routes.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. package server
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "io"
  6. "log"
  7. "net"
  8. "net/http"
  9. "os"
  10. "path/filepath"
  11. "strings"
  12. "time"
  13. "dario.cat/mergo"
  14. "github.com/gin-gonic/gin"
  15. "github.com/jmorganca/ollama/api"
  16. "github.com/jmorganca/ollama/llama"
  17. )
  18. func GenerateHandler(c *gin.Context) {
  19. start := time.Now()
  20. var req api.GenerateRequest
  21. if err := c.ShouldBindJSON(&req); err != nil {
  22. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  23. return
  24. }
  25. model, err := GetModel(req.Model)
  26. if err != nil {
  27. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  28. return
  29. }
  30. opts := api.DefaultOptions()
  31. if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
  32. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  33. return
  34. }
  35. if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
  36. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  37. return
  38. }
  39. prompt, err := model.Prompt(req)
  40. if err != nil {
  41. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  42. return
  43. }
  44. llm, err := llama.New(model.ModelPath, opts)
  45. if err != nil {
  46. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  47. return
  48. }
  49. defer llm.Close()
  50. ch := make(chan any)
  51. go func() {
  52. defer close(ch)
  53. fn := func(r api.GenerateResponse) {
  54. r.Model = req.Model
  55. r.CreatedAt = time.Now().UTC()
  56. if r.Done {
  57. r.TotalDuration = time.Since(start)
  58. }
  59. ch <- r
  60. }
  61. if err := llm.Predict(req.Context, prompt, fn); err != nil {
  62. ch <- gin.H{"error": err.Error()}
  63. }
  64. }()
  65. streamResponse(c, ch)
  66. }
  67. func PullModelHandler(c *gin.Context) {
  68. var req api.PullRequest
  69. if err := c.ShouldBindJSON(&req); err != nil {
  70. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  71. return
  72. }
  73. ch := make(chan any)
  74. go func() {
  75. defer close(ch)
  76. fn := func(r api.ProgressResponse) {
  77. ch <- r
  78. }
  79. regOpts := &RegistryOptions{
  80. Insecure: req.Insecure,
  81. Username: req.Username,
  82. Password: req.Password,
  83. }
  84. if err := PullModel(req.Name, regOpts, fn); err != nil {
  85. ch <- gin.H{"error": err.Error()}
  86. }
  87. }()
  88. streamResponse(c, ch)
  89. }
  90. func PushModelHandler(c *gin.Context) {
  91. var req api.PushRequest
  92. if err := c.ShouldBindJSON(&req); err != nil {
  93. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  94. return
  95. }
  96. ch := make(chan any)
  97. go func() {
  98. defer close(ch)
  99. fn := func(r api.ProgressResponse) {
  100. ch <- r
  101. }
  102. regOpts := &RegistryOptions{
  103. Insecure: req.Insecure,
  104. Username: req.Username,
  105. Password: req.Password,
  106. }
  107. if err := PushModel(req.Name, regOpts, fn); err != nil {
  108. ch <- gin.H{"error": err.Error()}
  109. }
  110. }()
  111. streamResponse(c, ch)
  112. }
  113. func CreateModelHandler(c *gin.Context) {
  114. var req api.CreateRequest
  115. if err := c.ShouldBindJSON(&req); err != nil {
  116. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  117. return
  118. }
  119. ch := make(chan any)
  120. go func() {
  121. defer close(ch)
  122. fn := func(status string) {
  123. ch <- api.CreateProgress{
  124. Status: status,
  125. }
  126. }
  127. if err := CreateModel(req.Name, req.Path, fn); err != nil {
  128. ch <- gin.H{"error": err.Error()}
  129. }
  130. }()
  131. streamResponse(c, ch)
  132. }
  133. func DeleteModelHandler(c *gin.Context) {
  134. var req api.DeleteRequest
  135. if err := c.ShouldBindJSON(&req); err != nil {
  136. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  137. return
  138. }
  139. ch := make(chan any)
  140. go func() {
  141. defer close(ch)
  142. fn := func(r api.ProgressResponse) {
  143. ch <- r
  144. }
  145. if err := DeleteModel(req.Name, fn); err != nil {
  146. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  147. return
  148. }
  149. }()
  150. streamResponse(c, ch)
  151. }
  152. func ListModelsHandler(c *gin.Context) {
  153. var models []api.ListResponseModel
  154. fp, err := GetManifestPath()
  155. if err != nil {
  156. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  157. return
  158. }
  159. err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
  160. if err != nil {
  161. if errors.Is(err, os.ErrNotExist) {
  162. log.Printf("manifest file does not exist: %s", fp)
  163. return nil
  164. }
  165. return err
  166. }
  167. if !info.IsDir() {
  168. fi, err := os.Stat(path)
  169. if err != nil {
  170. log.Printf("skipping file: %s", fp)
  171. return nil
  172. }
  173. path := path[len(fp)+1:]
  174. slashIndex := strings.LastIndex(path, "/")
  175. if slashIndex == -1 {
  176. return nil
  177. }
  178. tag := path[:slashIndex] + ":" + path[slashIndex+1:]
  179. mp := ParseModelPath(tag)
  180. manifest, err := GetManifest(mp)
  181. if err != nil {
  182. log.Printf("skipping file: %s", fp)
  183. return nil
  184. }
  185. model := api.ListResponseModel{
  186. Name: mp.GetShortTagname(),
  187. Size: manifest.GetTotalSize(),
  188. ModifiedAt: fi.ModTime(),
  189. }
  190. models = append(models, model)
  191. }
  192. return nil
  193. })
  194. if err != nil {
  195. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  196. return
  197. }
  198. c.JSON(http.StatusOK, api.ListResponse{models})
  199. }
  200. func Serve(ln net.Listener) error {
  201. r := gin.Default()
  202. r.GET("/", func(c *gin.Context) {
  203. c.String(http.StatusOK, "Ollama is running")
  204. })
  205. r.POST("/api/pull", PullModelHandler)
  206. r.POST("/api/generate", GenerateHandler)
  207. r.POST("/api/create", CreateModelHandler)
  208. r.POST("/api/push", PushModelHandler)
  209. r.GET("/api/tags", ListModelsHandler)
  210. r.DELETE("/api/delete", DeleteModelHandler)
  211. log.Printf("Listening on %s", ln.Addr())
  212. s := &http.Server{
  213. Handler: r,
  214. }
  215. return s.Serve(ln)
  216. }
  217. func streamResponse(c *gin.Context, ch chan any) {
  218. c.Stream(func(w io.Writer) bool {
  219. val, ok := <-ch
  220. if !ok {
  221. return false
  222. }
  223. bts, err := json.Marshal(val)
  224. if err != nil {
  225. return false
  226. }
  227. bts = append(bts, '\n')
  228. if _, err := w.Write(bts); err != nil {
  229. return false
  230. }
  231. return true
  232. })
  233. }