routes.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. package server
  2. import (
  3. "encoding/json"
  4. "io"
  5. "log"
  6. "net"
  7. "net/http"
  8. "os"
  9. "path/filepath"
  10. "strings"
  11. "time"
  12. "dario.cat/mergo"
  13. "github.com/gin-gonic/gin"
  14. "github.com/jmorganca/ollama/api"
  15. "github.com/jmorganca/ollama/llama"
  16. )
  17. func cacheDir() string {
  18. home, err := os.UserHomeDir()
  19. if err != nil {
  20. panic(err)
  21. }
  22. return filepath.Join(home, ".ollama")
  23. }
  24. func generate(c *gin.Context) {
  25. start := time.Now()
  26. var req api.GenerateRequest
  27. if err := c.ShouldBindJSON(&req); err != nil {
  28. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  29. return
  30. }
  31. model, err := GetModel(req.Model)
  32. if err != nil {
  33. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  34. return
  35. }
  36. opts := api.DefaultOptions()
  37. if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
  38. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  39. return
  40. }
  41. if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
  42. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  43. return
  44. }
  45. prompt, err := model.Prompt(req)
  46. if err != nil {
  47. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  48. return
  49. }
  50. llm, err := llama.New(model.ModelPath, opts)
  51. if err != nil {
  52. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  53. return
  54. }
  55. defer llm.Close()
  56. ch := make(chan any)
  57. go func() {
  58. defer close(ch)
  59. llm.Predict(req.Context, prompt, func(r api.GenerateResponse) {
  60. r.Model = req.Model
  61. r.CreatedAt = time.Now().UTC()
  62. if r.Done {
  63. r.TotalDuration = time.Since(start)
  64. }
  65. ch <- r
  66. })
  67. }()
  68. streamResponse(c, ch)
  69. }
  70. func pull(c *gin.Context) {
  71. var req api.PullRequest
  72. if err := c.ShouldBindJSON(&req); err != nil {
  73. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  74. return
  75. }
  76. ch := make(chan any)
  77. go func() {
  78. defer close(ch)
  79. fn := func(r api.ProgressResponse) {
  80. ch <- r
  81. }
  82. if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil {
  83. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  84. return
  85. }
  86. }()
  87. streamResponse(c, ch)
  88. }
  89. func push(c *gin.Context) {
  90. var req api.PushRequest
  91. if err := c.ShouldBindJSON(&req); err != nil {
  92. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  93. return
  94. }
  95. ch := make(chan any)
  96. go func() {
  97. defer close(ch)
  98. fn := func(r api.ProgressResponse) {
  99. ch <- r
  100. }
  101. if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil {
  102. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  103. return
  104. }
  105. }()
  106. streamResponse(c, ch)
  107. }
  108. func create(c *gin.Context) {
  109. var req api.CreateRequest
  110. if err := c.ShouldBindJSON(&req); err != nil {
  111. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  112. return
  113. }
  114. ch := make(chan any)
  115. go func() {
  116. defer close(ch)
  117. fn := func(status string) {
  118. ch <- api.CreateProgress{
  119. Status: status,
  120. }
  121. }
  122. if err := CreateModel(req.Name, req.Path, fn); err != nil {
  123. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  124. return
  125. }
  126. }()
  127. streamResponse(c, ch)
  128. }
  129. func list(c *gin.Context) {
  130. var models []api.ListResponseModel
  131. fp, err := GetManifestPath()
  132. if err != nil {
  133. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  134. return
  135. }
  136. err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
  137. if err != nil {
  138. return err
  139. }
  140. if !info.IsDir() {
  141. fi, err := os.Stat(path)
  142. if err != nil {
  143. log.Printf("skipping file: %s", fp)
  144. return nil
  145. }
  146. path := path[len(fp)+1:]
  147. slashIndex := strings.LastIndex(path, "/")
  148. if slashIndex == -1 {
  149. return nil
  150. }
  151. tag := path[:slashIndex] + ":" + path[slashIndex+1:]
  152. mp := ParseModelPath(tag)
  153. manifest, err := GetManifest(mp)
  154. if err != nil {
  155. log.Printf("skipping file: %s", fp)
  156. return nil
  157. }
  158. model := api.ListResponseModel{
  159. Name: mp.GetShortTagname(),
  160. Size: manifest.GetTotalSize(),
  161. ModifiedAt: fi.ModTime(),
  162. }
  163. models = append(models, model)
  164. }
  165. return nil
  166. })
  167. if err != nil {
  168. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  169. return
  170. }
  171. c.JSON(http.StatusOK, api.ListResponse{models})
  172. }
  173. func Serve(ln net.Listener) error {
  174. r := gin.Default()
  175. r.GET("/", func(c *gin.Context) {
  176. c.String(http.StatusOK, "Ollama is running")
  177. })
  178. r.POST("/api/pull", pull)
  179. r.POST("/api/generate", generate)
  180. r.POST("/api/create", create)
  181. r.POST("/api/push", push)
  182. r.GET("/api/tags", list)
  183. log.Printf("Listening on %s", ln.Addr())
  184. s := &http.Server{
  185. Handler: r,
  186. }
  187. return s.Serve(ln)
  188. }
  189. func streamResponse(c *gin.Context, ch chan any) {
  190. c.Stream(func(w io.Writer) bool {
  191. val, ok := <-ch
  192. if !ok {
  193. return false
  194. }
  195. bts, err := json.Marshal(val)
  196. if err != nil {
  197. return false
  198. }
  199. bts = append(bts, '\n')
  200. if _, err := w.Write(bts); err != nil {
  201. return false
  202. }
  203. return true
  204. })
  205. }