routes.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. package server
  2. import (
  3. "encoding/json"
  4. "io"
  5. "log"
  6. "net"
  7. "net/http"
  8. "os"
  9. "path/filepath"
  10. "strings"
  11. "time"
  12. "dario.cat/mergo"
  13. "github.com/gin-gonic/gin"
  14. "github.com/jmorganca/ollama/api"
  15. "github.com/jmorganca/ollama/llama"
  16. )
  17. func GenerateHandler(c *gin.Context) {
  18. start := time.Now()
  19. var req api.GenerateRequest
  20. if err := c.ShouldBindJSON(&req); err != nil {
  21. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  22. return
  23. }
  24. model, err := GetModel(req.Model)
  25. if err != nil {
  26. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  27. return
  28. }
  29. opts := api.DefaultOptions()
  30. if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
  31. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  32. return
  33. }
  34. if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
  35. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  36. return
  37. }
  38. prompt, err := model.Prompt(req)
  39. if err != nil {
  40. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  41. return
  42. }
  43. llm, err := llama.New(model.ModelPath, opts)
  44. if err != nil {
  45. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  46. return
  47. }
  48. defer llm.Close()
  49. ch := make(chan any)
  50. go func() {
  51. defer close(ch)
  52. fn := func(r api.GenerateResponse) {
  53. r.Model = req.Model
  54. r.CreatedAt = time.Now().UTC()
  55. if r.Done {
  56. r.TotalDuration = time.Since(start)
  57. }
  58. ch <- r
  59. }
  60. if err := llm.Predict(req.Context, prompt, fn); err != nil {
  61. ch <- gin.H{"error": err.Error()}
  62. }
  63. }()
  64. streamResponse(c, ch)
  65. }
  66. func PullModelHandler(c *gin.Context) {
  67. var req api.PullRequest
  68. if err := c.ShouldBindJSON(&req); err != nil {
  69. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  70. return
  71. }
  72. ch := make(chan any)
  73. go func() {
  74. defer close(ch)
  75. fn := func(r api.ProgressResponse) {
  76. ch <- r
  77. }
  78. regOpts := &RegistryOptions{
  79. Insecure: req.Insecure,
  80. Username: req.Username,
  81. Password: req.Password,
  82. }
  83. if err := PullModel(req.Name, regOpts, fn); err != nil {
  84. ch <- gin.H{"error": err.Error()}
  85. }
  86. }()
  87. streamResponse(c, ch)
  88. }
  89. func PushModelHandler(c *gin.Context) {
  90. var req api.PushRequest
  91. if err := c.ShouldBindJSON(&req); err != nil {
  92. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  93. return
  94. }
  95. ch := make(chan any)
  96. go func() {
  97. defer close(ch)
  98. fn := func(r api.ProgressResponse) {
  99. ch <- r
  100. }
  101. regOpts := &RegistryOptions{
  102. Insecure: req.Insecure,
  103. Username: req.Username,
  104. Password: req.Password,
  105. }
  106. if err := PushModel(req.Name, regOpts, fn); err != nil {
  107. ch <- gin.H{"error": err.Error()}
  108. }
  109. }()
  110. streamResponse(c, ch)
  111. }
  112. func CreateModelHandler(c *gin.Context) {
  113. var req api.CreateRequest
  114. if err := c.ShouldBindJSON(&req); err != nil {
  115. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  116. return
  117. }
  118. ch := make(chan any)
  119. go func() {
  120. defer close(ch)
  121. fn := func(status string) {
  122. ch <- api.CreateProgress{
  123. Status: status,
  124. }
  125. }
  126. if err := CreateModel(req.Name, req.Path, fn); err != nil {
  127. ch <- gin.H{"error": err.Error()}
  128. }
  129. }()
  130. streamResponse(c, ch)
  131. }
  132. func DeleteModelHandler(c *gin.Context) {
  133. var req api.DeleteRequest
  134. if err := c.ShouldBindJSON(&req); err != nil {
  135. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  136. return
  137. }
  138. ch := make(chan any)
  139. go func() {
  140. defer close(ch)
  141. fn := func(r api.ProgressResponse) {
  142. ch <- r
  143. }
  144. if err := DeleteModel(req.Name, fn); err != nil {
  145. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  146. return
  147. }
  148. }()
  149. streamResponse(c, ch)
  150. }
  151. func ListModelsHandler(c *gin.Context) {
  152. var models []api.ListResponseModel
  153. fp, err := GetManifestPath()
  154. if err != nil {
  155. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  156. return
  157. }
  158. err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
  159. if err != nil {
  160. return err
  161. }
  162. if !info.IsDir() {
  163. fi, err := os.Stat(path)
  164. if err != nil {
  165. log.Printf("skipping file: %s", fp)
  166. return nil
  167. }
  168. path := path[len(fp)+1:]
  169. slashIndex := strings.LastIndex(path, "/")
  170. if slashIndex == -1 {
  171. return nil
  172. }
  173. tag := path[:slashIndex] + ":" + path[slashIndex+1:]
  174. mp := ParseModelPath(tag)
  175. manifest, err := GetManifest(mp)
  176. if err != nil {
  177. log.Printf("skipping file: %s", fp)
  178. return nil
  179. }
  180. model := api.ListResponseModel{
  181. Name: mp.GetShortTagname(),
  182. Size: manifest.GetTotalSize(),
  183. ModifiedAt: fi.ModTime(),
  184. }
  185. models = append(models, model)
  186. }
  187. return nil
  188. })
  189. if err != nil {
  190. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  191. return
  192. }
  193. c.JSON(http.StatusOK, api.ListResponse{models})
  194. }
  195. func Serve(ln net.Listener) error {
  196. r := gin.Default()
  197. r.GET("/", func(c *gin.Context) {
  198. c.String(http.StatusOK, "Ollama is running")
  199. })
  200. r.POST("/api/pull", PullModelHandler)
  201. r.POST("/api/generate", GenerateHandler)
  202. r.POST("/api/create", CreateModelHandler)
  203. r.POST("/api/push", PushModelHandler)
  204. r.GET("/api/tags", ListModelsHandler)
  205. r.DELETE("/api/delete", DeleteModelHandler)
  206. log.Printf("Listening on %s", ln.Addr())
  207. s := &http.Server{
  208. Handler: r,
  209. }
  210. return s.Serve(ln)
  211. }
  212. func streamResponse(c *gin.Context, ch chan any) {
  213. c.Stream(func(w io.Writer) bool {
  214. val, ok := <-ch
  215. if !ok {
  216. return false
  217. }
  218. bts, err := json.Marshal(val)
  219. if err != nil {
  220. return false
  221. }
  222. bts = append(bts, '\n')
  223. if _, err := w.Write(bts); err != nil {
  224. return false
  225. }
  226. return true
  227. })
  228. }