main.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. package main
  2. import (
  3. "encoding/json"
  4. "flag"
  5. "fmt"
  6. "log"
  7. "net"
  8. "net/http"
  9. "sync"
  10. "github.com/ollama/ollama/llama"
  11. )
  12. type Request struct {
  13. Prompt string `json:"prompt"`
  14. }
  15. type Response struct {
  16. Token string `json:"token"`
  17. }
  18. type Server struct {
  19. model *llama.Model
  20. lc *llama.Context
  21. batch *llama.Batch
  22. queue chan Sequence
  23. seqs []*Sequence
  24. // mu guards seqs
  25. mu sync.Mutex
  26. }
  27. type Sequence struct {
  28. prompt []llama.Token
  29. out chan string
  30. }
  31. func schedule(parallel int, queue <-chan Sequence) {
  32. // Fill sequences from the queue
  33. // once a sequence finishes, remove it from and add a new one from the queue
  34. }
  35. func process() {
  36. // loop through the sequences, fill a batch, decode and sample tokens, responding to appropriate requests
  37. }
  38. func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
  39. var request Request
  40. if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
  41. http.Error(w, "Bad request", http.StatusBadRequest)
  42. return
  43. }
  44. // Set the headers to indicate streaming
  45. w.Header().Set("Content-Type", "application/json")
  46. w.Header().Set("Transfer-Encoding", "chunked")
  47. w.WriteHeader(http.StatusOK)
  48. tokens, err := s.model.Tokenize(request.Prompt, 2048, true, true)
  49. if err != nil {
  50. panic(err)
  51. }
  52. seq := Sequence{prompt: tokens}
  53. s.queue <- seq
  54. // listen for the sequence to finish
  55. for {
  56. str := <-seq.out
  57. if err := json.NewEncoder(w).Encode(&Response{Token: str}); err != nil {
  58. log.Println("Failed to encode result:", err)
  59. return
  60. }
  61. w.(http.Flusher).Flush()
  62. }
  63. // prompt eval
  64. for i, t := range tokens {
  65. batch.Add(t, llama.Pos(i), []llama.SeqId{0}, true)
  66. }
  67. // main loop
  68. for n := batch.NumTokens(); n < 2048; n++ {
  69. mu.Lock()
  70. err = s.lc.Decode(batch)
  71. if err != nil {
  72. panic("Failed to decode")
  73. }
  74. // sample a token
  75. token := s.lc.SampleTokenGreedy(batch)
  76. mu.Unlock()
  77. // if it's an end of sequence token, break
  78. if s.model.TokenIsEog(token) {
  79. break
  80. }
  81. // print the token
  82. str := s.model.TokenToPiece(token)
  83. if err := enc.Encode(&Response{Token: str}); err != nil {
  84. log.Println("Failed to encode result:", err)
  85. return
  86. }
  87. w.(http.Flusher).Flush()
  88. batch.Clear()
  89. batch.Add(token, llama.Pos(n), []llama.SeqId{0}, true)
  90. }
  91. }
  92. func main() {
  93. mp := flag.String("model", "", "Path to model binary file")
  94. parallel := flag.Int("parallel", 1, "Number of parallel requests to handle")
  95. flag.Parse()
  96. // load the model
  97. llama.BackendInit()
  98. params := llama.NewModelParams()
  99. model := llama.LoadModelFromFile(*mp, params)
  100. ctxParams := llama.NewContextParams()
  101. lc := llama.NewContextWithModel(model, ctxParams)
  102. if lc == nil {
  103. panic("Failed to create context")
  104. }
  105. server := &Server{
  106. model: model,
  107. lc: lc,
  108. queue: make(chan Sequence, 256),
  109. seqs: make([]*Sequence, *parallel),
  110. }
  111. addr := "127.0.0.1:8080"
  112. listener, err := net.Listen("tcp", addr)
  113. if err != nil {
  114. fmt.Println("Listen error:", err)
  115. return
  116. }
  117. defer listener.Close()
  118. httpServer := http.Server{
  119. Handler: http.HandlerFunc(server.stream),
  120. }
  121. log.Println("Server listening on", addr)
  122. if err := httpServer.Serve(listener); err != nil {
  123. log.Fatal("server error:", err)
  124. }
  125. }