routes.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. package server
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log"
  8. "net"
  9. "net/http"
  10. "os"
  11. "path/filepath"
  12. "strings"
  13. "sync"
  14. "time"
  15. "dario.cat/mergo"
  16. "github.com/gin-contrib/cors"
  17. "github.com/gin-gonic/gin"
  18. "github.com/jmorganca/ollama/api"
  19. "github.com/jmorganca/ollama/llama"
  20. )
  21. var mu sync.Mutex
  22. var activeSession struct {
  23. ID int64
  24. *llama.LLM
  25. }
  26. func GenerateHandler(c *gin.Context) {
  27. mu.Lock()
  28. defer mu.Unlock()
  29. start := time.Now()
  30. var req api.GenerateRequest
  31. if err := c.ShouldBindJSON(&req); err != nil {
  32. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  33. return
  34. }
  35. model, err := GetModel(req.Model)
  36. if err != nil {
  37. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  38. return
  39. }
  40. if req.SessionID == 0 || req.SessionID != activeSession.ID {
  41. if activeSession.LLM != nil {
  42. activeSession.Close()
  43. activeSession.LLM = nil
  44. }
  45. opts := api.DefaultOptions()
  46. if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
  47. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  48. return
  49. }
  50. if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
  51. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  52. return
  53. }
  54. llm, err := llama.New(model.ModelPath, opts)
  55. if err != nil {
  56. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  57. return
  58. }
  59. activeSession.ID = time.Now().UnixNano()
  60. activeSession.LLM = llm
  61. }
  62. prompt, err := model.Prompt(req)
  63. if err != nil {
  64. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  65. return
  66. }
  67. ch := make(chan any)
  68. go func() {
  69. defer close(ch)
  70. fn := func(r api.GenerateResponse) {
  71. r.Model = req.Model
  72. r.CreatedAt = time.Now().UTC()
  73. r.SessionID = activeSession.ID
  74. if r.Done {
  75. r.TotalDuration = time.Since(start)
  76. }
  77. ch <- r
  78. }
  79. if err := activeSession.LLM.Predict(req.Context, prompt, fn); err != nil {
  80. ch <- gin.H{"error": err.Error()}
  81. }
  82. }()
  83. streamResponse(c, ch)
  84. }
  85. func PullModelHandler(c *gin.Context) {
  86. var req api.PullRequest
  87. if err := c.ShouldBindJSON(&req); err != nil {
  88. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  89. return
  90. }
  91. ch := make(chan any)
  92. go func() {
  93. defer close(ch)
  94. fn := func(r api.ProgressResponse) {
  95. ch <- r
  96. }
  97. regOpts := &RegistryOptions{
  98. Insecure: req.Insecure,
  99. Username: req.Username,
  100. Password: req.Password,
  101. }
  102. if err := PullModel(req.Name, regOpts, fn); err != nil {
  103. ch <- gin.H{"error": err.Error()}
  104. }
  105. }()
  106. streamResponse(c, ch)
  107. }
  108. func PushModelHandler(c *gin.Context) {
  109. var req api.PushRequest
  110. if err := c.ShouldBindJSON(&req); err != nil {
  111. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  112. return
  113. }
  114. ch := make(chan any)
  115. go func() {
  116. defer close(ch)
  117. fn := func(r api.ProgressResponse) {
  118. ch <- r
  119. }
  120. regOpts := &RegistryOptions{
  121. Insecure: req.Insecure,
  122. Username: req.Username,
  123. Password: req.Password,
  124. }
  125. if err := PushModel(req.Name, regOpts, fn); err != nil {
  126. ch <- gin.H{"error": err.Error()}
  127. }
  128. }()
  129. streamResponse(c, ch)
  130. }
  131. func CreateModelHandler(c *gin.Context) {
  132. var req api.CreateRequest
  133. if err := c.ShouldBindJSON(&req); err != nil {
  134. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  135. return
  136. }
  137. ch := make(chan any)
  138. go func() {
  139. defer close(ch)
  140. fn := func(resp api.ProgressResponse) {
  141. ch <- resp
  142. }
  143. if err := CreateModel(req.Name, req.Path, fn); err != nil {
  144. ch <- gin.H{"error": err.Error()}
  145. }
  146. }()
  147. streamResponse(c, ch)
  148. }
  149. func DeleteModelHandler(c *gin.Context) {
  150. var req api.DeleteRequest
  151. if err := c.ShouldBindJSON(&req); err != nil {
  152. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  153. return
  154. }
  155. if err := DeleteModel(req.Name); err != nil {
  156. if os.IsNotExist(err) {
  157. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)})
  158. } else {
  159. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  160. }
  161. return
  162. }
  163. }
  164. func ListModelsHandler(c *gin.Context) {
  165. var models []api.ListResponseModel
  166. fp, err := GetManifestPath()
  167. if err != nil {
  168. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  169. return
  170. }
  171. err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
  172. if err != nil {
  173. if errors.Is(err, os.ErrNotExist) {
  174. log.Printf("manifest file does not exist: %s", fp)
  175. return nil
  176. }
  177. return err
  178. }
  179. if !info.IsDir() {
  180. fi, err := os.Stat(path)
  181. if err != nil {
  182. log.Printf("skipping file: %s", fp)
  183. return nil
  184. }
  185. path := path[len(fp)+1:]
  186. slashIndex := strings.LastIndex(path, "/")
  187. if slashIndex == -1 {
  188. return nil
  189. }
  190. tag := path[:slashIndex] + ":" + path[slashIndex+1:]
  191. mp := ParseModelPath(tag)
  192. manifest, err := GetManifest(mp)
  193. if err != nil {
  194. log.Printf("skipping file: %s", fp)
  195. return nil
  196. }
  197. model := api.ListResponseModel{
  198. Name: mp.GetShortTagname(),
  199. Size: manifest.GetTotalSize(),
  200. ModifiedAt: fi.ModTime(),
  201. }
  202. models = append(models, model)
  203. }
  204. return nil
  205. })
  206. if err != nil {
  207. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  208. return
  209. }
  210. c.JSON(http.StatusOK, api.ListResponse{models})
  211. }
  212. func CopyModelHandler(c *gin.Context) {
  213. var req api.CopyRequest
  214. if err := c.ShouldBindJSON(&req); err != nil {
  215. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  216. return
  217. }
  218. if err := CopyModel(req.Source, req.Destination); err != nil {
  219. if os.IsNotExist(err) {
  220. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})
  221. } else {
  222. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  223. }
  224. return
  225. }
  226. }
  227. func Serve(ln net.Listener) error {
  228. config := cors.DefaultConfig()
  229. config.AllowWildcard = true
  230. // only allow http/https from localhost
  231. config.AllowOrigins = []string{
  232. "http://localhost",
  233. "http://localhost:*",
  234. "https://localhost",
  235. "https://localhost:*",
  236. "http://127.0.0.1",
  237. "http://127.0.0.1:*",
  238. "https://127.0.0.1",
  239. "https://127.0.0.1:*",
  240. }
  241. r := gin.Default()
  242. r.Use(cors.New(config))
  243. r.GET("/", func(c *gin.Context) {
  244. c.String(http.StatusOK, "Ollama is running")
  245. })
  246. r.POST("/api/pull", PullModelHandler)
  247. r.POST("/api/generate", GenerateHandler)
  248. r.POST("/api/create", CreateModelHandler)
  249. r.POST("/api/push", PushModelHandler)
  250. r.POST("/api/copy", CopyModelHandler)
  251. r.GET("/api/tags", ListModelsHandler)
  252. r.DELETE("/api/delete", DeleteModelHandler)
  253. log.Printf("Listening on %s", ln.Addr())
  254. s := &http.Server{
  255. Handler: r,
  256. }
  257. return s.Serve(ln)
  258. }
  259. func streamResponse(c *gin.Context, ch chan any) {
  260. c.Stream(func(w io.Writer) bool {
  261. val, ok := <-ch
  262. if !ok {
  263. return false
  264. }
  265. bts, err := json.Marshal(val)
  266. if err != nil {
  267. return false
  268. }
  269. bts = append(bts, '\n')
  270. if _, err := w.Write(bts); err != nil {
  271. return false
  272. }
  273. return true
  274. })
  275. }