routes.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. package server
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log"
  8. "net"
  9. "net/http"
  10. "os"
  11. "path/filepath"
  12. "strings"
  13. "time"
  14. "dario.cat/mergo"
  15. "github.com/gin-gonic/gin"
  16. "github.com/jmorganca/ollama/api"
  17. "github.com/jmorganca/ollama/llama"
  18. )
  19. func GenerateHandler(c *gin.Context) {
  20. start := time.Now()
  21. var req api.GenerateRequest
  22. if err := c.ShouldBindJSON(&req); err != nil {
  23. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  24. return
  25. }
  26. model, err := GetModel(req.Model)
  27. if err != nil {
  28. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  29. return
  30. }
  31. opts := api.DefaultOptions()
  32. if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
  33. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  34. return
  35. }
  36. if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
  37. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  38. return
  39. }
  40. prompt, err := model.Prompt(req)
  41. if err != nil {
  42. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  43. return
  44. }
  45. llm, err := llama.New(model.ModelPath, opts)
  46. if err != nil {
  47. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  48. return
  49. }
  50. defer llm.Close()
  51. ch := make(chan any)
  52. go func() {
  53. defer close(ch)
  54. fn := func(r api.GenerateResponse) {
  55. r.Model = req.Model
  56. r.CreatedAt = time.Now().UTC()
  57. if r.Done {
  58. r.TotalDuration = time.Since(start)
  59. }
  60. ch <- r
  61. }
  62. if err := llm.Predict(req.Context, prompt, fn); err != nil {
  63. ch <- gin.H{"error": err.Error()}
  64. }
  65. }()
  66. streamResponse(c, ch)
  67. }
  68. func PullModelHandler(c *gin.Context) {
  69. var req api.PullRequest
  70. if err := c.ShouldBindJSON(&req); err != nil {
  71. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  72. return
  73. }
  74. ch := make(chan any)
  75. go func() {
  76. defer close(ch)
  77. fn := func(r api.ProgressResponse) {
  78. ch <- r
  79. }
  80. regOpts := &RegistryOptions{
  81. Insecure: req.Insecure,
  82. Username: req.Username,
  83. Password: req.Password,
  84. }
  85. if err := PullModel(req.Name, regOpts, fn); err != nil {
  86. ch <- gin.H{"error": err.Error()}
  87. }
  88. }()
  89. streamResponse(c, ch)
  90. }
  91. func PushModelHandler(c *gin.Context) {
  92. var req api.PushRequest
  93. if err := c.ShouldBindJSON(&req); err != nil {
  94. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  95. return
  96. }
  97. ch := make(chan any)
  98. go func() {
  99. defer close(ch)
  100. fn := func(r api.ProgressResponse) {
  101. ch <- r
  102. }
  103. regOpts := &RegistryOptions{
  104. Insecure: req.Insecure,
  105. Username: req.Username,
  106. Password: req.Password,
  107. }
  108. if err := PushModel(req.Name, regOpts, fn); err != nil {
  109. ch <- gin.H{"error": err.Error()}
  110. }
  111. }()
  112. streamResponse(c, ch)
  113. }
  114. func CreateModelHandler(c *gin.Context) {
  115. var req api.CreateRequest
  116. if err := c.ShouldBindJSON(&req); err != nil {
  117. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  118. return
  119. }
  120. ch := make(chan any)
  121. go func() {
  122. defer close(ch)
  123. fn := func(status string) {
  124. ch <- api.CreateProgress{
  125. Status: status,
  126. }
  127. }
  128. if err := CreateModel(req.Name, req.Path, fn); err != nil {
  129. ch <- gin.H{"error": err.Error()}
  130. }
  131. }()
  132. streamResponse(c, ch)
  133. }
  134. func DeleteModelHandler(c *gin.Context) {
  135. var req api.DeleteRequest
  136. if err := c.ShouldBindJSON(&req); err != nil {
  137. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  138. return
  139. }
  140. if err := DeleteModel(req.Name); err != nil {
  141. if os.IsNotExist(err) {
  142. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)})
  143. } else {
  144. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  145. }
  146. return
  147. }
  148. }
  149. func ListModelsHandler(c *gin.Context) {
  150. var models []api.ListResponseModel
  151. fp, err := GetManifestPath()
  152. if err != nil {
  153. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  154. return
  155. }
  156. err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
  157. if err != nil {
  158. if errors.Is(err, os.ErrNotExist) {
  159. log.Printf("manifest file does not exist: %s", fp)
  160. return nil
  161. }
  162. return err
  163. }
  164. if !info.IsDir() {
  165. fi, err := os.Stat(path)
  166. if err != nil {
  167. log.Printf("skipping file: %s", fp)
  168. return nil
  169. }
  170. path := path[len(fp)+1:]
  171. slashIndex := strings.LastIndex(path, "/")
  172. if slashIndex == -1 {
  173. return nil
  174. }
  175. tag := path[:slashIndex] + ":" + path[slashIndex+1:]
  176. mp := ParseModelPath(tag)
  177. manifest, err := GetManifest(mp)
  178. if err != nil {
  179. log.Printf("skipping file: %s", fp)
  180. return nil
  181. }
  182. model := api.ListResponseModel{
  183. Name: mp.GetShortTagname(),
  184. Size: manifest.GetTotalSize(),
  185. ModifiedAt: fi.ModTime(),
  186. }
  187. models = append(models, model)
  188. }
  189. return nil
  190. })
  191. if err != nil {
  192. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  193. return
  194. }
  195. c.JSON(http.StatusOK, api.ListResponse{models})
  196. }
  197. func Serve(ln net.Listener) error {
  198. r := gin.Default()
  199. r.GET("/", func(c *gin.Context) {
  200. c.String(http.StatusOK, "Ollama is running")
  201. })
  202. r.POST("/api/pull", PullModelHandler)
  203. r.POST("/api/generate", GenerateHandler)
  204. r.POST("/api/create", CreateModelHandler)
  205. r.POST("/api/push", PushModelHandler)
  206. r.GET("/api/tags", ListModelsHandler)
  207. r.DELETE("/api/delete", DeleteModelHandler)
  208. log.Printf("Listening on %s", ln.Addr())
  209. s := &http.Server{
  210. Handler: r,
  211. }
  212. return s.Serve(ln)
  213. }
  214. func streamResponse(c *gin.Context, ch chan any) {
  215. c.Stream(func(w io.Writer) bool {
  216. val, ok := <-ch
  217. if !ok {
  218. return false
  219. }
  220. bts, err := json.Marshal(val)
  221. if err != nil {
  222. return false
  223. }
  224. bts = append(bts, '\n')
  225. if _, err := w.Write(bts); err != nil {
  226. return false
  227. }
  228. return true
  229. })
  230. }