runner.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. package main
  2. import (
  3. "context"
  4. "encoding/json"
  5. "flag"
  6. "fmt"
  7. "log"
  8. "log/slog"
  9. "net"
  10. "net/http"
  11. "runtime"
  12. "strconv"
  13. "strings"
  14. "sync"
  15. "github.com/ollama/ollama/api"
  16. "github.com/ollama/ollama/llama"
  17. )
  18. type Sequence struct {
  19. // number of tokens evaluated
  20. nPast int
  21. // tokens left to evaluate
  22. tokens []int
  23. // channel to send responses over
  24. responses chan string
  25. samplingCtx *llama.SamplingContext
  26. // channel to send back the embedding if embedding only
  27. embedding chan []float32
  28. // stop sequences
  29. stop []string
  30. // true if an embedding are to be returned instead of text generation
  31. embeddingOnly bool
  32. }
  33. // prompt returns true if the prompt is still being processed
  34. func (s *Sequence) prompt() bool {
  35. return s.nPast < len(s.tokens)-1
  36. }
  37. func (s *Server) NewSequence(prompt string, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
  38. tokens, err := s.lc.Model().Tokenize(prompt, 2048, false, true)
  39. if err != nil {
  40. panic(err)
  41. }
  42. var sc *llama.SamplingContext
  43. if params != nil {
  44. sc = llama.NewSamplingContext(*params)
  45. for _, t := range tokens {
  46. sc.Accept(s.lc, t, false)
  47. }
  48. }
  49. return &Sequence{
  50. tokens: tokens,
  51. responses: make(chan string, 1),
  52. embedding: make(chan []float32, 1),
  53. samplingCtx: sc,
  54. embeddingOnly: embedding,
  55. stop: stop,
  56. }
  57. }
  58. type Server struct {
  59. model *llama.Model
  60. lc *llama.Context
  61. cc *llama.ClipContext
  62. batchSize int
  63. // parallel is the number of parallel requests to handle
  64. parallel int
  65. // seqs is the list of parallel sequences being evaluated
  66. // TODO (jmorganca): this can probably be moved into run()
  67. seqs []*Sequence
  68. mu sync.Mutex
  69. cond *sync.Cond
  70. }
  71. func (s *Server) allNil() bool {
  72. for _, item := range s.seqs {
  73. if item != nil {
  74. return false
  75. }
  76. }
  77. return true
  78. }
  79. func findStop(sequence string, stops []string) (bool, string) {
  80. for _, stop := range stops {
  81. if strings.Contains(sequence, stop) {
  82. return true, stop
  83. }
  84. }
  85. return false, ""
  86. }
  87. func containsStopSuffix(sequence string, stops []string) bool {
  88. for _, stop := range stops {
  89. for i := 1; i <= len(stop); i++ {
  90. if strings.HasSuffix(sequence, stop[:i]) {
  91. return true
  92. }
  93. }
  94. }
  95. return false
  96. }
  97. // truncateStop removes the provided stop string from pieces,
  98. // returning the partial pieces with stop removed, including truncating
  99. // the last piece if required
  100. func truncateStop(pieces []string, stop string) []string {
  101. joined := strings.Join(pieces, "")
  102. index := strings.Index(joined, stop)
  103. if index == -1 {
  104. return pieces
  105. }
  106. joined = joined[:index]
  107. // Split truncated string back into pieces of original lengths
  108. lengths := make([]int, len(pieces))
  109. for i, piece := range pieces {
  110. lengths[i] = len(piece)
  111. }
  112. var result []string
  113. start := 0
  114. for _, length := range lengths {
  115. if start >= len(joined) {
  116. break
  117. }
  118. end := start + length
  119. if end > len(joined) {
  120. end = len(joined)
  121. }
  122. result = append(result, joined[start:end])
  123. start = end
  124. }
  125. return result
  126. }
  127. func (s *Server) run(ctx context.Context) {
  128. batch := llama.NewBatch(s.batchSize, 0, s.parallel)
  129. defer batch.Free()
  130. // build up stop sequences as we recognize them
  131. // TODO (jmorganca): simplify this
  132. pieces := make([][]string, s.parallel)
  133. for {
  134. select {
  135. case <-ctx.Done():
  136. return
  137. default:
  138. slog.Info("Processing batch", "seqs", len(s.seqs))
  139. s.mu.Lock()
  140. for s.allNil() {
  141. s.cond.Wait() // Wait until an item is added
  142. }
  143. s.mu.Unlock()
  144. // prepare the batch
  145. ibatch := make([]int, s.parallel)
  146. for i, seq := range s.seqs {
  147. if seq == nil {
  148. continue
  149. }
  150. for j, t := range seq.tokens {
  151. // todo: make this n_batch
  152. if j > s.batchSize {
  153. break
  154. }
  155. batch.Add(t, seq.nPast, []int{i}, !seq.prompt())
  156. seq.nPast++
  157. if seq.prompt() {
  158. ibatch[i] = batch.NumTokens() + 1
  159. }
  160. }
  161. }
  162. err := s.lc.Decode(batch)
  163. if err != nil {
  164. panic("Failed to decode")
  165. }
  166. for i, seq := range s.seqs {
  167. if seq == nil {
  168. continue
  169. }
  170. // don't sample prompt processing
  171. if seq.prompt() {
  172. if len(seq.tokens) < s.batchSize {
  173. seq.tokens = []int{}
  174. } else {
  175. seq.tokens = seq.tokens[s.batchSize:]
  176. }
  177. continue
  178. }
  179. // if done processing the prompt, generating an embedding and return
  180. if seq.embeddingOnly {
  181. embd := s.lc.GetEmbeddingsSeq(i)
  182. if embd == nil {
  183. embd = s.lc.GetEmbeddingsIth(ibatch[i])
  184. }
  185. seq.embedding <- embd
  186. close(seq.embedding)
  187. s.lc.KvCacheSeqRm(i, 0, -1)
  188. s.seqs[i] = nil
  189. continue
  190. }
  191. // sample a token
  192. // logits := s.lc.GetLogitsIth(ibatch[i])
  193. // token := s.lc.SampleTokenGreedy(logits)
  194. token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
  195. seq.samplingCtx.Accept(s.lc, token, true)
  196. piece := s.model.TokenToPiece(token)
  197. slog.Info("sampled", "piece", piece)
  198. // if it's an end of sequence token, break
  199. // TODO: just end this sequence
  200. if s.model.TokenIsEog(token) {
  201. // TODO: end the sequence instead of quitting the pool
  202. s.lc.KvCacheSeqRm(i, 0, -1)
  203. // TODO (jmorganca): we should send this back
  204. // as it's important for the /api/generate context
  205. // seq.responses <- piece
  206. close(seq.responses)
  207. seq.samplingCtx.Free()
  208. pieces[i] = []string{}
  209. s.seqs[i] = nil
  210. continue
  211. }
  212. seq.tokens = []int{token}
  213. pieces[i] = append(pieces[i], piece)
  214. sequence := strings.Join(pieces[i], "")
  215. if ok, stop := findStop(sequence, seq.stop); ok {
  216. slog.Info("hit stop token", "stop", seq.stop)
  217. truncated := truncateStop(pieces[i], stop)
  218. for _, p := range truncated {
  219. seq.responses <- p
  220. }
  221. s.lc.KvCacheSeqRm(i, 0, -1)
  222. close(seq.responses)
  223. seq.samplingCtx.Free()
  224. pieces[i] = []string{}
  225. s.seqs[i] = nil
  226. continue
  227. }
  228. if containsStopSuffix(sequence, seq.stop) {
  229. continue
  230. }
  231. for _, p := range pieces[i] {
  232. seq.responses <- p
  233. }
  234. pieces[i] = []string{}
  235. }
  236. batch.Clear()
  237. }
  238. }
  239. }
  240. type CompletionRequest struct {
  241. Prompt string `json:"prompt"`
  242. Images []string `json:"images"`
  243. Grammar string `json:"grammar"`
  244. Stop []string `json:"stop"`
  245. api.Options
  246. }
  247. type CompletionResponse struct {
  248. Token string `json:"token"`
  249. }
  250. func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
  251. var req CompletionRequest
  252. req.Options = api.DefaultOptions()
  253. if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
  254. http.Error(w, "Bad request", http.StatusBadRequest)
  255. return
  256. }
  257. // Set the headers to indicate streaming
  258. w.Header().Set("Content-Type", "application/json")
  259. w.Header().Set("Transfer-Encoding", "chunked")
  260. w.WriteHeader(http.StatusOK)
  261. var samplingParams llama.SamplingParams
  262. samplingParams.TopK = req.TopK
  263. samplingParams.TopP = req.TopP
  264. samplingParams.TfsZ = req.TFSZ
  265. samplingParams.TypicalP = req.TypicalP
  266. samplingParams.Temp = req.Temperature
  267. samplingParams.PenaltyRepeat = req.RepeatPenalty
  268. samplingParams.PenaltyFreq = req.FrequencyPenalty
  269. samplingParams.PenaltyPresent = req.PresencePenalty
  270. samplingParams.Mirostat = req.Mirostat
  271. samplingParams.MirostatTau = req.MirostatTau
  272. samplingParams.MirostatEta = req.MirostatEta
  273. samplingParams.PenalizeNl = req.PenalizeNewline
  274. samplingParams.Seed = uint32(req.Seed)
  275. samplingParams.Grammar = req.Grammar
  276. seq := s.NewSequence(req.Prompt, req.Stop, &samplingParams, false)
  277. // TODO (jmorganca): add to sequence queue instead of
  278. // failing if a slot isn't available
  279. s.mu.Lock()
  280. for i, sq := range s.seqs {
  281. if sq == nil {
  282. s.seqs[i] = seq
  283. s.cond.Signal()
  284. break
  285. }
  286. }
  287. s.mu.Unlock()
  288. // stream the response
  289. for token := range seq.responses {
  290. if err := json.NewEncoder(w).Encode(&CompletionResponse{
  291. Token: token,
  292. }); err != nil {
  293. log.Println("Failed to encode result:", err)
  294. return
  295. }
  296. flusher, ok := w.(http.Flusher)
  297. if !ok {
  298. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  299. return
  300. }
  301. flusher.Flush()
  302. }
  303. }
  304. type EmbeddingRequest struct {
  305. Prompt string `json:"prompt"`
  306. }
  307. type EmbeddingResponse struct {
  308. Embedding []float32 `json:"embedding"`
  309. }
  310. // TODO (jmorganca): is it safe to do this concurrently with decoding?
  311. func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
  312. var req EmbeddingRequest
  313. if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
  314. http.Error(w, "Bad request", http.StatusBadRequest)
  315. return
  316. }
  317. w.Header().Set("Content-Type", "application/json")
  318. seq := s.NewSequence(req.Prompt, nil, nil, true)
  319. s.mu.Lock()
  320. for i, sq := range s.seqs {
  321. if sq == nil {
  322. s.seqs[i] = seq
  323. s.cond.Signal()
  324. break
  325. }
  326. }
  327. s.mu.Unlock()
  328. embedding := <-seq.embedding
  329. if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
  330. Embedding: embedding,
  331. }); err != nil {
  332. log.Println("Failed to encode result:", err)
  333. return
  334. }
  335. }
  336. func main() {
  337. mpath := flag.String("model", "", "Path to model binary file")
  338. ppath := flag.String("projector", "", "Path to projector binary file")
  339. parallel := flag.Int("parallel", 1, "Number of sequences to handle simultaneously")
  340. batchSize := flag.Int("batch-size", 512, "Batch size")
  341. nGpuLayers := flag.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
  342. mainGpu := flag.Int("main-gpu", 0, "Main GPU")
  343. flashAttention := flag.Bool("flash-attention", false, "Enable flash attention")
  344. numCtx := flag.Int("num-ctx", 2048, "Context (or KV cache) size")
  345. lpath := flag.String("lora", "", "Path to lora layer file")
  346. port := flag.Int("port", 8080, "Port to expose the server on")
  347. threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
  348. flag.Parse()
  349. // load the model
  350. llama.BackendInit()
  351. params := llama.NewModelParams(*nGpuLayers, *mainGpu)
  352. model := llama.LoadModelFromFile(*mpath, params)
  353. if *lpath != "" {
  354. model.ApplyLoraFromFile(*lpath, 1.0, "", *threads)
  355. }
  356. ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
  357. lc := llama.NewContextWithModel(model, ctxParams)
  358. if lc == nil {
  359. panic("Failed to create context")
  360. }
  361. var cc *llama.ClipContext
  362. if *ppath != "" {
  363. cc = llama.NewClipContext(*ppath)
  364. if cc == nil {
  365. panic("Failed to create clip context")
  366. }
  367. }
  368. server := &Server{
  369. model: model,
  370. lc: lc,
  371. cc: cc,
  372. batchSize: *batchSize,
  373. parallel: *parallel,
  374. seqs: make([]*Sequence, *parallel),
  375. }
  376. server.cond = sync.NewCond(&server.mu)
  377. ctx, cancel := context.WithCancel(context.Background())
  378. go server.run(ctx)
  379. addr := "127.0.0.1:" + strconv.Itoa(*port)
  380. listener, err := net.Listen("tcp", addr)
  381. if err != nil {
  382. fmt.Println("Listen error:", err)
  383. return
  384. }
  385. defer listener.Close()
  386. mux := http.NewServeMux()
  387. mux.HandleFunc("/embeddings", server.embeddings)
  388. mux.HandleFunc("/completion", server.completion)
  389. httpServer := http.Server{
  390. Handler: mux,
  391. }
  392. log.Println("Server listening on", addr)
  393. if err := httpServer.Serve(listener); err != nil {
  394. log.Fatal("server error:", err)
  395. }
  396. cancel()
  397. }