runner.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. package main
  2. import (
  3. "encoding/json"
  4. "flag"
  5. "fmt"
  6. "log"
  7. "net"
  8. "net/http"
  9. "sync"
  10. "github.com/ollama/ollama/llama"
  11. )
  12. type Request struct {
  13. Prompt string `json:"prompt"`
  14. }
  15. type Response struct {
  16. Token string `json:"token"`
  17. }
  18. type Server struct {
  19. model *llama.Model
  20. lc *llama.Context
  21. }
  22. var mu sync.Mutex
  23. func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
  24. var request Request
  25. if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
  26. http.Error(w, "Bad request", http.StatusBadRequest)
  27. return
  28. }
  29. // Set the headers to indicate streaming
  30. w.Header().Set("Content-Type", "application/json")
  31. w.Header().Set("Transfer-Encoding", "chunked")
  32. w.WriteHeader(http.StatusOK)
  33. enc := json.NewEncoder(w)
  34. // main loop
  35. tokens, err := s.model.Tokenize(request.Prompt, 2048, true, true)
  36. if err != nil {
  37. panic(err)
  38. }
  39. batch := llama.NewBatch(512, 0, 1)
  40. // prompt eval
  41. for i, t := range tokens {
  42. batch.Add(t, llama.Pos(i), []llama.SeqId{0}, true)
  43. }
  44. // main loop
  45. for n := batch.NumTokens(); n < 2048; n++ {
  46. mu.Lock()
  47. err = s.lc.Decode(batch)
  48. if err != nil {
  49. panic("Failed to decode")
  50. }
  51. // sample a token
  52. token := s.lc.SampleTokenGreedy(batch)
  53. mu.Unlock()
  54. // if it's an end of sequence token, break
  55. if s.model.TokenIsEog(token) {
  56. break
  57. }
  58. // print the token
  59. str := s.model.TokenToPiece(token)
  60. if err := enc.Encode(&Response{Token: str}); err != nil {
  61. log.Println("Failed to encode result:", err)
  62. return
  63. }
  64. w.(http.Flusher).Flush()
  65. batch.Clear()
  66. batch.Add(token, llama.Pos(n), []llama.SeqId{0}, true)
  67. }
  68. }
  69. func main() {
  70. mp := flag.String("model", "", "Path to model binary file")
  71. flag.Parse()
  72. // load the model
  73. llama.BackendInit()
  74. params := llama.NewModelParams()
  75. model := llama.LoadModelFromFile(*mp, params)
  76. ctxParams := llama.NewContextParams()
  77. lc := llama.NewContextWithModel(model, ctxParams)
  78. if lc == nil {
  79. panic("Failed to create context")
  80. }
  81. server := &Server{
  82. model: model,
  83. lc: lc,
  84. }
  85. addr := "127.0.0.1:8080"
  86. listener, err := net.Listen("tcp", addr)
  87. if err != nil {
  88. fmt.Println("Listen error:", err)
  89. return
  90. }
  91. defer listener.Close()
  92. httpServer := http.Server{
  93. Handler: http.HandlerFunc(server.stream),
  94. }
  95. log.Println("Server listening on", addr)
  96. if err := httpServer.Serve(listener); err != nil {
  97. log.Fatal("server error:", err)
  98. }
  99. }