routes.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. package server
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log"
  8. "net"
  9. "net/http"
  10. "os"
  11. "path/filepath"
  12. "strings"
  13. "time"
  14. "dario.cat/mergo"
  15. "github.com/gin-contrib/cors"
  16. "github.com/gin-gonic/gin"
  17. "github.com/jmorganca/ollama/api"
  18. "github.com/jmorganca/ollama/llama"
  19. )
  20. func GenerateHandler(c *gin.Context) {
  21. start := time.Now()
  22. var req api.GenerateRequest
  23. if err := c.ShouldBindJSON(&req); err != nil {
  24. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  25. return
  26. }
  27. model, err := GetModel(req.Model)
  28. if err != nil {
  29. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  30. return
  31. }
  32. opts := api.DefaultOptions()
  33. if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
  34. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  35. return
  36. }
  37. if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
  38. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  39. return
  40. }
  41. prompt, err := model.Prompt(req)
  42. if err != nil {
  43. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  44. return
  45. }
  46. llm, err := llama.New(model.ModelPath, opts)
  47. if err != nil {
  48. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  49. return
  50. }
  51. defer llm.Close()
  52. ch := make(chan any)
  53. go func() {
  54. defer close(ch)
  55. fn := func(r api.GenerateResponse) {
  56. r.Model = req.Model
  57. r.CreatedAt = time.Now().UTC()
  58. if r.Done {
  59. r.TotalDuration = time.Since(start)
  60. }
  61. ch <- r
  62. }
  63. if err := llm.Predict(req.Context, prompt, fn); err != nil {
  64. ch <- gin.H{"error": err.Error()}
  65. }
  66. }()
  67. streamResponse(c, ch)
  68. }
  69. func PullModelHandler(c *gin.Context) {
  70. var req api.PullRequest
  71. if err := c.ShouldBindJSON(&req); err != nil {
  72. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  73. return
  74. }
  75. ch := make(chan any)
  76. go func() {
  77. defer close(ch)
  78. fn := func(r api.ProgressResponse) {
  79. ch <- r
  80. }
  81. regOpts := &RegistryOptions{
  82. Insecure: req.Insecure,
  83. Username: req.Username,
  84. Password: req.Password,
  85. }
  86. if err := PullModel(req.Name, regOpts, fn); err != nil {
  87. ch <- gin.H{"error": err.Error()}
  88. }
  89. }()
  90. streamResponse(c, ch)
  91. }
  92. func PushModelHandler(c *gin.Context) {
  93. var req api.PushRequest
  94. if err := c.ShouldBindJSON(&req); err != nil {
  95. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  96. return
  97. }
  98. ch := make(chan any)
  99. go func() {
  100. defer close(ch)
  101. fn := func(r api.ProgressResponse) {
  102. ch <- r
  103. }
  104. regOpts := &RegistryOptions{
  105. Insecure: req.Insecure,
  106. Username: req.Username,
  107. Password: req.Password,
  108. }
  109. if err := PushModel(req.Name, regOpts, fn); err != nil {
  110. ch <- gin.H{"error": err.Error()}
  111. }
  112. }()
  113. streamResponse(c, ch)
  114. }
  115. func CreateModelHandler(c *gin.Context) {
  116. var req api.CreateRequest
  117. if err := c.ShouldBindJSON(&req); err != nil {
  118. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  119. return
  120. }
  121. ch := make(chan any)
  122. go func() {
  123. defer close(ch)
  124. fn := func(status string) {
  125. ch <- api.CreateProgress{
  126. Status: status,
  127. }
  128. }
  129. if err := CreateModel(req.Name, req.Path, fn); err != nil {
  130. ch <- gin.H{"error": err.Error()}
  131. }
  132. }()
  133. streamResponse(c, ch)
  134. }
  135. func DeleteModelHandler(c *gin.Context) {
  136. var req api.DeleteRequest
  137. if err := c.ShouldBindJSON(&req); err != nil {
  138. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  139. return
  140. }
  141. if err := DeleteModel(req.Name); err != nil {
  142. if os.IsNotExist(err) {
  143. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)})
  144. } else {
  145. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  146. }
  147. return
  148. }
  149. }
  150. func ListModelsHandler(c *gin.Context) {
  151. var models []api.ListResponseModel
  152. fp, err := GetManifestPath()
  153. if err != nil {
  154. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  155. return
  156. }
  157. err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
  158. if err != nil {
  159. if errors.Is(err, os.ErrNotExist) {
  160. log.Printf("manifest file does not exist: %s", fp)
  161. return nil
  162. }
  163. return err
  164. }
  165. if !info.IsDir() {
  166. fi, err := os.Stat(path)
  167. if err != nil {
  168. log.Printf("skipping file: %s", fp)
  169. return nil
  170. }
  171. path := path[len(fp)+1:]
  172. slashIndex := strings.LastIndex(path, "/")
  173. if slashIndex == -1 {
  174. return nil
  175. }
  176. tag := path[:slashIndex] + ":" + path[slashIndex+1:]
  177. mp := ParseModelPath(tag)
  178. manifest, err := GetManifest(mp)
  179. if err != nil {
  180. log.Printf("skipping file: %s", fp)
  181. return nil
  182. }
  183. model := api.ListResponseModel{
  184. Name: mp.GetShortTagname(),
  185. Size: manifest.GetTotalSize(),
  186. ModifiedAt: fi.ModTime(),
  187. }
  188. models = append(models, model)
  189. }
  190. return nil
  191. })
  192. if err != nil {
  193. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  194. return
  195. }
  196. c.JSON(http.StatusOK, api.ListResponse{models})
  197. }
  198. func CopyModelHandler(c *gin.Context) {
  199. var req api.CopyRequest
  200. if err := c.ShouldBindJSON(&req); err != nil {
  201. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  202. return
  203. }
  204. if err := CopyModel(req.Source, req.Destination); err != nil {
  205. if os.IsNotExist(err) {
  206. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})
  207. } else {
  208. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  209. }
  210. return
  211. }
  212. }
  213. func Serve(ln net.Listener) error {
  214. config := cors.DefaultConfig()
  215. config.AllowWildcard = true
  216. // only allow http/https from localhost
  217. config.AllowOrigins = []string{
  218. "http://localhost",
  219. "http://localhost:*",
  220. "https://localhost",
  221. "https://localhost:*",
  222. "http://127.0.0.1",
  223. "http://127.0.0.1:*",
  224. "https://127.0.0.1",
  225. "https://127.0.0.1:*",
  226. }
  227. r := gin.Default()
  228. r.Use(cors.New(config))
  229. r.GET("/", func(c *gin.Context) {
  230. c.String(http.StatusOK, "Ollama is running")
  231. })
  232. r.POST("/api/pull", PullModelHandler)
  233. r.POST("/api/generate", GenerateHandler)
  234. r.POST("/api/create", CreateModelHandler)
  235. r.POST("/api/push", PushModelHandler)
  236. r.POST("/api/copy", CopyModelHandler)
  237. r.GET("/api/tags", ListModelsHandler)
  238. r.DELETE("/api/delete", DeleteModelHandler)
  239. log.Printf("Listening on %s", ln.Addr())
  240. s := &http.Server{
  241. Handler: r,
  242. }
  243. return s.Serve(ln)
  244. }
  245. func streamResponse(c *gin.Context, ch chan any) {
  246. c.Stream(func(w io.Writer) bool {
  247. val, ok := <-ch
  248. if !ok {
  249. return false
  250. }
  251. bts, err := json.Marshal(val)
  252. if err != nil {
  253. return false
  254. }
  255. bts = append(bts, '\n')
  256. if _, err := w.Write(bts); err != nil {
  257. return false
  258. }
  259. return true
  260. })
  261. }