routes.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. package server
  2. import (
  3. "embed"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "log"
  8. "math"
  9. "net"
  10. "net/http"
  11. "path"
  12. "runtime"
  13. "strings"
  14. "text/template"
  15. "github.com/gin-gonic/gin"
  16. "github.com/lithammer/fuzzysearch/fuzzy"
  17. "github.com/jmorganca/ollama/api"
  18. "github.com/jmorganca/ollama/llama"
  19. )
  20. //go:embed templates/*
  21. var templatesFS embed.FS
  22. var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
  23. func generate(c *gin.Context) {
  24. var req api.GenerateRequest
  25. req.ModelOptions = api.DefaultModelOptions
  26. req.PredictOptions = api.DefaultPredictOptions
  27. if err := c.ShouldBindJSON(&req); err != nil {
  28. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  29. return
  30. }
  31. if remoteModel, _ := getRemote(req.Model); remoteModel != nil {
  32. req.Model = remoteModel.FullName()
  33. }
  34. modelOpts := getModelOpts(req)
  35. modelOpts.NGPULayers = 1 // hard-code this for now
  36. model, err := llama.New(req.Model, modelOpts)
  37. if err != nil {
  38. fmt.Println("Loading the model failed:", err.Error())
  39. return
  40. }
  41. defer model.Free()
  42. templateNames := make([]string, 0, len(templates.Templates()))
  43. for _, template := range templates.Templates() {
  44. templateNames = append(templateNames, template.Name())
  45. }
  46. match, _ := matchRankOne(path.Base(req.Model), templateNames)
  47. if template := templates.Lookup(match); template != nil {
  48. var sb strings.Builder
  49. if err := template.Execute(&sb, req); err != nil {
  50. fmt.Println("Prompt template failed:", err.Error())
  51. return
  52. }
  53. req.Prompt = sb.String()
  54. }
  55. ch := make(chan string)
  56. model.SetTokenCallback(func(token string) bool {
  57. ch <- token
  58. return true
  59. })
  60. predictOpts := getPredictOpts(req)
  61. go func() {
  62. defer close(ch)
  63. _, err := model.Predict(req.Prompt, predictOpts)
  64. if err != nil {
  65. panic(err)
  66. }
  67. }()
  68. c.Stream(func(w io.Writer) bool {
  69. token, ok := <-ch
  70. if !ok {
  71. return false
  72. }
  73. resp := api.GenerateResponse{
  74. Response: token,
  75. }
  76. bts, err := json.Marshal(resp)
  77. if err != nil {
  78. return false
  79. }
  80. bts = append(bts, '\n')
  81. if _, err := w.Write(bts); err != nil {
  82. return false
  83. }
  84. return true
  85. })
  86. }
  87. func Serve(ln net.Listener) error {
  88. r := gin.Default()
  89. r.POST("api/pull", func(c *gin.Context) {
  90. var req api.PullRequest
  91. if err := c.ShouldBindJSON(&req); err != nil {
  92. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  93. return
  94. }
  95. progressCh := make(chan api.PullProgress)
  96. go func() {
  97. defer close(progressCh)
  98. if err := pull(req.Model, progressCh); err != nil {
  99. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  100. return
  101. }
  102. }()
  103. c.Stream(func(w io.Writer) bool {
  104. progress, ok := <-progressCh
  105. if !ok {
  106. return false
  107. }
  108. bts, err := json.Marshal(progress)
  109. if err != nil {
  110. return false
  111. }
  112. bts = append(bts, '\n')
  113. if _, err := w.Write(bts); err != nil {
  114. return false
  115. }
  116. return true
  117. })
  118. })
  119. r.POST("/api/generate", generate)
  120. log.Printf("Listening on %s", ln.Addr())
  121. s := &http.Server{
  122. Handler: r,
  123. }
  124. return s.Serve(ln)
  125. }
  126. func matchRankOne(source string, targets []string) (bestMatch string, bestRank int) {
  127. bestRank = math.MaxInt
  128. for _, target := range targets {
  129. if rank := fuzzy.LevenshteinDistance(source, target); bestRank > rank {
  130. bestRank = rank
  131. bestMatch = target
  132. }
  133. }
  134. return
  135. }
  136. func getModelOpts(req api.GenerateRequest) llama.ModelOptions {
  137. var opts llama.ModelOptions
  138. opts.ContextSize = req.ModelOptions.ContextSize
  139. opts.Seed = req.ModelOptions.Seed
  140. opts.F16Memory = req.ModelOptions.F16Memory
  141. opts.MLock = req.ModelOptions.MLock
  142. opts.Embeddings = req.ModelOptions.Embeddings
  143. opts.MMap = req.ModelOptions.MMap
  144. opts.LowVRAM = req.ModelOptions.LowVRAM
  145. opts.NBatch = req.ModelOptions.NBatch
  146. opts.VocabOnly = req.ModelOptions.VocabOnly
  147. opts.NUMA = req.ModelOptions.NUMA
  148. opts.NGPULayers = req.ModelOptions.NGPULayers
  149. opts.MainGPU = req.ModelOptions.MainGPU
  150. opts.TensorSplit = req.ModelOptions.TensorSplit
  151. return opts
  152. }
  153. func getPredictOpts(req api.GenerateRequest) llama.PredictOptions {
  154. var opts llama.PredictOptions
  155. if req.PredictOptions.Threads == -1 {
  156. opts.Threads = runtime.NumCPU()
  157. } else {
  158. opts.Threads = req.PredictOptions.Threads
  159. }
  160. opts.Seed = req.PredictOptions.Seed
  161. opts.Tokens = req.PredictOptions.Tokens
  162. opts.Penalty = req.PredictOptions.Penalty
  163. opts.Repeat = req.PredictOptions.Repeat
  164. opts.Batch = req.PredictOptions.Batch
  165. opts.NKeep = req.PredictOptions.NKeep
  166. opts.TopK = req.PredictOptions.TopK
  167. opts.TopP = req.PredictOptions.TopP
  168. opts.TailFreeSamplingZ = req.PredictOptions.TailFreeSamplingZ
  169. opts.TypicalP = req.PredictOptions.TypicalP
  170. opts.Temperature = req.PredictOptions.Temperature
  171. opts.FrequencyPenalty = req.PredictOptions.FrequencyPenalty
  172. opts.PresencePenalty = req.PredictOptions.PresencePenalty
  173. opts.Mirostat = req.PredictOptions.Mirostat
  174. opts.MirostatTAU = req.PredictOptions.MirostatTAU
  175. opts.MirostatETA = req.PredictOptions.MirostatETA
  176. opts.MMap = req.PredictOptions.MMap
  177. return opts
  178. }