runner.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. package main
  2. import (
  3. "context"
  4. "encoding/json"
  5. "flag"
  6. "fmt"
  7. "log"
  8. "log/slog"
  9. "net"
  10. "net/http"
  11. "strconv"
  12. "sync"
  13. "github.com/ollama/ollama/api"
  14. "github.com/ollama/ollama/llama"
  15. )
  16. type Sequence struct {
  17. // number of tokens evaluated
  18. nPast int
  19. // tokens left to evaluate
  20. tokens []int
  21. responses chan string
  22. samplingCtx *llama.SamplingContext
  23. }
  24. // prompt returns true if the prompt is still being processed
  25. func (s *Sequence) prompt() bool {
  26. return s.nPast < len(s.tokens)-1
  27. }
  28. func DefaultParams() llama.SamplingParams {
  29. return llama.SamplingParams{}
  30. }
  31. func (s *Server) NewSequence(r Request, w http.ResponseWriter) *Sequence {
  32. var samplingParams llama.SamplingParams
  33. samplingParams.TopK = r.TopK
  34. samplingParams.TopP = r.TopP
  35. samplingParams.TfsZ = r.TFSZ
  36. samplingParams.TypicalP = r.TypicalP
  37. samplingParams.Temp = r.Temperature
  38. samplingParams.PenaltyRepeat = r.RepeatPenalty
  39. samplingParams.PenaltyFreq = r.FrequencyPenalty
  40. samplingParams.PenaltyPresent = r.PresencePenalty
  41. samplingParams.Mirostat = r.Mirostat
  42. samplingParams.MirostatTau = r.MirostatTau
  43. samplingParams.MirostatEta = r.MirostatEta
  44. samplingParams.PenalizeNl = r.PenalizeNewline
  45. samplingParams.Seed = uint32(r.Seed)
  46. tokens, err := s.lc.Model().Tokenize(r.Prompt, 2048, false, true)
  47. if err != nil {
  48. panic(err)
  49. }
  50. sc := llama.NewSamplingContext(samplingParams)
  51. for _, t := range tokens {
  52. sc.Accept(s.lc, t, false)
  53. }
  54. return &Sequence{
  55. tokens: tokens,
  56. responses: make(chan string, 1),
  57. samplingCtx: sc,
  58. }
  59. }
  60. type Server struct {
  61. model *llama.Model
  62. lc *llama.Context
  63. cc *llama.ClipContext
  64. // parallel is the number of parallel requests to handle
  65. parallel int
  66. // seqs is the list of parallel sequences being evaluated
  67. seqs []*Sequence
  68. mu sync.Mutex
  69. cond *sync.Cond
  70. }
  71. func (s *Server) allNil() bool {
  72. for _, item := range s.seqs {
  73. if item != nil {
  74. return false
  75. }
  76. }
  77. return true
  78. }
  79. func (s *Server) run(ctx context.Context) {
  80. batch := llama.NewBatch(512, 0, s.parallel)
  81. defer batch.Free()
  82. for {
  83. select {
  84. case <-ctx.Done():
  85. return
  86. default:
  87. slog.Info("Processing batch", "seqs", len(s.seqs))
  88. s.mu.Lock()
  89. for s.allNil() {
  90. s.cond.Wait() // Wait until an item is added
  91. }
  92. s.mu.Unlock()
  93. fmt.Println("seqs", s.seqs, len(s.seqs))
  94. // prepare the batch
  95. ibatch := make([]int, s.parallel)
  96. for i, seq := range s.seqs {
  97. if seq == nil {
  98. continue
  99. }
  100. for j, t := range seq.tokens {
  101. // todo: make this n_batch
  102. if j > 512 {
  103. break
  104. }
  105. batch.Add(t, seq.nPast, []int{i}, !seq.prompt())
  106. seq.nPast++
  107. if seq.prompt() {
  108. ibatch[i] = batch.NumTokens() + 1
  109. }
  110. }
  111. }
  112. err := s.lc.Decode(batch)
  113. if err != nil {
  114. panic("Failed to decode")
  115. }
  116. for i, seq := range s.seqs {
  117. if seq == nil {
  118. continue
  119. }
  120. // don't sample prompt processing
  121. if seq.prompt() {
  122. if len(seq.tokens) < 512 {
  123. seq.tokens = []int{}
  124. } else {
  125. seq.tokens = seq.tokens[512:]
  126. }
  127. continue
  128. }
  129. // sample a token
  130. // TODO: sample based on the sequence
  131. fmt.Println("Sampling token", i, ibatch[i])
  132. fmt.Println("calling sample", s.lc, nil, ibatch[i])
  133. token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
  134. seq.samplingCtx.Accept(s.lc, token, true)
  135. // logits := s.lc.GetLogitsIth(ibatch[i])
  136. // token := s.lc.SampleTokenGreedy(logits)
  137. fmt.Println("sampled", token, s.model.TokenToPiece(token))
  138. seq.responses <- s.model.TokenToPiece(token)
  139. seq.tokens = []int{token}
  140. // if it's an end of sequence token, break
  141. // TODO: just end this sequence
  142. if s.model.TokenIsEog(token) {
  143. // TODO: end the sequence instead of quitting the pool
  144. s.lc.KvCacheSeqRm(i, 0, -1)
  145. close(seq.responses)
  146. s.seqs[i] = nil
  147. continue
  148. }
  149. }
  150. batch.Clear()
  151. }
  152. }
  153. }
  154. type Request struct {
  155. Prompt string `json:"prompt"`
  156. Images []string `json:"images"`
  157. api.Options
  158. }
  159. type Response struct {
  160. Token string `json:"token"`
  161. }
  162. func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
  163. var request Request
  164. request.Options = api.DefaultOptions()
  165. if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
  166. http.Error(w, "Bad request", http.StatusBadRequest)
  167. return
  168. }
  169. // Set the headers to indicate streaming
  170. w.Header().Set("Content-Type", "application/json")
  171. w.Header().Set("Transfer-Encoding", "chunked")
  172. w.WriteHeader(http.StatusOK)
  173. seq := s.NewSequence(request, w)
  174. s.mu.Lock()
  175. for i, sq := range s.seqs {
  176. if sq == nil {
  177. s.seqs[i] = seq
  178. fmt.Println("signal")
  179. s.cond.Signal()
  180. break
  181. }
  182. }
  183. s.mu.Unlock()
  184. for token := range seq.responses {
  185. if err := json.NewEncoder(w).Encode(&Response{
  186. Token: token,
  187. }); err != nil {
  188. log.Println("Failed to encode result:", err)
  189. return
  190. }
  191. flusher, ok := w.(http.Flusher)
  192. if !ok {
  193. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  194. return
  195. }
  196. flusher.Flush()
  197. }
  198. }
  199. func main() {
  200. mpath := flag.String("model", "", "Path to model binary file")
  201. ppath := flag.String("projector", "", "Path to projector binary file")
  202. parallel := flag.Int("parallel", 1, "Number of sequences to handle simultaneously")
  203. port := flag.Int("port", 8080, "Port to expose the server on")
  204. flag.Parse()
  205. // load the model
  206. llama.BackendInit()
  207. params := llama.NewModelParams()
  208. model := llama.LoadModelFromFile(*mpath, params)
  209. ctxParams := llama.NewContextParams()
  210. lc := llama.NewContextWithModel(model, ctxParams)
  211. if lc == nil {
  212. panic("Failed to create context")
  213. }
  214. var cc *llama.ClipContext
  215. if *ppath != "" {
  216. cc = llama.NewClipContext(*ppath)
  217. if cc == nil {
  218. panic("Failed to create clip context")
  219. }
  220. }
  221. server := &Server{
  222. model: model,
  223. lc: lc,
  224. cc: cc,
  225. parallel: *parallel,
  226. seqs: make([]*Sequence, *parallel),
  227. }
  228. server.cond = sync.NewCond(&server.mu)
  229. ctx, cancel := context.WithCancel(context.Background())
  230. go server.run(ctx)
  231. addr := "127.0.0.1:" + strconv.Itoa(*port)
  232. listener, err := net.Listen("tcp", addr)
  233. if err != nil {
  234. fmt.Println("Listen error:", err)
  235. return
  236. }
  237. defer listener.Close()
  238. httpServer := http.Server{
  239. Handler: http.HandlerFunc(server.handler),
  240. }
  241. log.Println("Server listening on", addr)
  242. if err := httpServer.Serve(listener); err != nil {
  243. log.Fatal("server error:", err)
  244. }
  245. cancel()
  246. }