routes.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667
  1. package server
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "io/fs"
  9. "log"
  10. "net"
  11. "net/http"
  12. "os"
  13. "os/signal"
  14. "path/filepath"
  15. "reflect"
  16. "runtime"
  17. "strconv"
  18. "strings"
  19. "sync"
  20. "syscall"
  21. "time"
  22. "github.com/gin-contrib/cors"
  23. "github.com/gin-gonic/gin"
  24. "github.com/jmorganca/ollama/api"
  25. "github.com/jmorganca/ollama/llm"
  26. "github.com/jmorganca/ollama/version"
  27. )
  28. var mode string = gin.DebugMode
  29. func init() {
  30. switch mode {
  31. case gin.DebugMode:
  32. case gin.ReleaseMode:
  33. case gin.TestMode:
  34. default:
  35. mode = gin.DebugMode
  36. }
  37. gin.SetMode(mode)
  38. }
  39. var loaded struct {
  40. mu sync.Mutex
  41. llm llm.LLM
  42. expireAt time.Time
  43. expireTimer *time.Timer
  44. digest string
  45. options api.Options
  46. }
  47. var defaultSessionDuration = 5 * time.Minute
  48. // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
  49. func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
  50. opts := api.DefaultOptions()
  51. if err := opts.FromMap(model.Options); err != nil {
  52. log.Printf("could not load model options: %v", err)
  53. return err
  54. }
  55. if err := opts.FromMap(reqOpts); err != nil {
  56. return err
  57. }
  58. // check if the loaded model is still running in a subprocess, in case something unexpected happened
  59. if loaded.llm != nil {
  60. if err := loaded.llm.Ping(ctx); err != nil {
  61. log.Print("loaded llm process not responding, closing now")
  62. // the subprocess is no longer running, so close it
  63. loaded.llm.Close()
  64. loaded.llm = nil
  65. loaded.digest = ""
  66. }
  67. }
  68. if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, opts) {
  69. if loaded.llm != nil {
  70. log.Println("changing loaded model")
  71. loaded.llm.Close()
  72. loaded.llm = nil
  73. loaded.digest = ""
  74. }
  75. llmModel, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
  76. if err != nil {
  77. return err
  78. }
  79. // set cache values before modifying opts
  80. loaded.llm = llmModel
  81. loaded.digest = model.Digest
  82. loaded.options = opts
  83. if opts.NumKeep < 0 {
  84. promptWithSystem, err := model.Prompt(api.GenerateRequest{})
  85. if err != nil {
  86. return err
  87. }
  88. promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}})
  89. if err != nil {
  90. return err
  91. }
  92. tokensWithSystem, err := llmModel.Encode(ctx, promptWithSystem)
  93. if err != nil {
  94. return err
  95. }
  96. tokensNoSystem, err := llmModel.Encode(ctx, promptNoSystem)
  97. if err != nil {
  98. return err
  99. }
  100. opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem)
  101. llmModel.SetOptions(opts)
  102. }
  103. }
  104. loaded.expireAt = time.Now().Add(sessionDuration)
  105. if loaded.expireTimer == nil {
  106. loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
  107. loaded.mu.Lock()
  108. defer loaded.mu.Unlock()
  109. if time.Now().Before(loaded.expireAt) {
  110. return
  111. }
  112. if loaded.llm == nil {
  113. return
  114. }
  115. loaded.llm.Close()
  116. loaded.llm = nil
  117. loaded.digest = ""
  118. })
  119. }
  120. loaded.expireTimer.Reset(sessionDuration)
  121. return nil
  122. }
  123. func GenerateHandler(c *gin.Context) {
  124. loaded.mu.Lock()
  125. defer loaded.mu.Unlock()
  126. checkpointStart := time.Now()
  127. var req api.GenerateRequest
  128. if err := c.ShouldBindJSON(&req); err != nil {
  129. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  130. return
  131. }
  132. model, err := GetModel(req.Model)
  133. if err != nil {
  134. var pErr *fs.PathError
  135. if errors.As(err, &pErr) {
  136. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
  137. return
  138. }
  139. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  140. return
  141. }
  142. workDir := c.GetString("workDir")
  143. // TODO: set this duration from the request if specified
  144. sessionDuration := defaultSessionDuration
  145. if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil {
  146. if errors.Is(err, api.ErrInvalidOpts) {
  147. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  148. return
  149. }
  150. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  151. return
  152. }
  153. checkpointLoaded := time.Now()
  154. prompt, err := model.Prompt(req)
  155. if err != nil {
  156. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  157. return
  158. }
  159. ch := make(chan any)
  160. go func() {
  161. defer close(ch)
  162. fn := func(r api.GenerateResponse) {
  163. loaded.expireAt = time.Now().Add(sessionDuration)
  164. loaded.expireTimer.Reset(sessionDuration)
  165. r.Model = req.Model
  166. r.CreatedAt = time.Now().UTC()
  167. if r.Done {
  168. r.TotalDuration = time.Since(checkpointStart)
  169. r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
  170. }
  171. ch <- r
  172. }
  173. // an empty request loads the model
  174. if req.Prompt == "" && req.Template == "" && req.System == "" {
  175. ch <- api.GenerateResponse{Model: req.Model, Done: true}
  176. } else {
  177. if err := loaded.llm.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
  178. ch <- gin.H{"error": err.Error()}
  179. }
  180. }
  181. }()
  182. if req.Stream != nil && !*req.Stream {
  183. var response api.GenerateResponse
  184. generated := ""
  185. for resp := range ch {
  186. if r, ok := resp.(api.GenerateResponse); ok {
  187. generated += r.Response
  188. response = r
  189. } else {
  190. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  191. return
  192. }
  193. }
  194. response.Response = generated
  195. c.JSON(http.StatusOK, response)
  196. return
  197. }
  198. streamResponse(c, ch)
  199. }
  200. func EmbeddingHandler(c *gin.Context) {
  201. loaded.mu.Lock()
  202. defer loaded.mu.Unlock()
  203. var req api.EmbeddingRequest
  204. if err := c.ShouldBindJSON(&req); err != nil {
  205. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  206. return
  207. }
  208. model, err := GetModel(req.Model)
  209. if err != nil {
  210. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  211. return
  212. }
  213. workDir := c.GetString("workDir")
  214. if err := load(c.Request.Context(), workDir, model, req.Options, 5*time.Minute); err != nil {
  215. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  216. return
  217. }
  218. if !loaded.options.EmbeddingOnly {
  219. c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"})
  220. return
  221. }
  222. embedding, err := loaded.llm.Embedding(c.Request.Context(), req.Prompt)
  223. if err != nil {
  224. log.Printf("embedding generation failed: %v", err)
  225. c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
  226. return
  227. }
  228. resp := api.EmbeddingResponse{
  229. Embedding: embedding,
  230. }
  231. c.JSON(http.StatusOK, resp)
  232. }
  233. func PullModelHandler(c *gin.Context) {
  234. var req api.PullRequest
  235. if err := c.ShouldBindJSON(&req); err != nil {
  236. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  237. return
  238. }
  239. ch := make(chan any)
  240. go func() {
  241. defer close(ch)
  242. fn := func(r api.ProgressResponse) {
  243. ch <- r
  244. }
  245. regOpts := &RegistryOptions{
  246. Insecure: req.Insecure,
  247. }
  248. ctx, cancel := context.WithCancel(c.Request.Context())
  249. defer cancel()
  250. if err := PullModel(ctx, req.Name, regOpts, fn); err != nil {
  251. ch <- gin.H{"error": err.Error()}
  252. }
  253. }()
  254. if req.Stream != nil && !*req.Stream {
  255. waitForStream(c, ch)
  256. return
  257. }
  258. streamResponse(c, ch)
  259. }
  260. func PushModelHandler(c *gin.Context) {
  261. var req api.PushRequest
  262. if err := c.ShouldBindJSON(&req); err != nil {
  263. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  264. return
  265. }
  266. ch := make(chan any)
  267. go func() {
  268. defer close(ch)
  269. fn := func(r api.ProgressResponse) {
  270. ch <- r
  271. }
  272. regOpts := &RegistryOptions{
  273. Insecure: req.Insecure,
  274. }
  275. ctx := context.Background()
  276. if err := PushModel(ctx, req.Name, regOpts, fn); err != nil {
  277. ch <- gin.H{"error": err.Error()}
  278. }
  279. }()
  280. if req.Stream != nil && !*req.Stream {
  281. waitForStream(c, ch)
  282. return
  283. }
  284. streamResponse(c, ch)
  285. }
  286. func CreateModelHandler(c *gin.Context) {
  287. var req api.CreateRequest
  288. if err := c.ShouldBindJSON(&req); err != nil {
  289. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  290. return
  291. }
  292. workDir := c.GetString("workDir")
  293. ch := make(chan any)
  294. go func() {
  295. defer close(ch)
  296. fn := func(resp api.ProgressResponse) {
  297. ch <- resp
  298. }
  299. ctx, cancel := context.WithCancel(c.Request.Context())
  300. defer cancel()
  301. if err := CreateModel(ctx, workDir, req.Name, req.Path, fn); err != nil {
  302. ch <- gin.H{"error": err.Error()}
  303. }
  304. }()
  305. if req.Stream != nil && !*req.Stream {
  306. waitForStream(c, ch)
  307. return
  308. }
  309. streamResponse(c, ch)
  310. }
  311. func DeleteModelHandler(c *gin.Context) {
  312. var req api.DeleteRequest
  313. if err := c.ShouldBindJSON(&req); err != nil {
  314. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  315. return
  316. }
  317. if err := DeleteModel(req.Name); err != nil {
  318. if os.IsNotExist(err) {
  319. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)})
  320. } else {
  321. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  322. }
  323. return
  324. }
  325. manifestsPath, err := GetManifestPath()
  326. if err != nil {
  327. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  328. return
  329. }
  330. if err := PruneDirectory(manifestsPath); err != nil {
  331. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  332. return
  333. }
  334. c.JSON(http.StatusOK, nil)
  335. }
  336. func ShowModelHandler(c *gin.Context) {
  337. var req api.ShowRequest
  338. if err := c.ShouldBindJSON(&req); err != nil {
  339. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  340. return
  341. }
  342. resp, err := GetModelInfo(req.Name)
  343. if err != nil {
  344. if os.IsNotExist(err) {
  345. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)})
  346. } else {
  347. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  348. }
  349. return
  350. }
  351. c.JSON(http.StatusOK, resp)
  352. }
  353. func GetModelInfo(name string) (*api.ShowResponse, error) {
  354. model, err := GetModel(name)
  355. if err != nil {
  356. return nil, err
  357. }
  358. resp := &api.ShowResponse{
  359. License: strings.Join(model.License, "\n"),
  360. System: model.System,
  361. Template: model.Template,
  362. }
  363. mf, err := ShowModelfile(model)
  364. if err != nil {
  365. return nil, err
  366. }
  367. resp.Modelfile = mf
  368. var params []string
  369. cs := 30
  370. for k, v := range model.Options {
  371. switch val := v.(type) {
  372. case string:
  373. params = append(params, fmt.Sprintf("%-*s %s", cs, k, val))
  374. case int:
  375. params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.Itoa(val)))
  376. case float64:
  377. params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatFloat(val, 'f', 0, 64)))
  378. case bool:
  379. params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatBool(val)))
  380. case []interface{}:
  381. for _, nv := range val {
  382. switch nval := nv.(type) {
  383. case string:
  384. params = append(params, fmt.Sprintf("%-*s %s", cs, k, nval))
  385. case int:
  386. params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.Itoa(nval)))
  387. case float64:
  388. params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatFloat(nval, 'f', 0, 64)))
  389. case bool:
  390. params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatBool(nval)))
  391. }
  392. }
  393. }
  394. }
  395. resp.Parameters = strings.Join(params, "\n")
  396. return resp, nil
  397. }
  398. func ListModelsHandler(c *gin.Context) {
  399. var models []api.ModelResponse
  400. fp, err := GetManifestPath()
  401. if err != nil {
  402. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  403. return
  404. }
  405. walkFunc := func(path string, info os.FileInfo, _ error) error {
  406. if !info.IsDir() {
  407. dir, file := filepath.Split(path)
  408. dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
  409. tag := strings.Join([]string{dir, file}, ":")
  410. mp := ParseModelPath(tag)
  411. manifest, digest, err := GetManifest(mp)
  412. if err != nil {
  413. log.Printf("skipping file: %s", fp)
  414. return nil
  415. }
  416. models = append(models, api.ModelResponse{
  417. Name: mp.GetShortTagname(),
  418. Size: manifest.GetTotalSize(),
  419. Digest: digest,
  420. ModifiedAt: info.ModTime(),
  421. })
  422. }
  423. return nil
  424. }
  425. if err := filepath.Walk(fp, walkFunc); err != nil {
  426. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  427. return
  428. }
  429. c.JSON(http.StatusOK, api.ListResponse{Models: models})
  430. }
  431. func CopyModelHandler(c *gin.Context) {
  432. var req api.CopyRequest
  433. if err := c.ShouldBindJSON(&req); err != nil {
  434. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  435. return
  436. }
  437. if err := CopyModel(req.Source, req.Destination); err != nil {
  438. if os.IsNotExist(err) {
  439. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})
  440. } else {
  441. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  442. }
  443. return
  444. }
  445. }
  446. var defaultAllowOrigins = []string{
  447. "localhost",
  448. "127.0.0.1",
  449. "0.0.0.0",
  450. }
  451. func Serve(ln net.Listener, allowOrigins []string) error {
  452. config := cors.DefaultConfig()
  453. config.AllowWildcard = true
  454. config.AllowOrigins = allowOrigins
  455. for _, allowOrigin := range defaultAllowOrigins {
  456. config.AllowOrigins = append(config.AllowOrigins,
  457. fmt.Sprintf("http://%s", allowOrigin),
  458. fmt.Sprintf("https://%s", allowOrigin),
  459. fmt.Sprintf("http://%s:*", allowOrigin),
  460. fmt.Sprintf("https://%s:*", allowOrigin),
  461. )
  462. }
  463. workDir, err := os.MkdirTemp("", "ollama")
  464. if err != nil {
  465. return err
  466. }
  467. defer os.RemoveAll(workDir)
  468. r := gin.Default()
  469. r.Use(
  470. cors.New(config),
  471. func(c *gin.Context) {
  472. c.Set("workDir", workDir)
  473. c.Next()
  474. },
  475. )
  476. r.POST("/api/pull", PullModelHandler)
  477. r.POST("/api/generate", GenerateHandler)
  478. r.POST("/api/embeddings", EmbeddingHandler)
  479. r.POST("/api/create", CreateModelHandler)
  480. r.POST("/api/push", PushModelHandler)
  481. r.POST("/api/copy", CopyModelHandler)
  482. r.DELETE("/api/delete", DeleteModelHandler)
  483. r.POST("/api/show", ShowModelHandler)
  484. for _, method := range []string{http.MethodGet, http.MethodHead} {
  485. r.Handle(method, "/", func(c *gin.Context) {
  486. c.String(http.StatusOK, "Ollama is running")
  487. })
  488. r.Handle(method, "/api/tags", ListModelsHandler)
  489. }
  490. log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
  491. s := &http.Server{
  492. Handler: r,
  493. }
  494. // listen for a ctrl+c and stop any loaded llm
  495. signals := make(chan os.Signal, 1)
  496. signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
  497. go func() {
  498. <-signals
  499. if loaded.llm != nil {
  500. loaded.llm.Close()
  501. }
  502. os.RemoveAll(workDir)
  503. os.Exit(0)
  504. }()
  505. if runtime.GOOS == "linux" {
  506. // check compatibility to log warnings
  507. if _, err := llm.CheckVRAM(); err != nil {
  508. log.Printf("Warning: GPU support may not enabled, check you have installed install GPU drivers: %v", err)
  509. }
  510. }
  511. return s.Serve(ln)
  512. }
  513. func waitForStream(c *gin.Context, ch chan interface{}) {
  514. c.Header("Content-Type", "application/json")
  515. for resp := range ch {
  516. switch r := resp.(type) {
  517. case api.ProgressResponse:
  518. if r.Status == "success" {
  519. c.JSON(http.StatusOK, r)
  520. return
  521. }
  522. case gin.H:
  523. if errorMsg, ok := r["error"].(string); ok {
  524. c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
  525. return
  526. } else {
  527. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
  528. return
  529. }
  530. default:
  531. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"})
  532. return
  533. }
  534. }
  535. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected end of progress response"})
  536. }
  537. func streamResponse(c *gin.Context, ch chan any) {
  538. c.Header("Content-Type", "application/x-ndjson")
  539. c.Stream(func(w io.Writer) bool {
  540. val, ok := <-ch
  541. if !ok {
  542. return false
  543. }
  544. bts, err := json.Marshal(val)
  545. if err != nil {
  546. log.Printf("streamResponse: json.Marshal failed with %s", err)
  547. return false
  548. }
  549. // Delineate chunks with new-line delimiter
  550. bts = append(bts, '\n')
  551. if _, err := w.Write(bts); err != nil {
  552. log.Printf("streamResponse: w.Write failed with %s", err)
  553. return false
  554. }
  555. return true
  556. })
  557. }