routes.go 33 KB


  1. package server
  2. import (
  3. "bytes"
  4. "cmp"
  5. "context"
  6. "encoding/base64"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "log"
  12. "log/slog"
  13. "net"
  14. "net/http"
  15. "net/netip"
  16. "os"
  17. "os/signal"
  18. "path/filepath"
  19. "slices"
  20. "strings"
  21. "syscall"
  22. "time"
  23. "github.com/gin-contrib/cors"
  24. "github.com/gin-gonic/gin"
  25. "golang.org/x/crypto/ssh"
  26. "github.com/ollama/ollama/api"
  27. "github.com/ollama/ollama/auth"
  28. "github.com/ollama/ollama/envconfig"
  29. "github.com/ollama/ollama/gpu"
  30. "github.com/ollama/ollama/llm"
  31. "github.com/ollama/ollama/openai"
  32. "github.com/ollama/ollama/parser"
  33. "github.com/ollama/ollama/template"
  34. "github.com/ollama/ollama/types/errtypes"
  35. "github.com/ollama/ollama/types/model"
  36. "github.com/ollama/ollama/version"
  37. )
  38. var mode string = gin.DebugMode
  39. type Server struct {
  40. addr net.Addr
  41. sched *Scheduler
  42. }
  43. func init() {
  44. switch mode {
  45. case gin.DebugMode:
  46. case gin.ReleaseMode:
  47. case gin.TestMode:
  48. default:
  49. mode = gin.DebugMode
  50. }
  51. gin.SetMode(mode)
  52. }
  53. var errRequired = errors.New("is required")
  54. func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
  55. opts := api.DefaultOptions()
  56. if err := opts.FromMap(model.Options); err != nil {
  57. return api.Options{}, err
  58. }
  59. if err := opts.FromMap(requestOpts); err != nil {
  60. return api.Options{}, err
  61. }
  62. return opts, nil
  63. }
  64. // scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
  65. // It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
  66. func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
  67. if name == "" {
  68. return nil, nil, nil, fmt.Errorf("model %w", errRequired)
  69. }
  70. model, err := GetModel(name)
  71. if err != nil {
  72. return nil, nil, nil, err
  73. }
  74. if err := model.CheckCapabilities(caps...); err != nil {
  75. return nil, nil, nil, fmt.Errorf("%s %w", name, err)
  76. }
  77. opts, err := modelOptions(model, requestOpts)
  78. if err != nil {
  79. return nil, nil, nil, err
  80. }
  81. runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
  82. var runner *runnerRef
  83. select {
  84. case runner = <-runnerCh:
  85. case err = <-errCh:
  86. return nil, nil, nil, err
  87. }
  88. return runner.llama, model, &opts, nil
  89. }
  90. func (s *Server) GenerateHandler(c *gin.Context) {
  91. var req api.GenerateRequest
  92. if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
  93. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  94. return
  95. } else if err != nil {
  96. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  97. return
  98. }
  99. if req.Format != "" && req.Format != "json" {
  100. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
  101. return
  102. } else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
  103. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
  104. return
  105. }
  106. caps := []Capability{CapabilityCompletion}
  107. r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
  108. if errors.Is(err, errCapabilityCompletion) {
  109. c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
  110. return
  111. } else if err != nil {
  112. handleScheduleError(c, req.Model, err)
  113. return
  114. }
  115. if req.Prompt == "" {
  116. c.JSON(http.StatusOK, api.GenerateResponse{
  117. Model: req.Model,
  118. CreatedAt: time.Now().UTC(),
  119. Done: true,
  120. DoneReason: "load",
  121. })
  122. return
  123. }
  124. images := make([]llm.ImageData, len(req.Images))
  125. for i := range req.Images {
  126. images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
  127. }
  128. prompt := req.Prompt
  129. if !req.Raw {
  130. var msgs []api.Message
  131. if req.System != "" {
  132. msgs = append(msgs, api.Message{Role: "system", Content: req.System})
  133. } else if m.System != "" {
  134. msgs = append(msgs, api.Message{Role: "system", Content: m.System})
  135. }
  136. for _, i := range images {
  137. msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
  138. }
  139. msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
  140. tmpl := m.Template
  141. if req.Template != "" {
  142. tmpl, err = template.Parse(req.Template)
  143. if err != nil {
  144. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  145. return
  146. }
  147. }
  148. var b bytes.Buffer
  149. if req.Context != nil {
  150. s, err := r.Detokenize(c.Request.Context(), req.Context)
  151. if err != nil {
  152. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  153. return
  154. }
  155. b.WriteString(s)
  156. }
  157. if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
  158. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  159. return
  160. }
  161. prompt = b.String()
  162. }
  163. slog.Debug("generate request", "prompt", prompt, "images", images)
  164. ch := make(chan any)
  165. go func() {
  166. defer close(ch)
  167. if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
  168. Prompt: prompt,
  169. Images: images,
  170. Format: req.Format,
  171. Options: opts,
  172. }, func(r llm.CompletionResponse) {
  173. ch <- api.GenerateResponse{
  174. Model: req.Model,
  175. CreatedAt: time.Now().UTC(),
  176. Response: r.Content,
  177. Done: r.Done,
  178. DoneReason: r.DoneReason,
  179. Metrics: api.Metrics{
  180. PromptEvalCount: r.PromptEvalCount,
  181. PromptEvalDuration: r.PromptEvalDuration,
  182. EvalCount: r.EvalCount,
  183. EvalDuration: r.EvalDuration,
  184. },
  185. }
  186. }); err != nil {
  187. ch <- gin.H{"error": err.Error()}
  188. }
  189. }()
  190. if req.Stream != nil && !*req.Stream {
  191. var r api.GenerateResponse
  192. var sb strings.Builder
  193. for rr := range ch {
  194. switch t := rr.(type) {
  195. case api.GenerateResponse:
  196. sb.WriteString(t.Response)
  197. r = t
  198. case gin.H:
  199. msg, ok := t["error"].(string)
  200. if !ok {
  201. msg = "unexpected error format in response"
  202. }
  203. c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
  204. return
  205. default:
  206. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
  207. return
  208. }
  209. }
  210. r.Response = sb.String()
  211. c.JSON(http.StatusOK, r)
  212. return
  213. }
  214. streamResponse(c, ch)
  215. }
  216. func (s *Server) EmbeddingsHandler(c *gin.Context) {
  217. var req api.EmbeddingRequest
  218. if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
  219. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  220. return
  221. } else if err != nil {
  222. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  223. return
  224. }
  225. r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
  226. if err != nil {
  227. handleScheduleError(c, req.Model, err)
  228. return
  229. }
  230. // an empty request loads the model
  231. if req.Prompt == "" {
  232. c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
  233. return
  234. }
  235. embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
  236. if err != nil {
  237. slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
  238. c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
  239. return
  240. }
  241. c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
  242. }
  243. func (s *Server) PullModelHandler(c *gin.Context) {
  244. var req api.PullRequest
  245. err := c.ShouldBindJSON(&req)
  246. switch {
  247. case errors.Is(err, io.EOF):
  248. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  249. return
  250. case err != nil:
  251. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  252. return
  253. }
  254. name := model.ParseName(cmp.Or(req.Model, req.Name))
  255. if !name.IsValid() {
  256. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
  257. return
  258. }
  259. if err := checkNameExists(name); err != nil {
  260. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  261. return
  262. }
  263. ch := make(chan any)
  264. go func() {
  265. defer close(ch)
  266. fn := func(r api.ProgressResponse) {
  267. ch <- r
  268. }
  269. regOpts := &registryOptions{
  270. Insecure: req.Insecure,
  271. }
  272. ctx, cancel := context.WithCancel(c.Request.Context())
  273. defer cancel()
  274. if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
  275. ch <- gin.H{"error": err.Error()}
  276. }
  277. }()
  278. if req.Stream != nil && !*req.Stream {
  279. waitForStream(c, ch)
  280. return
  281. }
  282. streamResponse(c, ch)
  283. }
  284. func (s *Server) PushModelHandler(c *gin.Context) {
  285. var req api.PushRequest
  286. err := c.ShouldBindJSON(&req)
  287. switch {
  288. case errors.Is(err, io.EOF):
  289. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  290. return
  291. case err != nil:
  292. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  293. return
  294. }
  295. var model string
  296. if req.Model != "" {
  297. model = req.Model
  298. } else if req.Name != "" {
  299. model = req.Name
  300. } else {
  301. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  302. return
  303. }
  304. ch := make(chan any)
  305. go func() {
  306. defer close(ch)
  307. fn := func(r api.ProgressResponse) {
  308. ch <- r
  309. }
  310. regOpts := &registryOptions{
  311. Insecure: req.Insecure,
  312. }
  313. ctx, cancel := context.WithCancel(c.Request.Context())
  314. defer cancel()
  315. if err := PushModel(ctx, model, regOpts, fn); err != nil {
  316. ch <- gin.H{"error": err.Error()}
  317. }
  318. }()
  319. if req.Stream != nil && !*req.Stream {
  320. waitForStream(c, ch)
  321. return
  322. }
  323. streamResponse(c, ch)
  324. }
  325. func checkNameExists(name model.Name) error {
  326. names, err := Manifests()
  327. if err != nil {
  328. return err
  329. }
  330. for n := range names {
  331. if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
  332. return fmt.Errorf("a model with that name already exists")
  333. }
  334. }
  335. return nil
  336. }
  337. func (s *Server) CreateModelHandler(c *gin.Context) {
  338. var r api.CreateRequest
  339. if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
  340. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  341. return
  342. } else if err != nil {
  343. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  344. return
  345. }
  346. name := model.ParseName(cmp.Or(r.Model, r.Name))
  347. if !name.IsValid() {
  348. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
  349. return
  350. }
  351. if err := checkNameExists(name); err != nil {
  352. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  353. return
  354. }
  355. if r.Path == "" && r.Modelfile == "" {
  356. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
  357. return
  358. }
  359. var sr io.Reader = strings.NewReader(r.Modelfile)
  360. if r.Path != "" && r.Modelfile == "" {
  361. f, err := os.Open(r.Path)
  362. if err != nil {
  363. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
  364. return
  365. }
  366. defer f.Close()
  367. sr = f
  368. }
  369. f, err := parser.ParseFile(sr)
  370. if err != nil {
  371. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  372. return
  373. }
  374. ch := make(chan any)
  375. go func() {
  376. defer close(ch)
  377. fn := func(resp api.ProgressResponse) {
  378. ch <- resp
  379. }
  380. ctx, cancel := context.WithCancel(c.Request.Context())
  381. defer cancel()
  382. quantization := cmp.Or(r.Quantize, r.Quantization)
  383. if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
  384. ch <- gin.H{"error": err.Error()}
  385. }
  386. }()
  387. if r.Stream != nil && !*r.Stream {
  388. waitForStream(c, ch)
  389. return
  390. }
  391. streamResponse(c, ch)
  392. }
  393. func (s *Server) DeleteModelHandler(c *gin.Context) {
  394. var r api.DeleteRequest
  395. if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
  396. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  397. return
  398. } else if err != nil {
  399. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  400. return
  401. }
  402. n := model.ParseName(cmp.Or(r.Model, r.Name))
  403. if !n.IsValid() {
  404. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
  405. return
  406. }
  407. m, err := ParseNamedManifest(n)
  408. if err != nil {
  409. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  410. return
  411. }
  412. if err := m.Remove(); err != nil {
  413. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  414. return
  415. }
  416. if err := m.RemoveLayers(); err != nil {
  417. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  418. return
  419. }
  420. }
  421. func (s *Server) ShowModelHandler(c *gin.Context) {
  422. var req api.ShowRequest
  423. err := c.ShouldBindJSON(&req)
  424. switch {
  425. case errors.Is(err, io.EOF):
  426. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  427. return
  428. case err != nil:
  429. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  430. return
  431. }
  432. if req.Model != "" {
  433. // noop
  434. } else if req.Name != "" {
  435. req.Model = req.Name
  436. } else {
  437. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  438. return
  439. }
  440. resp, err := GetModelInfo(req)
  441. if err != nil {
  442. switch {
  443. case os.IsNotExist(err):
  444. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
  445. case err.Error() == "invalid model name":
  446. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  447. default:
  448. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  449. }
  450. return
  451. }
  452. c.JSON(http.StatusOK, resp)
  453. }
  454. func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
  455. m, err := GetModel(req.Model)
  456. if err != nil {
  457. return nil, err
  458. }
  459. modelDetails := api.ModelDetails{
  460. ParentModel: m.ParentModel,
  461. Format: m.Config.ModelFormat,
  462. Family: m.Config.ModelFamily,
  463. Families: m.Config.ModelFamilies,
  464. ParameterSize: m.Config.ModelType,
  465. QuantizationLevel: m.Config.FileType,
  466. }
  467. if req.System != "" {
  468. m.System = req.System
  469. }
  470. if req.Template != "" {
  471. m.Template, err = template.Parse(req.Template)
  472. if err != nil {
  473. return nil, err
  474. }
  475. }
  476. msgs := make([]api.Message, len(m.Messages))
  477. for i, msg := range m.Messages {
  478. msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
  479. }
  480. n := model.ParseName(req.Model)
  481. if !n.IsValid() {
  482. return nil, fmt.Errorf("invalid model name")
  483. }
  484. manifest, err := ParseNamedManifest(n)
  485. if err != nil {
  486. return nil, err
  487. }
  488. resp := &api.ShowResponse{
  489. License: strings.Join(m.License, "\n"),
  490. System: m.System,
  491. Template: m.Template.String(),
  492. Details: modelDetails,
  493. Messages: msgs,
  494. ModifiedAt: manifest.fi.ModTime(),
  495. }
  496. var params []string
  497. cs := 30
  498. for k, v := range m.Options {
  499. switch val := v.(type) {
  500. case []interface{}:
  501. for _, nv := range val {
  502. params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
  503. }
  504. default:
  505. params = append(params, fmt.Sprintf("%-*s %#v", cs, k, v))
  506. }
  507. }
  508. resp.Parameters = strings.Join(params, "\n")
  509. for k, v := range req.Options {
  510. if _, ok := req.Options[k]; ok {
  511. m.Options[k] = v
  512. }
  513. }
  514. var sb strings.Builder
  515. fmt.Fprintln(&sb, "# Modelfile generated by \"ollama show\"")
  516. fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
  517. fmt.Fprintf(&sb, "# FROM %s\n\n", m.ShortName)
  518. fmt.Fprint(&sb, m.String())
  519. resp.Modelfile = sb.String()
  520. kvData, err := getKVData(m.ModelPath, req.Verbose)
  521. if err != nil {
  522. return nil, err
  523. }
  524. delete(kvData, "general.name")
  525. delete(kvData, "tokenizer.chat_template")
  526. resp.ModelInfo = kvData
  527. if len(m.ProjectorPaths) > 0 {
  528. projectorData, err := getKVData(m.ProjectorPaths[0], req.Verbose)
  529. if err != nil {
  530. return nil, err
  531. }
  532. resp.ProjectorInfo = projectorData
  533. }
  534. return resp, nil
  535. }
  536. func getKVData(digest string, verbose bool) (llm.KV, error) {
  537. maxArraySize := 0
  538. if verbose {
  539. maxArraySize = -1
  540. }
  541. kvData, err := llm.LoadModel(digest, maxArraySize)
  542. if err != nil {
  543. return nil, err
  544. }
  545. kv := kvData.KV()
  546. if !verbose {
  547. for k := range kv {
  548. if t, ok := kv[k].([]any); len(t) > 5 && ok {
  549. kv[k] = []any{}
  550. }
  551. }
  552. }
  553. return kv, nil
  554. }
  555. func (s *Server) ListModelsHandler(c *gin.Context) {
  556. ms, err := Manifests()
  557. if err != nil {
  558. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  559. return
  560. }
  561. models := []api.ListModelResponse{}
  562. for n, m := range ms {
  563. f, err := m.Config.Open()
  564. if err != nil {
  565. slog.Warn("bad manifest filepath", "name", n, "error", err)
  566. continue
  567. }
  568. defer f.Close()
  569. var cf ConfigV2
  570. if err := json.NewDecoder(f).Decode(&cf); err != nil {
  571. slog.Warn("bad manifest config", "name", n, "error", err)
  572. continue
  573. }
  574. // tag should never be masked
  575. models = append(models, api.ListModelResponse{
  576. Model: n.DisplayShortest(),
  577. Name: n.DisplayShortest(),
  578. Size: m.Size(),
  579. Digest: m.digest,
  580. ModifiedAt: m.fi.ModTime(),
  581. Details: api.ModelDetails{
  582. Format: cf.ModelFormat,
  583. Family: cf.ModelFamily,
  584. Families: cf.ModelFamilies,
  585. ParameterSize: cf.ModelType,
  586. QuantizationLevel: cf.FileType,
  587. },
  588. })
  589. }
  590. slices.SortStableFunc(models, func(i, j api.ListModelResponse) int {
  591. // most recently modified first
  592. return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
  593. })
  594. c.JSON(http.StatusOK, api.ListResponse{Models: models})
  595. }
  596. func (s *Server) CopyModelHandler(c *gin.Context) {
  597. var r api.CopyRequest
  598. if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
  599. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  600. return
  601. } else if err != nil {
  602. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  603. return
  604. }
  605. src := model.ParseName(r.Source)
  606. if !src.IsValid() {
  607. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)})
  608. return
  609. }
  610. dst := model.ParseName(r.Destination)
  611. if !dst.IsValid() {
  612. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)})
  613. return
  614. }
  615. if err := checkNameExists(dst); err != nil {
  616. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  617. return
  618. }
  619. if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
  620. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
  621. } else if err != nil {
  622. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  623. }
  624. }
  625. func (s *Server) HeadBlobHandler(c *gin.Context) {
  626. path, err := GetBlobsPath(c.Param("digest"))
  627. if err != nil {
  628. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  629. return
  630. }
  631. if _, err := os.Stat(path); err != nil {
  632. c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
  633. return
  634. }
  635. c.Status(http.StatusOK)
  636. }
  637. func (s *Server) CreateBlobHandler(c *gin.Context) {
  638. if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
  639. p, err := GetBlobsPath(ib)
  640. if err != nil {
  641. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  642. return
  643. }
  644. if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) {
  645. slog.Info("evicting intermediate blob which no longer exists", "digest", ib)
  646. delete(intermediateBlobs, c.Param("digest"))
  647. } else if err != nil {
  648. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  649. return
  650. } else {
  651. c.Status(http.StatusOK)
  652. return
  653. }
  654. }
  655. path, err := GetBlobsPath(c.Param("digest"))
  656. if err != nil {
  657. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  658. return
  659. }
  660. _, err = os.Stat(path)
  661. switch {
  662. case errors.Is(err, os.ErrNotExist):
  663. // noop
  664. case err != nil:
  665. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  666. return
  667. default:
  668. c.Status(http.StatusOK)
  669. return
  670. }
  671. fmt.Println("hello")
  672. fmt.Println(s.IsLocal(c))
  673. if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) {
  674. fmt.Println("entered redirect")
  675. c.Header("LocalLocation", path)
  676. c.Status(http.StatusTemporaryRedirect)
  677. return
  678. }
  679. layer, err := NewLayer(c.Request.Body, "")
  680. if err != nil {
  681. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  682. return
  683. }
  684. if layer.Digest != c.Param("digest") {
  685. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
  686. return
  687. }
  688. c.Status(http.StatusCreated)
  689. }
  690. func (s *Server) IsLocal(c *gin.Context) bool {
  691. fmt.Println("entered islocal")
  692. fmt.Println(c.GetHeader("Authorization"), " is authorization")
  693. if authz := c.GetHeader("Authorization"); authz != "" {
  694. parts := strings.Split(authz, ":")
  695. if len(parts) != 3 {
  696. fmt.Println("failed at lenParts")
  697. return false
  698. }
  699. clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
  700. if err != nil {
  701. fmt.Println("failed at parseAuthorizedKey")
  702. return false
  703. }
  704. // partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
  705. requestData, err := base64.StdEncoding.DecodeString(parts[1])
  706. if err != nil {
  707. fmt.Println("failed at decodeString")
  708. return false
  709. }
  710. partialRequestDataParts := strings.Split(string(requestData), ",")
  711. if len(partialRequestDataParts) != 3 {
  712. fmt.Println("failed at lenPartialRequestDataParts")
  713. return false
  714. }
  715. /* timestamp, err := strconv.ParseInt(partialRequestDataParts[2], 10, 0)
  716. if err != nil {
  717. return false
  718. }
  719. t := time.Unix(timestamp, 0)
  720. if time.Since(t) > 5*time.Minute || time.Until(t) > 5*time.Minute {
  721. // token is invalid if timestamp +/- 5 minutes from current time
  722. return false
  723. } */
  724. /* nonce := partialRequestDataParts[3]
  725. if nonceCache.has(nonce) {
  726. return false
  727. }
  728. nonceCache.add(nonce, 5*time.Minute) */
  729. signature, err := base64.StdEncoding.DecodeString(parts[2])
  730. if err != nil {
  731. fmt.Println("failed at decodeString stdEncoding")
  732. return false
  733. }
  734. if err := clientPublicKey.Verify([]byte(requestData), &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
  735. fmt.Println("failed at verify")
  736. fmt.Println(err)
  737. return false
  738. }
  739. serverPublicKey, err := auth.GetPublicKey()
  740. if err != nil {
  741. fmt.Println("failed at getPublicKey")
  742. log.Fatal(err)
  743. }
  744. if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
  745. fmt.Println("true")
  746. return true
  747. }
  748. c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
  749. return false
  750. }
  751. return false
  752. }
  753. func isLocalIP(ip netip.Addr) bool {
  754. if interfaces, err := net.Interfaces(); err == nil {
  755. for _, iface := range interfaces {
  756. addrs, err := iface.Addrs()
  757. if err != nil {
  758. continue
  759. }
  760. for _, a := range addrs {
  761. if parsed, _, err := net.ParseCIDR(a.String()); err == nil {
  762. if parsed.String() == ip.String() {
  763. return true
  764. }
  765. }
  766. }
  767. }
  768. }
  769. return false
  770. }
  771. func allowedHost(host string) bool {
  772. if host == "" || host == "localhost" {
  773. return true
  774. }
  775. if hostname, err := os.Hostname(); err == nil && host == hostname {
  776. return true
  777. }
  778. var tlds = []string{
  779. "localhost",
  780. "local",
  781. "internal",
  782. }
  783. // check if the host is a local TLD
  784. for _, tld := range tlds {
  785. if strings.HasSuffix(host, "."+tld) {
  786. return true
  787. }
  788. }
  789. return false
  790. }
  791. func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
  792. return func(c *gin.Context) {
  793. if addr == nil {
  794. c.Next()
  795. return
  796. }
  797. if addr, err := netip.ParseAddrPort(addr.String()); err == nil && !addr.Addr().IsLoopback() {
  798. c.Next()
  799. return
  800. }
  801. host, _, err := net.SplitHostPort(c.Request.Host)
  802. if err != nil {
  803. host = c.Request.Host
  804. }
  805. if addr, err := netip.ParseAddr(host); err == nil {
  806. if addr.IsLoopback() || addr.IsPrivate() || addr.IsUnspecified() || isLocalIP(addr) {
  807. c.Next()
  808. return
  809. }
  810. }
  811. if allowedHost(host) {
  812. if c.Request.Method == http.MethodOptions {
  813. c.AbortWithStatus(http.StatusNoContent)
  814. return
  815. }
  816. c.Next()
  817. return
  818. }
  819. c.AbortWithStatus(http.StatusForbidden)
  820. }
  821. }
  822. func (s *Server) GenerateRoutes() http.Handler {
  823. config := cors.DefaultConfig()
  824. config.AllowWildcard = true
  825. config.AllowBrowserExtensions = true
  826. config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
  827. openAIProperties := []string{"lang", "package-version", "os", "arch", "runtime", "runtime-version", "async"}
  828. for _, prop := range openAIProperties {
  829. config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
  830. }
  831. config.AllowOrigins = envconfig.AllowOrigins
  832. r := gin.Default()
  833. r.Use(
  834. cors.New(config),
  835. allowedHostsMiddleware(s.addr),
  836. )
  837. r.POST("/api/pull", s.PullModelHandler)
  838. r.POST("/api/generate", s.GenerateHandler)
  839. r.POST("/api/chat", s.ChatHandler)
  840. r.POST("/api/embeddings", s.EmbeddingsHandler)
  841. r.POST("/api/create", s.CreateModelHandler)
  842. r.POST("/api/push", s.PushModelHandler)
  843. r.POST("/api/copy", s.CopyModelHandler)
  844. r.DELETE("/api/delete", s.DeleteModelHandler)
  845. r.POST("/api/show", s.ShowModelHandler)
  846. r.POST("/api/blobs/:digest", s.CreateBlobHandler)
  847. r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
  848. r.GET("/api/ps", s.ProcessHandler)
  849. // Compatibility endpoints
  850. r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
  851. r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
  852. r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
  853. r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
  854. for _, method := range []string{http.MethodGet, http.MethodHead} {
  855. r.Handle(method, "/", func(c *gin.Context) {
  856. c.String(http.StatusOK, "Ollama is running")
  857. })
  858. r.Handle(method, "/api/tags", s.ListModelsHandler)
  859. r.Handle(method, "/api/version", func(c *gin.Context) {
  860. c.JSON(http.StatusOK, gin.H{"version": version.Version})
  861. })
  862. }
  863. return r
  864. }
  865. func Serve(ln net.Listener) error {
  866. level := slog.LevelInfo
  867. if envconfig.Debug {
  868. level = slog.LevelDebug
  869. }
  870. slog.Info("server config", "env", envconfig.Values())
  871. handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
  872. Level: level,
  873. AddSource: true,
  874. ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
  875. if attr.Key == slog.SourceKey {
  876. source := attr.Value.Any().(*slog.Source)
  877. source.File = filepath.Base(source.File)
  878. }
  879. return attr
  880. },
  881. })
  882. slog.SetDefault(slog.New(handler))
  883. blobsDir, err := GetBlobsPath("")
  884. if err != nil {
  885. return err
  886. }
  887. if err := fixBlobs(blobsDir); err != nil {
  888. return err
  889. }
  890. if !envconfig.NoPrune {
  891. // clean up unused layers and manifests
  892. if err := PruneLayers(); err != nil {
  893. return err
  894. }
  895. manifestsPath, err := GetManifestPath()
  896. if err != nil {
  897. return err
  898. }
  899. if err := PruneDirectory(manifestsPath); err != nil {
  900. return err
  901. }
  902. }
  903. ctx, done := context.WithCancel(context.Background())
  904. schedCtx, schedDone := context.WithCancel(ctx)
  905. sched := InitScheduler(schedCtx)
  906. s := &Server{addr: ln.Addr(), sched: sched}
  907. http.Handle("/", s.GenerateRoutes())
  908. slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
  909. srvr := &http.Server{
  910. // Use http.DefaultServeMux so we get net/http/pprof for
  911. // free.
  912. //
  913. // TODO(bmizerany): Decide if we want to make this
  914. // configurable so it is not exposed by default, or allow
  915. // users to bind it to a different port. This was a quick
  916. // and easy way to get pprof, but it may not be the best
  917. // way.
  918. Handler: nil,
  919. }
  920. // listen for a ctrl+c and stop any loaded llm
  921. signals := make(chan os.Signal, 1)
  922. signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
  923. go func() {
  924. <-signals
  925. srvr.Close()
  926. schedDone()
  927. sched.unloadAllRunners()
  928. gpu.Cleanup()
  929. done()
  930. }()
  931. if err := llm.Init(); err != nil {
  932. return fmt.Errorf("unable to initialize llm library %w", err)
  933. }
  934. s.sched.Run(schedCtx)
  935. // At startup we retrieve GPU information so we can get log messages before loading a model
  936. // This will log warnings to the log in case we have problems with detected GPUs
  937. gpus := gpu.GetGPUInfo()
  938. gpus.LogDetails()
  939. err = srvr.Serve(ln)
  940. // If server is closed from the signal handler, wait for the ctx to be done
  941. // otherwise error out quickly
  942. if !errors.Is(err, http.ErrServerClosed) {
  943. return err
  944. }
  945. <-ctx.Done()
  946. return nil
  947. }
  948. func waitForStream(c *gin.Context, ch chan interface{}) {
  949. c.Header("Content-Type", "application/json")
  950. for resp := range ch {
  951. switch r := resp.(type) {
  952. case api.ProgressResponse:
  953. if r.Status == "success" {
  954. c.JSON(http.StatusOK, r)
  955. return
  956. }
  957. case gin.H:
  958. if errorMsg, ok := r["error"].(string); ok {
  959. c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
  960. return
  961. } else {
  962. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
  963. return
  964. }
  965. default:
  966. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"})
  967. return
  968. }
  969. }
  970. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected end of progress response"})
  971. }
  972. func streamResponse(c *gin.Context, ch chan any) {
  973. c.Header("Content-Type", "application/x-ndjson")
  974. c.Stream(func(w io.Writer) bool {
  975. val, ok := <-ch
  976. if !ok {
  977. return false
  978. }
  979. bts, err := json.Marshal(val)
  980. if err != nil {
  981. slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
  982. return false
  983. }
  984. // Delineate chunks with new-line delimiter
  985. bts = append(bts, '\n')
  986. if _, err := w.Write(bts); err != nil {
  987. slog.Info(fmt.Sprintf("streamResponse: w.Write failed with %s", err))
  988. return false
  989. }
  990. return true
  991. })
  992. }
  993. func (s *Server) ProcessHandler(c *gin.Context) {
  994. models := []api.ProcessModelResponse{}
  995. for _, v := range s.sched.loaded {
  996. model := v.model
  997. modelDetails := api.ModelDetails{
  998. Format: model.Config.ModelFormat,
  999. Family: model.Config.ModelFamily,
  1000. Families: model.Config.ModelFamilies,
  1001. ParameterSize: model.Config.ModelType,
  1002. QuantizationLevel: model.Config.FileType,
  1003. }
  1004. mr := api.ProcessModelResponse{
  1005. Model: model.ShortName,
  1006. Name: model.ShortName,
  1007. Size: int64(v.estimatedTotal),
  1008. SizeVRAM: int64(v.estimatedVRAM),
  1009. Digest: model.Digest,
  1010. Details: modelDetails,
  1011. ExpiresAt: v.expiresAt,
  1012. }
  1013. // The scheduler waits to set expiresAt, so if a model is loading it's
  1014. // possible that it will be set to the unix epoch. For those cases, just
  1015. // calculate the time w/ the sessionDuration instead.
  1016. var epoch time.Time
  1017. if v.expiresAt == epoch {
  1018. mr.ExpiresAt = time.Now().Add(v.sessionDuration)
  1019. }
  1020. models = append(models, mr)
  1021. }
  1022. slices.SortStableFunc(models, func(i, j api.ProcessModelResponse) int {
  1023. // longest duration remaining listed first
  1024. return cmp.Compare(j.ExpiresAt.Unix(), i.ExpiresAt.Unix())
  1025. })
  1026. c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
  1027. }
  1028. func (s *Server) ChatHandler(c *gin.Context) {
  1029. var req api.ChatRequest
  1030. if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
  1031. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  1032. return
  1033. } else if err != nil {
  1034. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  1035. return
  1036. }
  1037. caps := []Capability{CapabilityCompletion}
  1038. r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
  1039. if errors.Is(err, errCapabilityCompletion) {
  1040. c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
  1041. return
  1042. } else if err != nil {
  1043. handleScheduleError(c, req.Model, err)
  1044. return
  1045. }
  1046. if len(req.Messages) == 0 {
  1047. c.JSON(http.StatusOK, api.ChatResponse{
  1048. Model: req.Model,
  1049. CreatedAt: time.Now().UTC(),
  1050. Message: api.Message{Role: "assistant"},
  1051. Done: true,
  1052. DoneReason: "load",
  1053. })
  1054. return
  1055. }
  1056. prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
  1057. if err != nil {
  1058. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  1059. return
  1060. }
  1061. slog.Debug("chat request", "images", len(images), "prompt", prompt)
  1062. ch := make(chan any)
  1063. go func() {
  1064. defer close(ch)
  1065. if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
  1066. Prompt: prompt,
  1067. Images: images,
  1068. Format: req.Format,
  1069. Options: opts,
  1070. }, func(r llm.CompletionResponse) {
  1071. ch <- api.ChatResponse{
  1072. Model: req.Model,
  1073. CreatedAt: time.Now().UTC(),
  1074. Message: api.Message{Role: "assistant", Content: r.Content},
  1075. Done: r.Done,
  1076. DoneReason: r.DoneReason,
  1077. Metrics: api.Metrics{
  1078. PromptEvalCount: r.PromptEvalCount,
  1079. PromptEvalDuration: r.PromptEvalDuration,
  1080. EvalCount: r.EvalCount,
  1081. EvalDuration: r.EvalDuration,
  1082. },
  1083. }
  1084. }); err != nil {
  1085. ch <- gin.H{"error": err.Error()}
  1086. }
  1087. }()
  1088. if req.Stream != nil && !*req.Stream {
  1089. var r api.ChatResponse
  1090. var sb strings.Builder
  1091. for rr := range ch {
  1092. switch t := rr.(type) {
  1093. case api.ChatResponse:
  1094. sb.WriteString(t.Message.Content)
  1095. r = t
  1096. case gin.H:
  1097. msg, ok := t["error"].(string)
  1098. if !ok {
  1099. msg = "unexpected error format in response"
  1100. }
  1101. c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
  1102. return
  1103. default:
  1104. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
  1105. return
  1106. }
  1107. }
  1108. r.Message.Content = sb.String()
  1109. c.JSON(http.StatusOK, r)
  1110. return
  1111. }
  1112. streamResponse(c, ch)
  1113. }
  1114. func handleScheduleError(c *gin.Context, name string, err error) {
  1115. switch {
  1116. case errors.Is(err, errRequired):
  1117. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  1118. case errors.Is(err, context.Canceled):
  1119. c.JSON(499, gin.H{"error": "request canceled"})
  1120. case errors.Is(err, ErrMaxQueue):
  1121. c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
  1122. case errors.Is(err, os.ErrNotExist):
  1123. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
  1124. default:
  1125. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  1126. }
  1127. }