routes.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. package server
  2. import (
  3. "embed"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "log"
  9. "math"
  10. "net"
  11. "net/http"
  12. "os"
  13. "path"
  14. "runtime"
  15. "strings"
  16. "text/template"
  17. "github.com/gin-gonic/gin"
  18. "github.com/lithammer/fuzzysearch/fuzzy"
  19. "github.com/jmorganca/ollama/api"
  20. "github.com/jmorganca/ollama/llama"
  21. )
  22. //go:embed templates/*
  23. var templatesFS embed.FS
  24. var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
  25. func cacheDir() string {
  26. home, err := os.UserHomeDir()
  27. if err != nil {
  28. panic(err)
  29. }
  30. return path.Join(home, ".ollama")
  31. }
  32. func generate(c *gin.Context) {
  33. var req api.GenerateRequest
  34. req.ModelOptions = api.DefaultModelOptions
  35. req.PredictOptions = api.DefaultPredictOptions
  36. if err := c.ShouldBindJSON(&req); err != nil {
  37. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  38. return
  39. }
  40. if remoteModel, _ := getRemote(req.Model); remoteModel != nil {
  41. req.Model = remoteModel.FullName()
  42. }
  43. if _, err := os.Stat(req.Model); err != nil {
  44. if !errors.Is(err, os.ErrNotExist) {
  45. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  46. return
  47. }
  48. req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
  49. }
  50. modelOpts := getModelOpts(req)
  51. modelOpts.NGPULayers = 1 // hard-code this for now
  52. model, err := llama.New(req.Model, modelOpts)
  53. if err != nil {
  54. fmt.Println("Loading the model failed:", err.Error())
  55. return
  56. }
  57. defer model.Free()
  58. templateNames := make([]string, 0, len(templates.Templates()))
  59. for _, template := range templates.Templates() {
  60. templateNames = append(templateNames, template.Name())
  61. }
  62. match, _ := matchRankOne(path.Base(req.Model), templateNames)
  63. if template := templates.Lookup(match); template != nil {
  64. var sb strings.Builder
  65. if err := template.Execute(&sb, req); err != nil {
  66. fmt.Println("Prompt template failed:", err.Error())
  67. return
  68. }
  69. req.Prompt = sb.String()
  70. }
  71. ch := make(chan string)
  72. model.SetTokenCallback(func(token string) bool {
  73. ch <- token
  74. return true
  75. })
  76. predictOpts := getPredictOpts(req)
  77. go func() {
  78. defer close(ch)
  79. _, err := model.Predict(req.Prompt, predictOpts)
  80. if err != nil {
  81. panic(err)
  82. }
  83. }()
  84. c.Stream(func(w io.Writer) bool {
  85. token, ok := <-ch
  86. if !ok {
  87. return false
  88. }
  89. resp := api.GenerateResponse{
  90. Response: token,
  91. }
  92. bts, err := json.Marshal(resp)
  93. if err != nil {
  94. return false
  95. }
  96. bts = append(bts, '\n')
  97. if _, err := w.Write(bts); err != nil {
  98. return false
  99. }
  100. return true
  101. })
  102. }
  103. func Serve(ln net.Listener) error {
  104. r := gin.Default()
  105. r.GET("/", func(c *gin.Context) {
  106. c.String(http.StatusOK, "Ollama is running")
  107. })
  108. r.POST("api/pull", func(c *gin.Context) {
  109. var req api.PullRequest
  110. if err := c.ShouldBindJSON(&req); err != nil {
  111. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  112. return
  113. }
  114. progressCh := make(chan api.PullProgress)
  115. go func() {
  116. defer close(progressCh)
  117. if err := pull(req.Model, progressCh); err != nil {
  118. var opError *net.OpError
  119. if errors.As(err, &opError) {
  120. result := api.PullProgress{
  121. Error: api.Error{
  122. Code: http.StatusBadGateway,
  123. Message: "failed to get models from directory",
  124. },
  125. }
  126. c.JSON(http.StatusBadGateway, result)
  127. return
  128. }
  129. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  130. return
  131. }
  132. }()
  133. c.Stream(func(w io.Writer) bool {
  134. progress, ok := <-progressCh
  135. if !ok {
  136. return false
  137. }
  138. bts, err := json.Marshal(progress)
  139. if err != nil {
  140. return false
  141. }
  142. bts = append(bts, '\n')
  143. if _, err := w.Write(bts); err != nil {
  144. return false
  145. }
  146. return true
  147. })
  148. })
  149. r.POST("/api/generate", generate)
  150. log.Printf("Listening on %s", ln.Addr())
  151. s := &http.Server{
  152. Handler: r,
  153. }
  154. return s.Serve(ln)
  155. }
  156. func matchRankOne(source string, targets []string) (bestMatch string, bestRank int) {
  157. bestRank = math.MaxInt
  158. for _, target := range targets {
  159. if rank := fuzzy.LevenshteinDistance(source, target); bestRank > rank {
  160. bestRank = rank
  161. bestMatch = target
  162. }
  163. }
  164. return
  165. }
  166. func getModelOpts(req api.GenerateRequest) llama.ModelOptions {
  167. var opts llama.ModelOptions
  168. opts.ContextSize = req.ModelOptions.ContextSize
  169. opts.Seed = req.ModelOptions.Seed
  170. opts.F16Memory = req.ModelOptions.F16Memory
  171. opts.MLock = req.ModelOptions.MLock
  172. opts.Embeddings = req.ModelOptions.Embeddings
  173. opts.MMap = req.ModelOptions.MMap
  174. opts.LowVRAM = req.ModelOptions.LowVRAM
  175. opts.NBatch = req.ModelOptions.NBatch
  176. opts.VocabOnly = req.ModelOptions.VocabOnly
  177. opts.NUMA = req.ModelOptions.NUMA
  178. opts.NGPULayers = req.ModelOptions.NGPULayers
  179. opts.MainGPU = req.ModelOptions.MainGPU
  180. opts.TensorSplit = req.ModelOptions.TensorSplit
  181. return opts
  182. }
  183. func getPredictOpts(req api.GenerateRequest) llama.PredictOptions {
  184. var opts llama.PredictOptions
  185. if req.PredictOptions.Threads == -1 {
  186. opts.Threads = runtime.NumCPU()
  187. } else {
  188. opts.Threads = req.PredictOptions.Threads
  189. }
  190. opts.Seed = req.PredictOptions.Seed
  191. opts.Tokens = req.PredictOptions.Tokens
  192. opts.Penalty = req.PredictOptions.Penalty
  193. opts.Repeat = req.PredictOptions.Repeat
  194. opts.Batch = req.PredictOptions.Batch
  195. opts.NKeep = req.PredictOptions.NKeep
  196. opts.TopK = req.PredictOptions.TopK
  197. opts.TopP = req.PredictOptions.TopP
  198. opts.TailFreeSamplingZ = req.PredictOptions.TailFreeSamplingZ
  199. opts.TypicalP = req.PredictOptions.TypicalP
  200. opts.Temperature = req.PredictOptions.Temperature
  201. opts.FrequencyPenalty = req.PredictOptions.FrequencyPenalty
  202. opts.PresencePenalty = req.PredictOptions.PresencePenalty
  203. opts.Mirostat = req.PredictOptions.Mirostat
  204. opts.MirostatTAU = req.PredictOptions.MirostatTAU
  205. opts.MirostatETA = req.PredictOptions.MirostatETA
  206. opts.MMap = req.PredictOptions.MMap
  207. return opts
  208. }