routes.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. package server
  2. import (
  3. "encoding/json"
  4. "io"
  5. "log"
  6. "net"
  7. "net/http"
  8. "os"
  9. "path/filepath"
  10. "strings"
  11. "time"
  12. "dario.cat/mergo"
  13. "github.com/gin-gonic/gin"
  14. "github.com/jmorganca/ollama/api"
  15. "github.com/jmorganca/ollama/llama"
  16. )
  17. func generate(c *gin.Context) {
  18. start := time.Now()
  19. var req api.GenerateRequest
  20. if err := c.ShouldBindJSON(&req); err != nil {
  21. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  22. return
  23. }
  24. model, err := GetModel(req.Model)
  25. if err != nil {
  26. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  27. return
  28. }
  29. opts := api.DefaultOptions()
  30. if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
  31. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  32. return
  33. }
  34. if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
  35. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  36. return
  37. }
  38. prompt, err := model.Prompt(req)
  39. if err != nil {
  40. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  41. return
  42. }
  43. llm, err := llama.New(model.ModelPath, opts)
  44. if err != nil {
  45. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  46. return
  47. }
  48. defer llm.Close()
  49. ch := make(chan any)
  50. go func() {
  51. defer close(ch)
  52. fn := func(r api.GenerateResponse) {
  53. r.Model = req.Model
  54. r.CreatedAt = time.Now().UTC()
  55. if r.Done {
  56. r.TotalDuration = time.Since(start)
  57. }
  58. ch <- r
  59. }
  60. if err := llm.Predict(req.Context, prompt, fn); err != nil {
  61. ch <- gin.H{"error": err.Error()}
  62. }
  63. }()
  64. streamResponse(c, ch)
  65. }
  66. func pull(c *gin.Context) {
  67. var req api.PullRequest
  68. if err := c.ShouldBindJSON(&req); err != nil {
  69. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  70. return
  71. }
  72. ch := make(chan any)
  73. go func() {
  74. defer close(ch)
  75. fn := func(r api.ProgressResponse) {
  76. ch <- r
  77. }
  78. if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil {
  79. ch <- gin.H{"error": err.Error()}
  80. }
  81. }()
  82. streamResponse(c, ch)
  83. }
  84. func push(c *gin.Context) {
  85. var req api.PushRequest
  86. if err := c.ShouldBindJSON(&req); err != nil {
  87. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  88. return
  89. }
  90. ch := make(chan any)
  91. go func() {
  92. defer close(ch)
  93. fn := func(r api.ProgressResponse) {
  94. ch <- r
  95. }
  96. if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil {
  97. ch <- gin.H{"error": err.Error()}
  98. }
  99. }()
  100. streamResponse(c, ch)
  101. }
  102. func create(c *gin.Context) {
  103. var req api.CreateRequest
  104. if err := c.ShouldBindJSON(&req); err != nil {
  105. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  106. return
  107. }
  108. ch := make(chan any)
  109. go func() {
  110. defer close(ch)
  111. fn := func(status string) {
  112. ch <- api.CreateProgress{
  113. Status: status,
  114. }
  115. }
  116. if err := CreateModel(req.Name, req.Path, fn); err != nil {
  117. ch <- gin.H{"error": err.Error()}
  118. }
  119. }()
  120. streamResponse(c, ch)
  121. }
  122. func list(c *gin.Context) {
  123. var models []api.ListResponseModel
  124. fp, err := GetManifestPath()
  125. if err != nil {
  126. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  127. return
  128. }
  129. err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
  130. if err != nil {
  131. return err
  132. }
  133. if !info.IsDir() {
  134. fi, err := os.Stat(path)
  135. if err != nil {
  136. log.Printf("skipping file: %s", fp)
  137. return nil
  138. }
  139. path := path[len(fp)+1:]
  140. slashIndex := strings.LastIndex(path, "/")
  141. if slashIndex == -1 {
  142. return nil
  143. }
  144. tag := path[:slashIndex] + ":" + path[slashIndex+1:]
  145. mp := ParseModelPath(tag)
  146. manifest, err := GetManifest(mp)
  147. if err != nil {
  148. log.Printf("skipping file: %s", fp)
  149. return nil
  150. }
  151. model := api.ListResponseModel{
  152. Name: mp.GetShortTagname(),
  153. Size: manifest.GetTotalSize(),
  154. ModifiedAt: fi.ModTime(),
  155. }
  156. models = append(models, model)
  157. }
  158. return nil
  159. })
  160. if err != nil {
  161. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  162. return
  163. }
  164. c.JSON(http.StatusOK, api.ListResponse{models})
  165. }
  166. func Serve(ln net.Listener) error {
  167. r := gin.Default()
  168. r.GET("/", func(c *gin.Context) {
  169. c.String(http.StatusOK, "Ollama is running")
  170. })
  171. r.POST("/api/pull", pull)
  172. r.POST("/api/generate", generate)
  173. r.POST("/api/create", create)
  174. r.POST("/api/push", push)
  175. r.GET("/api/tags", list)
  176. log.Printf("Listening on %s", ln.Addr())
  177. s := &http.Server{
  178. Handler: r,
  179. }
  180. return s.Serve(ln)
  181. }
  182. func streamResponse(c *gin.Context, ch chan any) {
  183. c.Stream(func(w io.Writer) bool {
  184. val, ok := <-ch
  185. if !ok {
  186. return false
  187. }
  188. bts, err := json.Marshal(val)
  189. if err != nil {
  190. return false
  191. }
  192. bts = append(bts, '\n')
  193. if _, err := w.Write(bts); err != nil {
  194. return false
  195. }
  196. return true
  197. })
  198. }