routes.go 32 KB


  1. package server
  2. import (
  3. "bytes"
  4. "cmp"
  5. "context"
  6. "encoding/base64"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "log"
  12. "log/slog"
  13. "net"
  14. "net/http"
  15. "net/netip"
  16. "os"
  17. "os/signal"
  18. "path/filepath"
  19. "slices"
  20. "strings"
  21. "syscall"
  22. "time"
  23. "github.com/gin-contrib/cors"
  24. "github.com/gin-gonic/gin"
  25. "golang.org/x/crypto/ssh"
  26. "github.com/ollama/ollama/api"
  27. "github.com/ollama/ollama/auth"
  28. "github.com/ollama/ollama/envconfig"
  29. "github.com/ollama/ollama/gpu"
  30. "github.com/ollama/ollama/llm"
  31. "github.com/ollama/ollama/openai"
  32. "github.com/ollama/ollama/parser"
  33. "github.com/ollama/ollama/template"
  34. "github.com/ollama/ollama/types/errtypes"
  35. "github.com/ollama/ollama/types/model"
  36. "github.com/ollama/ollama/version"
  37. )
  38. var mode string = gin.DebugMode
  39. type Server struct {
  40. addr net.Addr
  41. sched *Scheduler
  42. }
  43. func init() {
  44. switch mode {
  45. case gin.DebugMode:
  46. case gin.ReleaseMode:
  47. case gin.TestMode:
  48. default:
  49. mode = gin.DebugMode
  50. }
  51. gin.SetMode(mode)
  52. }
  53. var errRequired = errors.New("is required")
  54. func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
  55. opts := api.DefaultOptions()
  56. if err := opts.FromMap(model.Options); err != nil {
  57. return api.Options{}, err
  58. }
  59. if err := opts.FromMap(requestOpts); err != nil {
  60. return api.Options{}, err
  61. }
  62. return opts, nil
  63. }
  64. // scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
  65. // It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
  66. func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
  67. if name == "" {
  68. return nil, nil, nil, fmt.Errorf("model %w", errRequired)
  69. }
  70. model, err := GetModel(name)
  71. if err != nil {
  72. return nil, nil, nil, err
  73. }
  74. if err := model.CheckCapabilities(caps...); err != nil {
  75. return nil, nil, nil, fmt.Errorf("%s %w", name, err)
  76. }
  77. opts, err := modelOptions(model, requestOpts)
  78. if err != nil {
  79. return nil, nil, nil, err
  80. }
  81. runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
  82. var runner *runnerRef
  83. select {
  84. case runner = <-runnerCh:
  85. case err = <-errCh:
  86. return nil, nil, nil, err
  87. }
  88. return runner.llama, model, &opts, nil
  89. }
  90. func (s *Server) GenerateHandler(c *gin.Context) {
  91. var req api.GenerateRequest
  92. if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
  93. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  94. return
  95. } else if err != nil {
  96. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  97. return
  98. }
  99. if req.Format != "" && req.Format != "json" {
  100. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
  101. return
  102. } else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
  103. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
  104. return
  105. }
  106. caps := []Capability{CapabilityCompletion}
  107. r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
  108. if errors.Is(err, errCapabilityCompletion) {
  109. c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
  110. return
  111. } else if err != nil {
  112. handleScheduleError(c, req.Model, err)
  113. return
  114. }
  115. if req.Prompt == "" {
  116. c.JSON(http.StatusOK, api.GenerateResponse{
  117. Model: req.Model,
  118. CreatedAt: time.Now().UTC(),
  119. Done: true,
  120. DoneReason: "load",
  121. })
  122. return
  123. }
  124. images := make([]llm.ImageData, len(req.Images))
  125. for i := range req.Images {
  126. images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
  127. }
  128. prompt := req.Prompt
  129. if !req.Raw {
  130. var msgs []api.Message
  131. if req.System != "" {
  132. msgs = append(msgs, api.Message{Role: "system", Content: req.System})
  133. } else if m.System != "" {
  134. msgs = append(msgs, api.Message{Role: "system", Content: m.System})
  135. }
  136. for _, i := range images {
  137. msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
  138. }
  139. msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
  140. tmpl := m.Template
  141. if req.Template != "" {
  142. tmpl, err = template.Parse(req.Template)
  143. if err != nil {
  144. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  145. return
  146. }
  147. }
  148. var b bytes.Buffer
  149. if req.Context != nil {
  150. s, err := r.Detokenize(c.Request.Context(), req.Context)
  151. if err != nil {
  152. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  153. return
  154. }
  155. b.WriteString(s)
  156. }
  157. if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
  158. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  159. return
  160. }
  161. prompt = b.String()
  162. }
  163. slog.Debug("generate request", "prompt", prompt, "images", images)
  164. ch := make(chan any)
  165. go func() {
  166. defer close(ch)
  167. if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
  168. Prompt: prompt,
  169. Images: images,
  170. Format: req.Format,
  171. Options: opts,
  172. }, func(r llm.CompletionResponse) {
  173. ch <- api.GenerateResponse{
  174. Model: req.Model,
  175. CreatedAt: time.Now().UTC(),
  176. Response: r.Content,
  177. Done: r.Done,
  178. DoneReason: r.DoneReason,
  179. Metrics: api.Metrics{
  180. PromptEvalCount: r.PromptEvalCount,
  181. PromptEvalDuration: r.PromptEvalDuration,
  182. EvalCount: r.EvalCount,
  183. EvalDuration: r.EvalDuration,
  184. },
  185. }
  186. }); err != nil {
  187. ch <- gin.H{"error": err.Error()}
  188. }
  189. }()
  190. if req.Stream != nil && !*req.Stream {
  191. var r api.GenerateResponse
  192. var sb strings.Builder
  193. for rr := range ch {
  194. switch t := rr.(type) {
  195. case api.GenerateResponse:
  196. sb.WriteString(t.Response)
  197. r = t
  198. case gin.H:
  199. msg, ok := t["error"].(string)
  200. if !ok {
  201. msg = "unexpected error format in response"
  202. }
  203. c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
  204. return
  205. default:
  206. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
  207. return
  208. }
  209. }
  210. r.Response = sb.String()
  211. c.JSON(http.StatusOK, r)
  212. return
  213. }
  214. streamResponse(c, ch)
  215. }
  216. func (s *Server) EmbeddingsHandler(c *gin.Context) {
  217. var req api.EmbeddingRequest
  218. if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
  219. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  220. return
  221. } else if err != nil {
  222. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  223. return
  224. }
  225. r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
  226. if err != nil {
  227. handleScheduleError(c, req.Model, err)
  228. return
  229. }
  230. // an empty request loads the model
  231. if req.Prompt == "" {
  232. c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
  233. return
  234. }
  235. embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
  236. if err != nil {
  237. slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
  238. c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
  239. return
  240. }
  241. c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
  242. }
  243. func (s *Server) PullModelHandler(c *gin.Context) {
  244. var req api.PullRequest
  245. err := c.ShouldBindJSON(&req)
  246. switch {
  247. case errors.Is(err, io.EOF):
  248. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  249. return
  250. case err != nil:
  251. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  252. return
  253. }
  254. name := model.ParseName(cmp.Or(req.Model, req.Name))
  255. if !name.IsValid() {
  256. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
  257. return
  258. }
  259. if err := checkNameExists(name); err != nil {
  260. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  261. return
  262. }
  263. ch := make(chan any)
  264. go func() {
  265. defer close(ch)
  266. fn := func(r api.ProgressResponse) {
  267. ch <- r
  268. }
  269. regOpts := &registryOptions{
  270. Insecure: req.Insecure,
  271. }
  272. ctx, cancel := context.WithCancel(c.Request.Context())
  273. defer cancel()
  274. if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
  275. ch <- gin.H{"error": err.Error()}
  276. }
  277. }()
  278. if req.Stream != nil && !*req.Stream {
  279. waitForStream(c, ch)
  280. return
  281. }
  282. streamResponse(c, ch)
  283. }
  284. func (s *Server) PushModelHandler(c *gin.Context) {
  285. var req api.PushRequest
  286. err := c.ShouldBindJSON(&req)
  287. switch {
  288. case errors.Is(err, io.EOF):
  289. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  290. return
  291. case err != nil:
  292. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  293. return
  294. }
  295. var model string
  296. if req.Model != "" {
  297. model = req.Model
  298. } else if req.Name != "" {
  299. model = req.Name
  300. } else {
  301. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  302. return
  303. }
  304. ch := make(chan any)
  305. go func() {
  306. defer close(ch)
  307. fn := func(r api.ProgressResponse) {
  308. ch <- r
  309. }
  310. regOpts := &registryOptions{
  311. Insecure: req.Insecure,
  312. }
  313. ctx, cancel := context.WithCancel(c.Request.Context())
  314. defer cancel()
  315. if err := PushModel(ctx, model, regOpts, fn); err != nil {
  316. ch <- gin.H{"error": err.Error()}
  317. }
  318. }()
  319. if req.Stream != nil && !*req.Stream {
  320. waitForStream(c, ch)
  321. return
  322. }
  323. streamResponse(c, ch)
  324. }
  325. func checkNameExists(name model.Name) error {
  326. names, err := Manifests()
  327. if err != nil {
  328. return err
  329. }
  330. for n := range names {
  331. if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
  332. return fmt.Errorf("a model with that name already exists")
  333. }
  334. }
  335. return nil
  336. }
  337. func (s *Server) CreateModelHandler(c *gin.Context) {
  338. var r api.CreateRequest
  339. if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
  340. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  341. return
  342. } else if err != nil {
  343. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  344. return
  345. }
  346. name := model.ParseName(cmp.Or(r.Model, r.Name))
  347. if !name.IsValid() {
  348. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
  349. return
  350. }
  351. if err := checkNameExists(name); err != nil {
  352. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  353. return
  354. }
  355. if r.Path == "" && r.Modelfile == "" {
  356. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
  357. return
  358. }
  359. var sr io.Reader = strings.NewReader(r.Modelfile)
  360. if r.Path != "" && r.Modelfile == "" {
  361. f, err := os.Open(r.Path)
  362. if err != nil {
  363. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
  364. return
  365. }
  366. defer f.Close()
  367. sr = f
  368. }
  369. f, err := parser.ParseFile(sr)
  370. if err != nil {
  371. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  372. return
  373. }
  374. ch := make(chan any)
  375. go func() {
  376. defer close(ch)
  377. fn := func(resp api.ProgressResponse) {
  378. ch <- resp
  379. }
  380. ctx, cancel := context.WithCancel(c.Request.Context())
  381. defer cancel()
  382. quantization := cmp.Or(r.Quantize, r.Quantization)
  383. if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
  384. ch <- gin.H{"error": err.Error()}
  385. }
  386. }()
  387. if r.Stream != nil && !*r.Stream {
  388. waitForStream(c, ch)
  389. return
  390. }
  391. streamResponse(c, ch)
  392. }
  393. func (s *Server) DeleteModelHandler(c *gin.Context) {
  394. var r api.DeleteRequest
  395. if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
  396. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  397. return
  398. } else if err != nil {
  399. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  400. return
  401. }
  402. n := model.ParseName(cmp.Or(r.Model, r.Name))
  403. if !n.IsValid() {
  404. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
  405. return
  406. }
  407. m, err := ParseNamedManifest(n)
  408. if err != nil {
  409. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  410. return
  411. }
  412. if err := m.Remove(); err != nil {
  413. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  414. return
  415. }
  416. if err := m.RemoveLayers(); err != nil {
  417. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  418. return
  419. }
  420. }
  421. func (s *Server) ShowModelHandler(c *gin.Context) {
  422. var req api.ShowRequest
  423. err := c.ShouldBindJSON(&req)
  424. switch {
  425. case errors.Is(err, io.EOF):
  426. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  427. return
  428. case err != nil:
  429. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  430. return
  431. }
  432. if req.Model != "" {
  433. // noop
  434. } else if req.Name != "" {
  435. req.Model = req.Name
  436. } else {
  437. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  438. return
  439. }
  440. resp, err := GetModelInfo(req)
  441. if err != nil {
  442. switch {
  443. case os.IsNotExist(err):
  444. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
  445. case err.Error() == "invalid model name":
  446. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  447. default:
  448. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  449. }
  450. return
  451. }
  452. c.JSON(http.StatusOK, resp)
  453. }
  454. func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
  455. m, err := GetModel(req.Model)
  456. if err != nil {
  457. return nil, err
  458. }
  459. modelDetails := api.ModelDetails{
  460. ParentModel: m.ParentModel,
  461. Format: m.Config.ModelFormat,
  462. Family: m.Config.ModelFamily,
  463. Families: m.Config.ModelFamilies,
  464. ParameterSize: m.Config.ModelType,
  465. QuantizationLevel: m.Config.FileType,
  466. }
  467. if req.System != "" {
  468. m.System = req.System
  469. }
  470. if req.Template != "" {
  471. m.Template, err = template.Parse(req.Template)
  472. if err != nil {
  473. return nil, err
  474. }
  475. }
  476. msgs := make([]api.Message, len(m.Messages))
  477. for i, msg := range m.Messages {
  478. msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
  479. }
  480. n := model.ParseName(req.Model)
  481. if !n.IsValid() {
  482. return nil, fmt.Errorf("invalid model name")
  483. }
  484. manifest, err := ParseNamedManifest(n)
  485. if err != nil {
  486. return nil, err
  487. }
  488. resp := &api.ShowResponse{
  489. License: strings.Join(m.License, "\n"),
  490. System: m.System,
  491. Template: m.Template.String(),
  492. Details: modelDetails,
  493. Messages: msgs,
  494. ModifiedAt: manifest.fi.ModTime(),
  495. }
  496. var params []string
  497. cs := 30
  498. for k, v := range m.Options {
  499. switch val := v.(type) {
  500. case []interface{}:
  501. for _, nv := range val {
  502. params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
  503. }
  504. default:
  505. params = append(params, fmt.Sprintf("%-*s %#v", cs, k, v))
  506. }
  507. }
  508. resp.Parameters = strings.Join(params, "\n")
  509. for k, v := range req.Options {
  510. if _, ok := req.Options[k]; ok {
  511. m.Options[k] = v
  512. }
  513. }
  514. var sb strings.Builder
  515. fmt.Fprintln(&sb, "# Modelfile generated by \"ollama show\"")
  516. fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
  517. fmt.Fprintf(&sb, "# FROM %s\n\n", m.ShortName)
  518. fmt.Fprint(&sb, m.String())
  519. resp.Modelfile = sb.String()
  520. kvData, err := getKVData(m.ModelPath, req.Verbose)
  521. if err != nil {
  522. return nil, err
  523. }
  524. delete(kvData, "general.name")
  525. delete(kvData, "tokenizer.chat_template")
  526. resp.ModelInfo = kvData
  527. if len(m.ProjectorPaths) > 0 {
  528. projectorData, err := getKVData(m.ProjectorPaths[0], req.Verbose)
  529. if err != nil {
  530. return nil, err
  531. }
  532. resp.ProjectorInfo = projectorData
  533. }
  534. return resp, nil
  535. }
  536. func getKVData(digest string, verbose bool) (llm.KV, error) {
  537. maxArraySize := 0
  538. if verbose {
  539. maxArraySize = -1
  540. }
  541. kvData, err := llm.LoadModel(digest, maxArraySize)
  542. if err != nil {
  543. return nil, err
  544. }
  545. kv := kvData.KV()
  546. if !verbose {
  547. for k := range kv {
  548. if t, ok := kv[k].([]any); len(t) > 5 && ok {
  549. kv[k] = []any{}
  550. }
  551. }
  552. }
  553. return kv, nil
  554. }
  555. func (s *Server) ListModelsHandler(c *gin.Context) {
  556. ms, err := Manifests()
  557. if err != nil {
  558. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  559. return
  560. }
  561. models := []api.ListModelResponse{}
  562. for n, m := range ms {
  563. f, err := m.Config.Open()
  564. if err != nil {
  565. slog.Warn("bad manifest filepath", "name", n, "error", err)
  566. continue
  567. }
  568. defer f.Close()
  569. var cf ConfigV2
  570. if err := json.NewDecoder(f).Decode(&cf); err != nil {
  571. slog.Warn("bad manifest config", "name", n, "error", err)
  572. continue
  573. }
  574. // tag should never be masked
  575. models = append(models, api.ListModelResponse{
  576. Model: n.DisplayShortest(),
  577. Name: n.DisplayShortest(),
  578. Size: m.Size(),
  579. Digest: m.digest,
  580. ModifiedAt: m.fi.ModTime(),
  581. Details: api.ModelDetails{
  582. Format: cf.ModelFormat,
  583. Family: cf.ModelFamily,
  584. Families: cf.ModelFamilies,
  585. ParameterSize: cf.ModelType,
  586. QuantizationLevel: cf.FileType,
  587. },
  588. })
  589. }
  590. slices.SortStableFunc(models, func(i, j api.ListModelResponse) int {
  591. // most recently modified first
  592. return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
  593. })
  594. c.JSON(http.StatusOK, api.ListResponse{Models: models})
  595. }
  596. func (s *Server) CopyModelHandler(c *gin.Context) {
  597. var r api.CopyRequest
  598. if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
  599. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  600. return
  601. } else if err != nil {
  602. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  603. return
  604. }
  605. src := model.ParseName(r.Source)
  606. if !src.IsValid() {
  607. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)})
  608. return
  609. }
  610. dst := model.ParseName(r.Destination)
  611. if !dst.IsValid() {
  612. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)})
  613. return
  614. }
  615. if err := checkNameExists(dst); err != nil {
  616. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  617. return
  618. }
  619. if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
  620. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
  621. } else if err != nil {
  622. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  623. }
  624. }
  625. func (s *Server) HeadBlobHandler(c *gin.Context) {
  626. path, err := GetBlobsPath(c.Param("digest"))
  627. if err != nil {
  628. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  629. return
  630. }
  631. if _, err := os.Stat(path); err != nil {
  632. c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
  633. return
  634. }
  635. c.Status(http.StatusOK)
  636. }
  637. func (s *Server) CreateBlobHandler(c *gin.Context) {
  638. if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
  639. p, err := GetBlobsPath(ib)
  640. if err != nil {
  641. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  642. return
  643. }
  644. if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) {
  645. slog.Info("evicting intermediate blob which no longer exists", "digest", ib)
  646. delete(intermediateBlobs, c.Param("digest"))
  647. } else if err != nil {
  648. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  649. return
  650. } else {
  651. c.Status(http.StatusOK)
  652. return
  653. }
  654. }
  655. path, err := GetBlobsPath(c.Param("digest"))
  656. if err != nil {
  657. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  658. return
  659. }
  660. _, err = os.Stat(path)
  661. switch {
  662. case errors.Is(err, os.ErrNotExist):
  663. // noop
  664. case err != nil:
  665. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  666. return
  667. default:
  668. c.Status(http.StatusOK)
  669. return
  670. }
  671. if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) {
  672. c.Header("LocalLocation", path)
  673. c.Status(http.StatusTemporaryRedirect)
  674. return
  675. }
  676. layer, err := NewLayer(c.Request.Body, "")
  677. if err != nil {
  678. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  679. return
  680. }
  681. if layer.Digest != c.Param("digest") {
  682. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
  683. return
  684. }
  685. c.Status(http.StatusCreated)
  686. }
  687. func (s *Server) IsLocal(c *gin.Context) bool {
  688. if authz := c.GetHeader("Authorization"); authz != "" {
  689. parts := strings.Split(authz, ":")
  690. if len(parts) != 3 {
  691. return false
  692. }
  693. clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
  694. if err != nil {
  695. return false
  696. }
  697. // partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
  698. requestData, err := base64.StdEncoding.DecodeString(parts[1])
  699. if err != nil {
  700. return false
  701. }
  702. partialRequestDataParts := strings.Split(string(requestData), ",")
  703. if len(partialRequestDataParts) != 3 {
  704. return false
  705. }
  706. signature, err := base64.StdEncoding.DecodeString(parts[2])
  707. if err != nil {
  708. return false
  709. }
  710. if err := clientPublicKey.Verify(requestData, &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
  711. return false
  712. }
  713. serverPublicKey, err := auth.GetPublicKey()
  714. if err != nil {
  715. log.Fatal(err)
  716. }
  717. if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
  718. return true
  719. }
  720. c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
  721. return false
  722. }
  723. return false
  724. }
  725. func isLocalIP(ip netip.Addr) bool {
  726. if interfaces, err := net.Interfaces(); err == nil {
  727. for _, iface := range interfaces {
  728. addrs, err := iface.Addrs()
  729. if err != nil {
  730. continue
  731. }
  732. for _, a := range addrs {
  733. if parsed, _, err := net.ParseCIDR(a.String()); err == nil {
  734. if parsed.String() == ip.String() {
  735. return true
  736. }
  737. }
  738. }
  739. }
  740. }
  741. return false
  742. }
  743. func allowedHost(host string) bool {
  744. if host == "" || host == "localhost" {
  745. return true
  746. }
  747. if hostname, err := os.Hostname(); err == nil && host == hostname {
  748. return true
  749. }
  750. var tlds = []string{
  751. "localhost",
  752. "local",
  753. "internal",
  754. }
  755. // check if the host is a local TLD
  756. for _, tld := range tlds {
  757. if strings.HasSuffix(host, "."+tld) {
  758. return true
  759. }
  760. }
  761. return false
  762. }
  763. func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
  764. return func(c *gin.Context) {
  765. if addr == nil {
  766. c.Next()
  767. return
  768. }
  769. if addr, err := netip.ParseAddrPort(addr.String()); err == nil && !addr.Addr().IsLoopback() {
  770. c.Next()
  771. return
  772. }
  773. host, _, err := net.SplitHostPort(c.Request.Host)
  774. if err != nil {
  775. host = c.Request.Host
  776. }
  777. if addr, err := netip.ParseAddr(host); err == nil {
  778. if addr.IsLoopback() || addr.IsPrivate() || addr.IsUnspecified() || isLocalIP(addr) {
  779. c.Next()
  780. return
  781. }
  782. }
  783. if allowedHost(host) {
  784. if c.Request.Method == http.MethodOptions {
  785. c.AbortWithStatus(http.StatusNoContent)
  786. return
  787. }
  788. c.Next()
  789. return
  790. }
  791. c.AbortWithStatus(http.StatusForbidden)
  792. }
  793. }
  794. func (s *Server) GenerateRoutes() http.Handler {
  795. config := cors.DefaultConfig()
  796. config.AllowWildcard = true
  797. config.AllowBrowserExtensions = true
  798. config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
  799. openAIProperties := []string{"lang", "package-version", "os", "arch", "runtime", "runtime-version", "async"}
  800. for _, prop := range openAIProperties {
  801. config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
  802. }
  803. config.AllowOrigins = envconfig.AllowOrigins
  804. r := gin.Default()
  805. r.Use(
  806. cors.New(config),
  807. allowedHostsMiddleware(s.addr),
  808. )
  809. r.POST("/api/pull", s.PullModelHandler)
  810. r.POST("/api/generate", s.GenerateHandler)
  811. r.POST("/api/chat", s.ChatHandler)
  812. r.POST("/api/embeddings", s.EmbeddingsHandler)
  813. r.POST("/api/create", s.CreateModelHandler)
  814. r.POST("/api/push", s.PushModelHandler)
  815. r.POST("/api/copy", s.CopyModelHandler)
  816. r.DELETE("/api/delete", s.DeleteModelHandler)
  817. r.POST("/api/show", s.ShowModelHandler)
  818. r.POST("/api/blobs/:digest", s.CreateBlobHandler)
  819. r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
  820. r.GET("/api/ps", s.ProcessHandler)
  821. // Compatibility endpoints
  822. r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
  823. r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
  824. r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
  825. r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
  826. for _, method := range []string{http.MethodGet, http.MethodHead} {
  827. r.Handle(method, "/", func(c *gin.Context) {
  828. c.String(http.StatusOK, "Ollama is running")
  829. })
  830. r.Handle(method, "/api/tags", s.ListModelsHandler)
  831. r.Handle(method, "/api/version", func(c *gin.Context) {
  832. c.JSON(http.StatusOK, gin.H{"version": version.Version})
  833. })
  834. }
  835. return r
  836. }
  837. func Serve(ln net.Listener) error {
  838. level := slog.LevelInfo
  839. if envconfig.Debug {
  840. level = slog.LevelDebug
  841. }
  842. slog.Info("server config", "env", envconfig.Values())
  843. handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
  844. Level: level,
  845. AddSource: true,
  846. ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
  847. if attr.Key == slog.SourceKey {
  848. source := attr.Value.Any().(*slog.Source)
  849. source.File = filepath.Base(source.File)
  850. }
  851. return attr
  852. },
  853. })
  854. slog.SetDefault(slog.New(handler))
  855. blobsDir, err := GetBlobsPath("")
  856. if err != nil {
  857. return err
  858. }
  859. if err := fixBlobs(blobsDir); err != nil {
  860. return err
  861. }
  862. if !envconfig.NoPrune {
  863. // clean up unused layers and manifests
  864. if err := PruneLayers(); err != nil {
  865. return err
  866. }
  867. manifestsPath, err := GetManifestPath()
  868. if err != nil {
  869. return err
  870. }
  871. if err := PruneDirectory(manifestsPath); err != nil {
  872. return err
  873. }
  874. }
  875. ctx, done := context.WithCancel(context.Background())
  876. schedCtx, schedDone := context.WithCancel(ctx)
  877. sched := InitScheduler(schedCtx)
  878. s := &Server{addr: ln.Addr(), sched: sched}
  879. http.Handle("/", s.GenerateRoutes())
  880. slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
  881. srvr := &http.Server{
  882. // Use http.DefaultServeMux so we get net/http/pprof for
  883. // free.
  884. //
  885. // TODO(bmizerany): Decide if we want to make this
  886. // configurable so it is not exposed by default, or allow
  887. // users to bind it to a different port. This was a quick
  888. // and easy way to get pprof, but it may not be the best
  889. // way.
  890. Handler: nil,
  891. }
  892. // listen for a ctrl+c and stop any loaded llm
  893. signals := make(chan os.Signal, 1)
  894. signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
  895. go func() {
  896. <-signals
  897. srvr.Close()
  898. schedDone()
  899. sched.unloadAllRunners()
  900. gpu.Cleanup()
  901. done()
  902. }()
  903. if err := llm.Init(); err != nil {
  904. return fmt.Errorf("unable to initialize llm library %w", err)
  905. }
  906. s.sched.Run(schedCtx)
  907. // At startup we retrieve GPU information so we can get log messages before loading a model
  908. // This will log warnings to the log in case we have problems with detected GPUs
  909. gpus := gpu.GetGPUInfo()
  910. gpus.LogDetails()
  911. err = srvr.Serve(ln)
  912. // If server is closed from the signal handler, wait for the ctx to be done
  913. // otherwise error out quickly
  914. if !errors.Is(err, http.ErrServerClosed) {
  915. return err
  916. }
  917. <-ctx.Done()
  918. return nil
  919. }
  920. func waitForStream(c *gin.Context, ch chan interface{}) {
  921. c.Header("Content-Type", "application/json")
  922. for resp := range ch {
  923. switch r := resp.(type) {
  924. case api.ProgressResponse:
  925. if r.Status == "success" {
  926. c.JSON(http.StatusOK, r)
  927. return
  928. }
  929. case gin.H:
  930. if errorMsg, ok := r["error"].(string); ok {
  931. c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
  932. return
  933. } else {
  934. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
  935. return
  936. }
  937. default:
  938. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"})
  939. return
  940. }
  941. }
  942. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected end of progress response"})
  943. }
  944. func streamResponse(c *gin.Context, ch chan any) {
  945. c.Header("Content-Type", "application/x-ndjson")
  946. c.Stream(func(w io.Writer) bool {
  947. val, ok := <-ch
  948. if !ok {
  949. return false
  950. }
  951. bts, err := json.Marshal(val)
  952. if err != nil {
  953. slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
  954. return false
  955. }
  956. // Delineate chunks with new-line delimiter
  957. bts = append(bts, '\n')
  958. if _, err := w.Write(bts); err != nil {
  959. slog.Info(fmt.Sprintf("streamResponse: w.Write failed with %s", err))
  960. return false
  961. }
  962. return true
  963. })
  964. }
  965. func (s *Server) ProcessHandler(c *gin.Context) {
  966. models := []api.ProcessModelResponse{}
  967. for _, v := range s.sched.loaded {
  968. model := v.model
  969. modelDetails := api.ModelDetails{
  970. Format: model.Config.ModelFormat,
  971. Family: model.Config.ModelFamily,
  972. Families: model.Config.ModelFamilies,
  973. ParameterSize: model.Config.ModelType,
  974. QuantizationLevel: model.Config.FileType,
  975. }
  976. mr := api.ProcessModelResponse{
  977. Model: model.ShortName,
  978. Name: model.ShortName,
  979. Size: int64(v.estimatedTotal),
  980. SizeVRAM: int64(v.estimatedVRAM),
  981. Digest: model.Digest,
  982. Details: modelDetails,
  983. ExpiresAt: v.expiresAt,
  984. }
  985. // The scheduler waits to set expiresAt, so if a model is loading it's
  986. // possible that it will be set to the unix epoch. For those cases, just
  987. // calculate the time w/ the sessionDuration instead.
  988. var epoch time.Time
  989. if v.expiresAt == epoch {
  990. mr.ExpiresAt = time.Now().Add(v.sessionDuration)
  991. }
  992. models = append(models, mr)
  993. }
  994. slices.SortStableFunc(models, func(i, j api.ProcessModelResponse) int {
  995. // longest duration remaining listed first
  996. return cmp.Compare(j.ExpiresAt.Unix(), i.ExpiresAt.Unix())
  997. })
  998. c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
  999. }
  1000. func (s *Server) ChatHandler(c *gin.Context) {
  1001. var req api.ChatRequest
  1002. if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
  1003. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  1004. return
  1005. } else if err != nil {
  1006. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  1007. return
  1008. }
  1009. caps := []Capability{CapabilityCompletion}
  1010. r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
  1011. if errors.Is(err, errCapabilityCompletion) {
  1012. c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
  1013. return
  1014. } else if err != nil {
  1015. handleScheduleError(c, req.Model, err)
  1016. return
  1017. }
  1018. if len(req.Messages) == 0 {
  1019. c.JSON(http.StatusOK, api.ChatResponse{
  1020. Model: req.Model,
  1021. CreatedAt: time.Now().UTC(),
  1022. Message: api.Message{Role: "assistant"},
  1023. Done: true,
  1024. DoneReason: "load",
  1025. })
  1026. return
  1027. }
  1028. prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
  1029. if err != nil {
  1030. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  1031. return
  1032. }
  1033. slog.Debug("chat request", "images", len(images), "prompt", prompt)
  1034. ch := make(chan any)
  1035. go func() {
  1036. defer close(ch)
  1037. if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
  1038. Prompt: prompt,
  1039. Images: images,
  1040. Format: req.Format,
  1041. Options: opts,
  1042. }, func(r llm.CompletionResponse) {
  1043. ch <- api.ChatResponse{
  1044. Model: req.Model,
  1045. CreatedAt: time.Now().UTC(),
  1046. Message: api.Message{Role: "assistant", Content: r.Content},
  1047. Done: r.Done,
  1048. DoneReason: r.DoneReason,
  1049. Metrics: api.Metrics{
  1050. PromptEvalCount: r.PromptEvalCount,
  1051. PromptEvalDuration: r.PromptEvalDuration,
  1052. EvalCount: r.EvalCount,
  1053. EvalDuration: r.EvalDuration,
  1054. },
  1055. }
  1056. }); err != nil {
  1057. ch <- gin.H{"error": err.Error()}
  1058. }
  1059. }()
  1060. if req.Stream != nil && !*req.Stream {
  1061. var r api.ChatResponse
  1062. var sb strings.Builder
  1063. for rr := range ch {
  1064. switch t := rr.(type) {
  1065. case api.ChatResponse:
  1066. sb.WriteString(t.Message.Content)
  1067. r = t
  1068. case gin.H:
  1069. msg, ok := t["error"].(string)
  1070. if !ok {
  1071. msg = "unexpected error format in response"
  1072. }
  1073. c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
  1074. return
  1075. default:
  1076. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
  1077. return
  1078. }
  1079. }
  1080. r.Message.Content = sb.String()
  1081. c.JSON(http.StatusOK, r)
  1082. return
  1083. }
  1084. streamResponse(c, ch)
  1085. }
  1086. func handleScheduleError(c *gin.Context, name string, err error) {
  1087. switch {
  1088. case errors.Is(err, errRequired):
  1089. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  1090. case errors.Is(err, context.Canceled):
  1091. c.JSON(499, gin.H{"error": "request canceled"})
  1092. case errors.Is(err, ErrMaxQueue):
  1093. c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
  1094. case errors.Is(err, os.ErrNotExist):
  1095. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
  1096. default:
  1097. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  1098. }
  1099. }