runner.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. package main
  2. import (
  3. "context"
  4. "encoding/json"
  5. "flag"
  6. "fmt"
  7. "log"
  8. "log/slog"
  9. "math"
  10. "net"
  11. "net/http"
  12. "runtime"
  13. "strconv"
  14. "strings"
  15. "sync"
  16. "github.com/ollama/ollama/api"
  17. "github.com/ollama/ollama/llama"
  18. )
  19. type Sequence struct {
  20. // number of tokens evaluated
  21. nPast int
  22. // tokens left to evaluate
  23. tokens []int
  24. // channel to send responses over
  25. responses chan string
  26. // number of tokens to predict
  27. numPredict int
  28. samplingCtx *llama.SamplingContext
  29. // channel to send back the embedding if embedding only
  30. embedding chan []float32
  31. // stop sequences
  32. stop []string
  33. // true if an embedding are to be returned instead of text generation
  34. embeddingOnly bool
  35. doneReason string
  36. }
  37. // prompt returns true if the prompt is still being processed
  38. func (s *Sequence) prompt() bool {
  39. return s.nPast < len(s.tokens)-1
  40. }
  41. func (s *Server) NewSequence(prompt string, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
  42. tokens, err := s.lc.Model().Tokenize(prompt, false, true)
  43. if err != nil {
  44. panic(err)
  45. }
  46. // truncate to last n tokens
  47. // TODO: this shouldn't happen and will severely impact generation
  48. // quality. instead we should ensure to cut prompt in the API.
  49. if len(tokens) > s.numCtx {
  50. tokens = tokens[:s.numCtx]
  51. }
  52. var sc *llama.SamplingContext
  53. if params != nil {
  54. sc = llama.NewSamplingContext(*params)
  55. for _, t := range tokens {
  56. sc.Accept(s.lc, t, false)
  57. }
  58. }
  59. return &Sequence{
  60. tokens: tokens,
  61. responses: make(chan string, 1),
  62. embedding: make(chan []float32, 1),
  63. samplingCtx: sc,
  64. embeddingOnly: embedding,
  65. stop: stop,
  66. }
  67. }
  68. type Server struct {
  69. model *llama.Model
  70. lc *llama.Context
  71. cc *llama.ClipContext
  72. batchSize int
  73. // parallel is the number of parallel requests to handle
  74. parallel int
  75. // seqs is the list of parallel sequences being evaluated
  76. // TODO (jmorganca): this can probably be moved into run()
  77. seqs []*Sequence
  78. // context window size
  79. numCtx int
  80. mu sync.Mutex
  81. cond *sync.Cond
  82. progress float32
  83. status string
  84. }
  85. func (s *Server) allNil() bool {
  86. for _, item := range s.seqs {
  87. if item != nil {
  88. return false
  89. }
  90. }
  91. return true
  92. }
  93. func findStop(sequence string, stops []string) (bool, string) {
  94. for _, stop := range stops {
  95. if strings.Contains(sequence, stop) {
  96. return true, stop
  97. }
  98. }
  99. return false, ""
  100. }
  101. func containsStopSuffix(sequence string, stops []string) bool {
  102. for _, stop := range stops {
  103. for i := 1; i <= len(stop); i++ {
  104. if strings.HasSuffix(sequence, stop[:i]) {
  105. return true
  106. }
  107. }
  108. }
  109. return false
  110. }
  111. // truncateStop removes the provided stop string from pieces,
  112. // returning the partial pieces with stop removed, including truncating
  113. // the last piece if required
  114. func truncateStop(pieces []string, stop string) []string {
  115. joined := strings.Join(pieces, "")
  116. index := strings.Index(joined, stop)
  117. if index == -1 {
  118. return pieces
  119. }
  120. joined = joined[:index]
  121. // Split truncated string back into pieces of original lengths
  122. lengths := make([]int, len(pieces))
  123. for i, piece := range pieces {
  124. lengths[i] = len(piece)
  125. }
  126. var result []string
  127. start := 0
  128. for _, length := range lengths {
  129. if start >= len(joined) {
  130. break
  131. }
  132. end := start + length
  133. if end > len(joined) {
  134. end = len(joined)
  135. }
  136. result = append(result, joined[start:end])
  137. start = end
  138. }
  139. return result
  140. }
  141. func (s *Server) run(ctx context.Context) {
  142. batch := llama.NewBatch(s.batchSize, 0, s.parallel)
  143. defer batch.Free()
  144. // build up stop sequences as we recognize them
  145. // TODO (jmorganca): simplify this
  146. pieces := make([][]string, s.parallel)
  147. for {
  148. select {
  149. case <-ctx.Done():
  150. return
  151. default:
  152. slog.Info("Processing batch", "seqs", len(s.seqs))
  153. s.mu.Lock()
  154. for s.allNil() {
  155. s.cond.Wait() // Wait until an item is added
  156. }
  157. s.mu.Unlock()
  158. // prepare the batch
  159. ibatch := make([]int, s.parallel)
  160. for i, seq := range s.seqs {
  161. if seq == nil {
  162. continue
  163. }
  164. // we've reached the context limit
  165. if seq.nPast > s.numCtx {
  166. seq.doneReason = "limit"
  167. close(seq.responses)
  168. s.lc.KvCacheSeqRm(i, 0, -1)
  169. s.seqs[i] = nil
  170. continue
  171. }
  172. for j, t := range seq.tokens {
  173. // todo: make this n_batch
  174. if j > s.batchSize {
  175. break
  176. }
  177. batch.Add(t, seq.nPast, []int{i}, !seq.prompt())
  178. seq.nPast++
  179. if seq.prompt() {
  180. ibatch[i] = batch.NumTokens() + 1
  181. }
  182. }
  183. }
  184. err := s.lc.Decode(batch)
  185. if err != nil {
  186. panic("Failed to decode")
  187. }
  188. for i, seq := range s.seqs {
  189. if seq == nil {
  190. continue
  191. }
  192. // don't sample prompt processing
  193. if seq.prompt() {
  194. if len(seq.tokens) < s.batchSize {
  195. seq.tokens = []int{}
  196. } else {
  197. seq.tokens = seq.tokens[s.batchSize:]
  198. }
  199. continue
  200. }
  201. // if done processing the prompt, generating an embedding and return
  202. if seq.embeddingOnly {
  203. embd := s.lc.GetEmbeddingsSeq(i)
  204. if embd == nil {
  205. embd = s.lc.GetEmbeddingsIth(ibatch[i])
  206. }
  207. seq.embedding <- embd
  208. close(seq.embedding)
  209. s.lc.KvCacheSeqRm(i, 0, -1)
  210. s.seqs[i] = nil
  211. continue
  212. }
  213. // sample a token
  214. // logits := s.lc.GetLogitsIth(ibatch[i])
  215. // token := s.lc.SampleTokenGreedy(logits)
  216. token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
  217. seq.samplingCtx.Accept(s.lc, token, true)
  218. piece := s.model.TokenToPiece(token)
  219. slog.Info("sampled", "piece", piece)
  220. // if it's an end of sequence token, break
  221. // TODO: just end this sequence
  222. if s.model.TokenIsEog(token) {
  223. // TODO: end the sequence instead of quitting the pool
  224. s.lc.KvCacheSeqRm(i, 0, -1)
  225. // TODO (jmorganca): we should send this back
  226. // as it's important for the /api/generate context
  227. // seq.responses <- piece
  228. seq.doneReason = "stop"
  229. close(seq.responses)
  230. seq.samplingCtx.Free()
  231. pieces[i] = []string{}
  232. s.seqs[i] = nil
  233. continue
  234. }
  235. seq.tokens = []int{token}
  236. pieces[i] = append(pieces[i], piece)
  237. sequence := strings.Join(pieces[i], "")
  238. if ok, stop := findStop(sequence, seq.stop); ok {
  239. slog.Info("hit stop token", "stop", seq.stop)
  240. truncated := truncateStop(pieces[i], stop)
  241. for _, p := range truncated {
  242. seq.responses <- p
  243. }
  244. s.lc.KvCacheSeqRm(i, 0, -1)
  245. seq.doneReason = "stop"
  246. close(seq.responses)
  247. seq.samplingCtx.Free()
  248. pieces[i] = []string{}
  249. s.seqs[i] = nil
  250. continue
  251. }
  252. if containsStopSuffix(sequence, seq.stop) {
  253. continue
  254. }
  255. for _, p := range pieces[i] {
  256. seq.responses <- p
  257. }
  258. pieces[i] = []string{}
  259. }
  260. batch.Clear()
  261. }
  262. }
  263. }
  264. type CompletionRequest struct {
  265. Prompt string `json:"prompt"`
  266. Images []string `json:"images"`
  267. Grammar string `json:"grammar"`
  268. Stop []string `json:"stop"`
  269. api.Options
  270. }
  271. type CompletionResponse struct {
  272. Token string `json:"token"`
  273. }
  274. func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
  275. var req CompletionRequest
  276. req.Options = api.DefaultOptions()
  277. if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
  278. http.Error(w, "Bad request", http.StatusBadRequest)
  279. return
  280. }
  281. // Set the headers to indicate streaming
  282. w.Header().Set("Content-Type", "application/json")
  283. w.Header().Set("Transfer-Encoding", "chunked")
  284. w.WriteHeader(http.StatusOK)
  285. var samplingParams llama.SamplingParams
  286. samplingParams.TopK = req.TopK
  287. samplingParams.TopP = req.TopP
  288. samplingParams.TfsZ = req.TFSZ
  289. samplingParams.TypicalP = req.TypicalP
  290. samplingParams.Temp = req.Temperature
  291. samplingParams.PenaltyRepeat = req.RepeatPenalty
  292. samplingParams.PenaltyFreq = req.FrequencyPenalty
  293. samplingParams.PenaltyPresent = req.PresencePenalty
  294. samplingParams.Mirostat = req.Mirostat
  295. samplingParams.MirostatTau = req.MirostatTau
  296. samplingParams.MirostatEta = req.MirostatEta
  297. samplingParams.PenalizeNl = req.PenalizeNewline
  298. samplingParams.Seed = uint32(req.Seed)
  299. samplingParams.Grammar = req.Grammar
  300. seq := s.NewSequence(req.Prompt, req.Stop, &samplingParams, false)
  301. // TODO (jmorganca): add to sequence queue instead of
  302. // failing if a slot isn't available
  303. s.mu.Lock()
  304. for i, sq := range s.seqs {
  305. if sq == nil {
  306. s.seqs[i] = seq
  307. s.cond.Signal()
  308. break
  309. }
  310. }
  311. s.mu.Unlock()
  312. // stream the response
  313. for token := range seq.responses {
  314. if err := json.NewEncoder(w).Encode(&CompletionResponse{
  315. Token: token,
  316. }); err != nil {
  317. log.Println("Failed to encode result:", err)
  318. return
  319. }
  320. flusher, ok := w.(http.Flusher)
  321. if !ok {
  322. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  323. return
  324. }
  325. flusher.Flush()
  326. }
  327. }
  328. type EmbeddingRequest struct {
  329. Prompt string `json:"prompt"`
  330. }
  331. type EmbeddingResponse struct {
  332. Embedding []float32 `json:"embedding"`
  333. }
  334. // TODO (jmorganca): is it safe to do this concurrently with decoding?
  335. func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
  336. var req EmbeddingRequest
  337. if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
  338. http.Error(w, "Bad request", http.StatusBadRequest)
  339. return
  340. }
  341. w.Header().Set("Content-Type", "application/json")
  342. seq := s.NewSequence(req.Prompt, nil, nil, true)
  343. s.mu.Lock()
  344. for i, sq := range s.seqs {
  345. if sq == nil {
  346. s.seqs[i] = seq
  347. s.cond.Signal()
  348. break
  349. }
  350. }
  351. s.mu.Unlock()
  352. embedding := <-seq.embedding
  353. if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
  354. Embedding: embedding,
  355. }); err != nil {
  356. log.Println("Failed to encode result:", err)
  357. return
  358. }
  359. }
  360. type HealthResponse struct {
  361. Status string `json:"status"`
  362. Progress float32 `json:"progress"`
  363. }
  364. // TODO (jmorganca): is it safe to do this concurrently with decoding?
  365. func (s *Server) health(w http.ResponseWriter, r *http.Request) {
  366. w.Header().Set("Content-Type", "application/json")
  367. if err := json.NewEncoder(w).Encode(&HealthResponse{
  368. Status: s.status,
  369. Progress: s.progress,
  370. }); err != nil {
  371. log.Println("Failed to encode result:", err)
  372. return
  373. }
  374. }
  375. func main() {
  376. mpath := flag.String("model", "", "Path to model binary file")
  377. ppath := flag.String("projector", "", "Path to projector binary file")
  378. parallel := flag.Int("parallel", 1, "Number of sequences to handle simultaneously")
  379. batchSize := flag.Int("batch-size", 512, "Batch size")
  380. nGpuLayers := flag.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
  381. mainGpu := flag.Int("main-gpu", 0, "Main GPU")
  382. flashAttention := flag.Bool("flash-attention", false, "Enable flash attention")
  383. numCtx := flag.Int("num-ctx", 2048, "Context (or KV cache) size")
  384. lpath := flag.String("lora", "", "Path to lora layer file")
  385. port := flag.Int("port", 8080, "Port to expose the server on")
  386. threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
  387. flag.Parse()
  388. server := &Server{
  389. numCtx: *numCtx,
  390. batchSize: *batchSize,
  391. parallel: *parallel,
  392. seqs: make([]*Sequence, *parallel),
  393. status: "loading",
  394. }
  395. // load the model
  396. llama.BackendInit()
  397. params := llama.NewModelParams(*nGpuLayers, *mainGpu, func(progress float32) {
  398. slog.Info("Loading model", "progress %", math.Round(float64(progress*100)))
  399. server.progress = progress
  400. })
  401. server.model = llama.LoadModelFromFile(*mpath, params)
  402. if *lpath != "" {
  403. server.model.ApplyLoraFromFile(*lpath, 1.0, "", *threads)
  404. }
  405. ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
  406. server.lc = llama.NewContextWithModel(server.model, ctxParams)
  407. if *ppath != "" {
  408. server.cc = llama.NewClipContext(*ppath)
  409. }
  410. server.cond = sync.NewCond(&server.mu)
  411. ctx, cancel := context.WithCancel(context.Background())
  412. go server.run(ctx)
  413. addr := "127.0.0.1:" + strconv.Itoa(*port)
  414. listener, err := net.Listen("tcp", addr)
  415. if err != nil {
  416. fmt.Println("Listen error:", err)
  417. return
  418. }
  419. defer listener.Close()
  420. mux := http.NewServeMux()
  421. mux.HandleFunc("/embeddings", server.embeddings)
  422. mux.HandleFunc("/completion", server.completion)
  423. mux.HandleFunc("/health", server.health)
  424. httpServer := http.Server{
  425. Handler: mux,
  426. }
  427. server.status = "ready"
  428. log.Println("Server listening on", addr)
  429. if err := httpServer.Serve(listener); err != nil {
  430. log.Fatal("server error:", err)
  431. }
  432. cancel()
  433. }