routes.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. package server
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log"
  8. "net"
  9. "net/http"
  10. "os"
  11. "path/filepath"
  12. "strings"
  13. "time"
  14. "dario.cat/mergo"
  15. "github.com/gin-contrib/cors"
  16. "github.com/gin-gonic/gin"
  17. "github.com/jmorganca/ollama/api"
  18. "github.com/jmorganca/ollama/llama"
  19. )
  20. func GenerateHandler(c *gin.Context) {
  21. start := time.Now()
  22. var req api.GenerateRequest
  23. if err := c.ShouldBindJSON(&req); err != nil {
  24. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  25. return
  26. }
  27. model, err := GetModel(req.Model)
  28. if err != nil {
  29. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  30. return
  31. }
  32. opts := api.DefaultOptions()
  33. if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
  34. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  35. return
  36. }
  37. if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
  38. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  39. return
  40. }
  41. prompt, err := model.Prompt(req)
  42. if err != nil {
  43. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  44. return
  45. }
  46. llm, err := llama.New(model.ModelPath, opts)
  47. if err != nil {
  48. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  49. return
  50. }
  51. defer llm.Close()
  52. ch := make(chan any)
  53. go func() {
  54. defer close(ch)
  55. fn := func(r api.GenerateResponse) {
  56. r.Model = req.Model
  57. r.CreatedAt = time.Now().UTC()
  58. if r.Done {
  59. r.TotalDuration = time.Since(start)
  60. }
  61. ch <- r
  62. }
  63. if err := llm.Predict(req.Context, prompt, fn); err != nil {
  64. ch <- gin.H{"error": err.Error()}
  65. }
  66. }()
  67. streamResponse(c, ch)
  68. }
  69. func PullModelHandler(c *gin.Context) {
  70. var req api.PullRequest
  71. if err := c.ShouldBindJSON(&req); err != nil {
  72. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  73. return
  74. }
  75. ch := make(chan any)
  76. go func() {
  77. defer close(ch)
  78. fn := func(r api.ProgressResponse) {
  79. ch <- r
  80. }
  81. regOpts := &RegistryOptions{
  82. Insecure: req.Insecure,
  83. Username: req.Username,
  84. Password: req.Password,
  85. }
  86. if err := PullModel(req.Name, regOpts, fn); err != nil {
  87. ch <- gin.H{"error": err.Error()}
  88. }
  89. }()
  90. streamResponse(c, ch)
  91. }
  92. func PushModelHandler(c *gin.Context) {
  93. var req api.PushRequest
  94. if err := c.ShouldBindJSON(&req); err != nil {
  95. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  96. return
  97. }
  98. ch := make(chan any)
  99. go func() {
  100. defer close(ch)
  101. fn := func(r api.ProgressResponse) {
  102. ch <- r
  103. }
  104. regOpts := &RegistryOptions{
  105. Insecure: req.Insecure,
  106. Username: req.Username,
  107. Password: req.Password,
  108. }
  109. if err := PushModel(req.Name, regOpts, fn); err != nil {
  110. ch <- gin.H{"error": err.Error()}
  111. }
  112. }()
  113. streamResponse(c, ch)
  114. }
  115. func CreateModelHandler(c *gin.Context) {
  116. var req api.CreateRequest
  117. if err := c.ShouldBindJSON(&req); err != nil {
  118. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  119. return
  120. }
  121. ch := make(chan any)
  122. go func() {
  123. defer close(ch)
  124. fn := func(resp api.ProgressResponse) {
  125. ch <- resp
  126. }
  127. if err := CreateModel(req.Name, req.Path, fn); err != nil {
  128. ch <- gin.H{"error": err.Error()}
  129. }
  130. }()
  131. streamResponse(c, ch)
  132. }
  133. func DeleteModelHandler(c *gin.Context) {
  134. var req api.DeleteRequest
  135. if err := c.ShouldBindJSON(&req); err != nil {
  136. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  137. return
  138. }
  139. if err := DeleteModel(req.Name); err != nil {
  140. if os.IsNotExist(err) {
  141. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)})
  142. } else {
  143. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  144. }
  145. return
  146. }
  147. }
  148. func ListModelsHandler(c *gin.Context) {
  149. var models []api.ListResponseModel
  150. fp, err := GetManifestPath()
  151. if err != nil {
  152. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  153. return
  154. }
  155. err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
  156. if err != nil {
  157. if errors.Is(err, os.ErrNotExist) {
  158. log.Printf("manifest file does not exist: %s", fp)
  159. return nil
  160. }
  161. return err
  162. }
  163. if !info.IsDir() {
  164. fi, err := os.Stat(path)
  165. if err != nil {
  166. log.Printf("skipping file: %s", fp)
  167. return nil
  168. }
  169. path := path[len(fp)+1:]
  170. slashIndex := strings.LastIndex(path, "/")
  171. if slashIndex == -1 {
  172. return nil
  173. }
  174. tag := path[:slashIndex] + ":" + path[slashIndex+1:]
  175. mp := ParseModelPath(tag)
  176. manifest, err := GetManifest(mp)
  177. if err != nil {
  178. log.Printf("skipping file: %s", fp)
  179. return nil
  180. }
  181. model := api.ListResponseModel{
  182. Name: mp.GetShortTagname(),
  183. Size: manifest.GetTotalSize(),
  184. ModifiedAt: fi.ModTime(),
  185. }
  186. models = append(models, model)
  187. }
  188. return nil
  189. })
  190. if err != nil {
  191. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  192. return
  193. }
  194. c.JSON(http.StatusOK, api.ListResponse{models})
  195. }
  196. func CopyModelHandler(c *gin.Context) {
  197. var req api.CopyRequest
  198. if err := c.ShouldBindJSON(&req); err != nil {
  199. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  200. return
  201. }
  202. if err := CopyModel(req.Source, req.Destination); err != nil {
  203. if os.IsNotExist(err) {
  204. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})
  205. } else {
  206. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  207. }
  208. return
  209. }
  210. }
  211. func Serve(ln net.Listener) error {
  212. config := cors.DefaultConfig()
  213. config.AllowWildcard = true
  214. // only allow http/https from localhost
  215. config.AllowOrigins = []string{
  216. "http://localhost",
  217. "http://localhost:*",
  218. "https://localhost",
  219. "https://localhost:*",
  220. "http://127.0.0.1",
  221. "http://127.0.0.1:*",
  222. "https://127.0.0.1",
  223. "https://127.0.0.1:*",
  224. }
  225. r := gin.Default()
  226. r.Use(cors.New(config))
  227. r.GET("/", func(c *gin.Context) {
  228. c.String(http.StatusOK, "Ollama is running")
  229. })
  230. r.POST("/api/pull", PullModelHandler)
  231. r.POST("/api/generate", GenerateHandler)
  232. r.POST("/api/create", CreateModelHandler)
  233. r.POST("/api/push", PushModelHandler)
  234. r.POST("/api/copy", CopyModelHandler)
  235. r.GET("/api/tags", ListModelsHandler)
  236. r.DELETE("/api/delete", DeleteModelHandler)
  237. log.Printf("Listening on %s", ln.Addr())
  238. s := &http.Server{
  239. Handler: r,
  240. }
  241. return s.Serve(ln)
  242. }
  243. func streamResponse(c *gin.Context, ch chan any) {
  244. c.Stream(func(w io.Writer) bool {
  245. val, ok := <-ch
  246. if !ok {
  247. return false
  248. }
  249. bts, err := json.Marshal(val)
  250. if err != nil {
  251. return false
  252. }
  253. bts = append(bts, '\n')
  254. if _, err := w.Write(bts); err != nil {
  255. return false
  256. }
  257. return true
  258. })
  259. }