routes.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. package server
  2. import (
  3. "encoding/json"
  4. "io"
  5. "log"
  6. "net"
  7. "net/http"
  8. "os"
  9. "path/filepath"
  10. "strings"
  11. "text/template"
  12. "time"
  13. "dario.cat/mergo"
  14. "github.com/gin-gonic/gin"
  15. "github.com/jmorganca/ollama/api"
  16. "github.com/jmorganca/ollama/llama"
  17. )
  18. func cacheDir() string {
  19. home, err := os.UserHomeDir()
  20. if err != nil {
  21. panic(err)
  22. }
  23. return filepath.Join(home, ".ollama")
  24. }
  25. func generate(c *gin.Context) {
  26. start := time.Now()
  27. var req api.GenerateRequest
  28. if err := c.ShouldBindJSON(&req); err != nil {
  29. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  30. return
  31. }
  32. model, err := GetModel(req.Model)
  33. if err != nil {
  34. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  35. return
  36. }
  37. opts := api.DefaultOptions()
  38. if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
  39. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  40. return
  41. }
  42. if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
  43. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  44. return
  45. }
  46. templ, err := template.New("").Parse(model.Prompt)
  47. if err != nil {
  48. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  49. return
  50. }
  51. var sb strings.Builder
  52. if err = templ.Execute(&sb, req); err != nil {
  53. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  54. return
  55. }
  56. req.Prompt = sb.String()
  57. llm, err := llama.New(model.ModelPath, opts)
  58. if err != nil {
  59. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  60. return
  61. }
  62. defer llm.Close()
  63. ch := make(chan any)
  64. go func() {
  65. defer close(ch)
  66. llm.Predict(req.Context, req.Prompt, func(r api.GenerateResponse) {
  67. r.Model = req.Model
  68. r.CreatedAt = time.Now().UTC()
  69. if r.Done {
  70. r.TotalDuration = time.Since(start)
  71. }
  72. ch <- r
  73. })
  74. }()
  75. streamResponse(c, ch)
  76. }
  77. func pull(c *gin.Context) {
  78. var req api.PullRequest
  79. if err := c.ShouldBindJSON(&req); err != nil {
  80. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  81. return
  82. }
  83. ch := make(chan any)
  84. go func() {
  85. defer close(ch)
  86. fn := func(r api.ProgressResponse) {
  87. ch <- r
  88. }
  89. if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil {
  90. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  91. return
  92. }
  93. }()
  94. streamResponse(c, ch)
  95. }
  96. func push(c *gin.Context) {
  97. var req api.PushRequest
  98. if err := c.ShouldBindJSON(&req); err != nil {
  99. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  100. return
  101. }
  102. ch := make(chan any)
  103. go func() {
  104. defer close(ch)
  105. fn := func(r api.ProgressResponse) {
  106. ch <- r
  107. }
  108. if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil {
  109. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  110. return
  111. }
  112. }()
  113. streamResponse(c, ch)
  114. }
  115. func create(c *gin.Context) {
  116. var req api.CreateRequest
  117. if err := c.ShouldBindJSON(&req); err != nil {
  118. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  119. return
  120. }
  121. // NOTE consider passing the entire Modelfile in the json instead of the path to it
  122. file, err := os.Open(req.Path)
  123. if err != nil {
  124. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  125. return
  126. }
  127. defer file.Close()
  128. ch := make(chan any)
  129. go func() {
  130. defer close(ch)
  131. fn := func(status string) {
  132. ch <- api.CreateProgress{
  133. Status: status,
  134. }
  135. }
  136. if err := CreateModel(req.Name, file, fn); err != nil {
  137. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  138. return
  139. }
  140. }()
  141. streamResponse(c, ch)
  142. }
  143. func list(c *gin.Context) {
  144. var models []api.ListResponseModel
  145. fp, err := GetManifestPath()
  146. if err != nil {
  147. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  148. return
  149. }
  150. err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
  151. if err != nil {
  152. return err
  153. }
  154. if !info.IsDir() {
  155. fi, err := os.Stat(path)
  156. if err != nil {
  157. log.Printf("skipping file: %s", fp)
  158. return nil
  159. }
  160. path := path[len(fp)+1:]
  161. slashIndex := strings.LastIndex(path, "/")
  162. if slashIndex == -1 {
  163. return nil
  164. }
  165. tag := path[:slashIndex] + ":" + path[slashIndex+1:]
  166. mp := ParseModelPath(tag)
  167. manifest, err := GetManifest(mp)
  168. if err != nil {
  169. log.Printf("skipping file: %s", fp)
  170. return nil
  171. }
  172. model := api.ListResponseModel{
  173. Name: mp.GetShortTagname(),
  174. Size: manifest.GetTotalSize(),
  175. ModifiedAt: fi.ModTime(),
  176. }
  177. models = append(models, model)
  178. }
  179. return nil
  180. })
  181. if err != nil {
  182. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  183. return
  184. }
  185. c.JSON(http.StatusOK, api.ListResponse{models})
  186. }
  187. func Serve(ln net.Listener) error {
  188. r := gin.Default()
  189. r.GET("/", func(c *gin.Context) {
  190. c.String(http.StatusOK, "Ollama is running")
  191. })
  192. r.POST("/api/pull", pull)
  193. r.POST("/api/generate", generate)
  194. r.POST("/api/create", create)
  195. r.POST("/api/push", push)
  196. r.GET("/api/tags", list)
  197. log.Printf("Listening on %s", ln.Addr())
  198. s := &http.Server{
  199. Handler: r,
  200. }
  201. return s.Serve(ln)
  202. }
  203. func streamResponse(c *gin.Context, ch chan any) {
  204. c.Stream(func(w io.Writer) bool {
  205. val, ok := <-ch
  206. if !ok {
  207. return false
  208. }
  209. bts, err := json.Marshal(val)
  210. if err != nil {
  211. return false
  212. }
  213. bts = append(bts, '\n')
  214. if _, err := w.Write(bts); err != nil {
  215. return false
  216. }
  217. return true
  218. })
  219. }