routes.go 33 KB


  1. package server
  2. import (
  3. "bytes"
  4. "cmp"
  5. "context"
  6. "encoding/base64"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "log"
  12. "log/slog"
  13. "net"
  14. "net/http"
  15. "net/netip"
  16. "os"
  17. "os/signal"
  18. "path/filepath"
  19. "slices"
  20. "strings"
  21. "syscall"
  22. "time"
  23. "github.com/gin-contrib/cors"
  24. "github.com/gin-gonic/gin"
  25. "golang.org/x/crypto/ssh"
  26. "github.com/ollama/ollama/api"
  27. "github.com/ollama/ollama/auth"
  28. "github.com/ollama/ollama/envconfig"
  29. "github.com/ollama/ollama/gpu"
  30. "github.com/ollama/ollama/llm"
  31. "github.com/ollama/ollama/openai"
  32. "github.com/ollama/ollama/parser"
  33. "github.com/ollama/ollama/template"
  34. "github.com/ollama/ollama/types/errtypes"
  35. "github.com/ollama/ollama/types/model"
  36. "github.com/ollama/ollama/version"
  37. )
  38. var mode string = gin.DebugMode
  39. type Server struct {
  40. addr net.Addr
  41. sched *Scheduler
  42. }
  43. func init() {
  44. switch mode {
  45. case gin.DebugMode:
  46. case gin.ReleaseMode:
  47. case gin.TestMode:
  48. default:
  49. mode = gin.DebugMode
  50. }
  51. gin.SetMode(mode)
  52. }
  53. var errRequired = errors.New("is required")
  54. func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
  55. opts := api.DefaultOptions()
  56. if err := opts.FromMap(model.Options); err != nil {
  57. return api.Options{}, err
  58. }
  59. if err := opts.FromMap(requestOpts); err != nil {
  60. return api.Options{}, err
  61. }
  62. return opts, nil
  63. }
  64. // scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
  65. // It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
  66. func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
  67. if name == "" {
  68. return nil, nil, nil, fmt.Errorf("model %w", errRequired)
  69. }
  70. model, err := GetModel(name)
  71. if err != nil {
  72. return nil, nil, nil, err
  73. }
  74. if err := model.CheckCapabilities(caps...); err != nil {
  75. return nil, nil, nil, fmt.Errorf("%s %w", name, err)
  76. }
  77. opts, err := modelOptions(model, requestOpts)
  78. if err != nil {
  79. return nil, nil, nil, err
  80. }
  81. runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
  82. var runner *runnerRef
  83. select {
  84. case runner = <-runnerCh:
  85. case err = <-errCh:
  86. return nil, nil, nil, err
  87. }
  88. return runner.llama, model, &opts, nil
  89. }
  90. func (s *Server) GenerateHandler(c *gin.Context) {
  91. var req api.GenerateRequest
  92. if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
  93. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  94. return
  95. } else if err != nil {
  96. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  97. return
  98. }
  99. if req.Format != "" && req.Format != "json" {
  100. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
  101. return
  102. } else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
  103. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
  104. return
  105. }
  106. caps := []Capability{CapabilityCompletion}
  107. r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
  108. if errors.Is(err, errCapabilityCompletion) {
  109. c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
  110. return
  111. } else if err != nil {
  112. handleScheduleError(c, req.Model, err)
  113. return
  114. }
  115. if req.Prompt == "" {
  116. c.JSON(http.StatusOK, api.GenerateResponse{
  117. Model: req.Model,
  118. CreatedAt: time.Now().UTC(),
  119. Done: true,
  120. DoneReason: "load",
  121. })
  122. return
  123. }
  124. images := make([]llm.ImageData, len(req.Images))
  125. for i := range req.Images {
  126. images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
  127. }
  128. prompt := req.Prompt
  129. if !req.Raw {
  130. var msgs []api.Message
  131. if req.System != "" {
  132. msgs = append(msgs, api.Message{Role: "system", Content: req.System})
  133. } else if m.System != "" {
  134. msgs = append(msgs, api.Message{Role: "system", Content: m.System})
  135. }
  136. for _, i := range images {
  137. msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
  138. }
  139. msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
  140. tmpl := m.Template
  141. if req.Template != "" {
  142. tmpl, err = template.Parse(req.Template)
  143. if err != nil {
  144. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  145. return
  146. }
  147. }
  148. var b bytes.Buffer
  149. if req.Context != nil {
  150. s, err := r.Detokenize(c.Request.Context(), req.Context)
  151. if err != nil {
  152. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  153. return
  154. }
  155. b.WriteString(s)
  156. }
  157. if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
  158. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  159. return
  160. }
  161. prompt = b.String()
  162. }
  163. slog.Debug("generate request", "prompt", prompt, "images", images)
  164. ch := make(chan any)
  165. go func() {
  166. defer close(ch)
  167. if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
  168. Prompt: prompt,
  169. Images: images,
  170. Format: req.Format,
  171. Options: opts,
  172. }, func(r llm.CompletionResponse) {
  173. ch <- api.GenerateResponse{
  174. Model: req.Model,
  175. CreatedAt: time.Now().UTC(),
  176. Response: r.Content,
  177. Done: r.Done,
  178. DoneReason: r.DoneReason,
  179. Metrics: api.Metrics{
  180. PromptEvalCount: r.PromptEvalCount,
  181. PromptEvalDuration: r.PromptEvalDuration,
  182. EvalCount: r.EvalCount,
  183. EvalDuration: r.EvalDuration,
  184. },
  185. }
  186. }); err != nil {
  187. ch <- gin.H{"error": err.Error()}
  188. }
  189. }()
  190. if req.Stream != nil && !*req.Stream {
  191. var r api.GenerateResponse
  192. var sb strings.Builder
  193. for rr := range ch {
  194. switch t := rr.(type) {
  195. case api.GenerateResponse:
  196. sb.WriteString(t.Response)
  197. r = t
  198. case gin.H:
  199. msg, ok := t["error"].(string)
  200. if !ok {
  201. msg = "unexpected error format in response"
  202. }
  203. c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
  204. return
  205. default:
  206. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
  207. return
  208. }
  209. }
  210. r.Response = sb.String()
  211. c.JSON(http.StatusOK, r)
  212. return
  213. }
  214. streamResponse(c, ch)
  215. }
  216. func (s *Server) EmbeddingsHandler(c *gin.Context) {
  217. var req api.EmbeddingRequest
  218. if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
  219. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  220. return
  221. } else if err != nil {
  222. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  223. return
  224. }
  225. r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
  226. if err != nil {
  227. handleScheduleError(c, req.Model, err)
  228. return
  229. }
  230. // an empty request loads the model
  231. if req.Prompt == "" {
  232. c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
  233. return
  234. }
  235. embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
  236. if err != nil {
  237. slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
  238. c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
  239. return
  240. }
  241. c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
  242. }
  243. func (s *Server) PullModelHandler(c *gin.Context) {
  244. var req api.PullRequest
  245. err := c.ShouldBindJSON(&req)
  246. switch {
  247. case errors.Is(err, io.EOF):
  248. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  249. return
  250. case err != nil:
  251. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  252. return
  253. }
  254. name := model.ParseName(cmp.Or(req.Model, req.Name))
  255. if !name.IsValid() {
  256. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
  257. return
  258. }
  259. if err := checkNameExists(name); err != nil {
  260. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  261. return
  262. }
  263. ch := make(chan any)
  264. go func() {
  265. defer close(ch)
  266. fn := func(r api.ProgressResponse) {
  267. ch <- r
  268. }
  269. regOpts := &registryOptions{
  270. Insecure: req.Insecure,
  271. }
  272. ctx, cancel := context.WithCancel(c.Request.Context())
  273. defer cancel()
  274. if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
  275. ch <- gin.H{"error": err.Error()}
  276. }
  277. }()
  278. if req.Stream != nil && !*req.Stream {
  279. waitForStream(c, ch)
  280. return
  281. }
  282. streamResponse(c, ch)
  283. }
  284. func (s *Server) PushModelHandler(c *gin.Context) {
  285. var req api.PushRequest
  286. err := c.ShouldBindJSON(&req)
  287. switch {
  288. case errors.Is(err, io.EOF):
  289. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  290. return
  291. case err != nil:
  292. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  293. return
  294. }
  295. var model string
  296. if req.Model != "" {
  297. model = req.Model
  298. } else if req.Name != "" {
  299. model = req.Name
  300. } else {
  301. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  302. return
  303. }
  304. ch := make(chan any)
  305. go func() {
  306. defer close(ch)
  307. fn := func(r api.ProgressResponse) {
  308. ch <- r
  309. }
  310. regOpts := &registryOptions{
  311. Insecure: req.Insecure,
  312. }
  313. ctx, cancel := context.WithCancel(c.Request.Context())
  314. defer cancel()
  315. if err := PushModel(ctx, model, regOpts, fn); err != nil {
  316. ch <- gin.H{"error": err.Error()}
  317. }
  318. }()
  319. if req.Stream != nil && !*req.Stream {
  320. waitForStream(c, ch)
  321. return
  322. }
  323. streamResponse(c, ch)
  324. }
  325. func checkNameExists(name model.Name) error {
  326. names, err := Manifests()
  327. if err != nil {
  328. return err
  329. }
  330. for n := range names {
  331. if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
  332. return fmt.Errorf("a model with that name already exists")
  333. }
  334. }
  335. return nil
  336. }
  337. func (s *Server) CreateModelHandler(c *gin.Context) {
  338. var r api.CreateRequest
  339. if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
  340. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  341. return
  342. } else if err != nil {
  343. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  344. return
  345. }
  346. name := model.ParseName(cmp.Or(r.Model, r.Name))
  347. if !name.IsValid() {
  348. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
  349. return
  350. }
  351. if err := checkNameExists(name); err != nil {
  352. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  353. return
  354. }
  355. if r.Path == "" && r.Modelfile == "" {
  356. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
  357. return
  358. }
  359. var sr io.Reader = strings.NewReader(r.Modelfile)
  360. if r.Path != "" && r.Modelfile == "" {
  361. f, err := os.Open(r.Path)
  362. if err != nil {
  363. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
  364. return
  365. }
  366. defer f.Close()
  367. sr = f
  368. }
  369. f, err := parser.ParseFile(sr)
  370. if err != nil {
  371. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  372. return
  373. }
  374. ch := make(chan any)
  375. go func() {
  376. defer close(ch)
  377. fn := func(resp api.ProgressResponse) {
  378. ch <- resp
  379. }
  380. ctx, cancel := context.WithCancel(c.Request.Context())
  381. defer cancel()
  382. quantization := cmp.Or(r.Quantize, r.Quantization)
  383. if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
  384. ch <- gin.H{"error": err.Error()}
  385. }
  386. }()
  387. if r.Stream != nil && !*r.Stream {
  388. waitForStream(c, ch)
  389. return
  390. }
  391. streamResponse(c, ch)
  392. }
  393. func (s *Server) DeleteModelHandler(c *gin.Context) {
  394. var r api.DeleteRequest
  395. if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
  396. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  397. return
  398. } else if err != nil {
  399. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  400. return
  401. }
  402. n := model.ParseName(cmp.Or(r.Model, r.Name))
  403. if !n.IsValid() {
  404. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
  405. return
  406. }
  407. m, err := ParseNamedManifest(n)
  408. if err != nil {
  409. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  410. return
  411. }
  412. if err := m.Remove(); err != nil {
  413. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  414. return
  415. }
  416. if err := m.RemoveLayers(); err != nil {
  417. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  418. return
  419. }
  420. }
  421. func (s *Server) ShowModelHandler(c *gin.Context) {
  422. var req api.ShowRequest
  423. err := c.ShouldBindJSON(&req)
  424. switch {
  425. case errors.Is(err, io.EOF):
  426. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  427. return
  428. case err != nil:
  429. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  430. return
  431. }
  432. if req.Model != "" {
  433. // noop
  434. } else if req.Name != "" {
  435. req.Model = req.Name
  436. } else {
  437. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  438. return
  439. }
  440. resp, err := GetModelInfo(req)
  441. if err != nil {
  442. switch {
  443. case os.IsNotExist(err):
  444. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
  445. case err.Error() == "invalid model name":
  446. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  447. default:
  448. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  449. }
  450. return
  451. }
  452. c.JSON(http.StatusOK, resp)
  453. }
  454. func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
  455. m, err := GetModel(req.Model)
  456. if err != nil {
  457. return nil, err
  458. }
  459. modelDetails := api.ModelDetails{
  460. ParentModel: m.ParentModel,
  461. Format: m.Config.ModelFormat,
  462. Family: m.Config.ModelFamily,
  463. Families: m.Config.ModelFamilies,
  464. ParameterSize: m.Config.ModelType,
  465. QuantizationLevel: m.Config.FileType,
  466. }
  467. if req.System != "" {
  468. m.System = req.System
  469. }
  470. if req.Template != "" {
  471. m.Template, err = template.Parse(req.Template)
  472. if err != nil {
  473. return nil, err
  474. }
  475. }
  476. msgs := make([]api.Message, len(m.Messages))
  477. for i, msg := range m.Messages {
  478. msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
  479. }
  480. n := model.ParseName(req.Model)
  481. if !n.IsValid() {
  482. return nil, fmt.Errorf("invalid model name")
  483. }
  484. manifest, err := ParseNamedManifest(n)
  485. if err != nil {
  486. return nil, err
  487. }
  488. resp := &api.ShowResponse{
  489. License: strings.Join(m.License, "\n"),
  490. System: m.System,
  491. Template: m.Template.String(),
  492. Details: modelDetails,
  493. Messages: msgs,
  494. ModifiedAt: manifest.fi.ModTime(),
  495. }
  496. var params []string
  497. cs := 30
  498. for k, v := range m.Options {
  499. switch val := v.(type) {
  500. case []interface{}:
  501. for _, nv := range val {
  502. params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
  503. }
  504. default:
  505. params = append(params, fmt.Sprintf("%-*s %#v", cs, k, v))
  506. }
  507. }
  508. resp.Parameters = strings.Join(params, "\n")
  509. for k, v := range req.Options {
  510. if _, ok := req.Options[k]; ok {
  511. m.Options[k] = v
  512. }
  513. }
  514. var sb strings.Builder
  515. fmt.Fprintln(&sb, "# Modelfile generated by \"ollama show\"")
  516. fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
  517. fmt.Fprintf(&sb, "# FROM %s\n\n", m.ShortName)
  518. fmt.Fprint(&sb, m.String())
  519. resp.Modelfile = sb.String()
  520. kvData, err := getKVData(m.ModelPath, req.Verbose)
  521. if err != nil {
  522. return nil, err
  523. }
  524. delete(kvData, "general.name")
  525. delete(kvData, "tokenizer.chat_template")
  526. resp.ModelInfo = kvData
  527. if len(m.ProjectorPaths) > 0 {
  528. projectorData, err := getKVData(m.ProjectorPaths[0], req.Verbose)
  529. if err != nil {
  530. return nil, err
  531. }
  532. resp.ProjectorInfo = projectorData
  533. }
  534. return resp, nil
  535. }
  536. func getKVData(digest string, verbose bool) (llm.KV, error) {
  537. maxArraySize := 0
  538. if verbose {
  539. maxArraySize = -1
  540. }
  541. kvData, err := llm.LoadModel(digest, maxArraySize)
  542. if err != nil {
  543. return nil, err
  544. }
  545. kv := kvData.KV()
  546. if !verbose {
  547. for k := range kv {
  548. if t, ok := kv[k].([]any); len(t) > 5 && ok {
  549. kv[k] = []any{}
  550. }
  551. }
  552. }
  553. return kv, nil
  554. }
  555. func (s *Server) ListModelsHandler(c *gin.Context) {
  556. ms, err := Manifests()
  557. if err != nil {
  558. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  559. return
  560. }
  561. models := []api.ListModelResponse{}
  562. for n, m := range ms {
  563. f, err := m.Config.Open()
  564. if err != nil {
  565. slog.Warn("bad manifest filepath", "name", n, "error", err)
  566. continue
  567. }
  568. defer f.Close()
  569. var cf ConfigV2
  570. if err := json.NewDecoder(f).Decode(&cf); err != nil {
  571. slog.Warn("bad manifest config", "name", n, "error", err)
  572. continue
  573. }
  574. // tag should never be masked
  575. models = append(models, api.ListModelResponse{
  576. Model: n.DisplayShortest(),
  577. Name: n.DisplayShortest(),
  578. Size: m.Size(),
  579. Digest: m.digest,
  580. ModifiedAt: m.fi.ModTime(),
  581. Details: api.ModelDetails{
  582. Format: cf.ModelFormat,
  583. Family: cf.ModelFamily,
  584. Families: cf.ModelFamilies,
  585. ParameterSize: cf.ModelType,
  586. QuantizationLevel: cf.FileType,
  587. },
  588. })
  589. }
  590. slices.SortStableFunc(models, func(i, j api.ListModelResponse) int {
  591. // most recently modified first
  592. return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
  593. })
  594. c.JSON(http.StatusOK, api.ListResponse{Models: models})
  595. }
  596. func (s *Server) CopyModelHandler(c *gin.Context) {
  597. var r api.CopyRequest
  598. if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
  599. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  600. return
  601. } else if err != nil {
  602. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  603. return
  604. }
  605. src := model.ParseName(r.Source)
  606. if !src.IsValid() {
  607. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)})
  608. return
  609. }
  610. dst := model.ParseName(r.Destination)
  611. if !dst.IsValid() {
  612. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)})
  613. return
  614. }
  615. if err := checkNameExists(dst); err != nil {
  616. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  617. return
  618. }
  619. if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
  620. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
  621. } else if err != nil {
  622. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  623. }
  624. }
  625. func (s *Server) HeadBlobHandler(c *gin.Context) {
  626. path, err := GetBlobsPath(c.Param("digest"))
  627. if err != nil {
  628. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  629. return
  630. }
  631. if _, err := os.Stat(path); err != nil {
  632. c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
  633. return
  634. }
  635. c.Status(http.StatusOK)
  636. }
  637. func (s *Server) CreateBlobHandler(c *gin.Context) {
  638. if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
  639. p, err := GetBlobsPath(ib)
  640. if err != nil {
  641. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  642. return
  643. }
  644. if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) {
  645. slog.Info("evicting intermediate blob which no longer exists", "digest", ib)
  646. delete(intermediateBlobs, c.Param("digest"))
  647. } else if err != nil {
  648. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  649. return
  650. } else {
  651. c.Status(http.StatusOK)
  652. return
  653. }
  654. }
  655. fmt.Println("path2", c.Param("digest"))
  656. path, err := GetBlobsPath(c.Param("digest"))
  657. if err != nil {
  658. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  659. return
  660. }
  661. fmt.Println("path1", path)
  662. _, err = os.Stat(path)
  663. switch {
  664. case errors.Is(err, os.ErrNotExist):
  665. // noop
  666. case err != nil:
  667. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  668. return
  669. default:
  670. c.Status(http.StatusOK)
  671. return
  672. }
  673. fmt.Println("hello")
  674. fmt.Println(s.IsLocal(c))
  675. if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) {
  676. fmt.Println("entered redirect")
  677. c.Header("LocalLocation", path)
  678. c.Status(http.StatusTemporaryRedirect)
  679. return
  680. }
  681. layer, err := NewLayer(c.Request.Body, "")
  682. if err != nil {
  683. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  684. return
  685. }
  686. if layer.Digest != c.Param("digest") {
  687. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
  688. return
  689. }
  690. c.Status(http.StatusCreated)
  691. }
  692. func (s *Server) IsLocal(c *gin.Context) bool {
  693. fmt.Println("entered islocal")
  694. fmt.Println(c.GetHeader("Authorization"), " is authorization")
  695. if authz := c.GetHeader("Authorization"); authz != "" {
  696. parts := strings.Split(authz, ":")
  697. if len(parts) != 3 {
  698. fmt.Println("failed at lenParts")
  699. return false
  700. }
  701. clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
  702. if err != nil {
  703. fmt.Println("failed at parseAuthorizedKey")
  704. return false
  705. }
  706. // partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
  707. requestData, err := base64.StdEncoding.DecodeString(parts[1])
  708. if err != nil {
  709. fmt.Println("failed at decodeString")
  710. return false
  711. }
  712. partialRequestDataParts := strings.Split(string(requestData), ",")
  713. if len(partialRequestDataParts) != 3 {
  714. fmt.Println("failed at lenPartialRequestDataParts")
  715. return false
  716. }
  717. /* timestamp, err := strconv.ParseInt(partialRequestDataParts[2], 10, 0)
  718. if err != nil {
  719. return false
  720. }
  721. t := time.Unix(timestamp, 0)
  722. if time.Since(t) > 5*time.Minute || time.Until(t) > 5*time.Minute {
  723. // token is invalid if timestamp +/- 5 minutes from current time
  724. return false
  725. } */
  726. /* nonce := partialRequestDataParts[3]
  727. if nonceCache.has(nonce) {
  728. return false
  729. }
  730. nonceCache.add(nonce, 5*time.Minute) */
  731. signature, err := base64.StdEncoding.DecodeString(parts[2])
  732. if err != nil {
  733. fmt.Println("failed at decodeString stdEncoding")
  734. return false
  735. }
  736. if err := clientPublicKey.Verify([]byte(requestData), &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
  737. fmt.Println("failed at verify")
  738. fmt.Println(err)
  739. return false
  740. }
  741. serverPublicKey, err := auth.GetPublicKey()
  742. if err != nil {
  743. fmt.Println("failed at getPublicKey")
  744. log.Fatal(err)
  745. }
  746. if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
  747. fmt.Println("true")
  748. return true
  749. }
  750. c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
  751. return false
  752. }
  753. return false
  754. }
  755. func isLocalIP(ip netip.Addr) bool {
  756. if interfaces, err := net.Interfaces(); err == nil {
  757. for _, iface := range interfaces {
  758. addrs, err := iface.Addrs()
  759. if err != nil {
  760. continue
  761. }
  762. for _, a := range addrs {
  763. if parsed, _, err := net.ParseCIDR(a.String()); err == nil {
  764. if parsed.String() == ip.String() {
  765. return true
  766. }
  767. }
  768. }
  769. }
  770. }
  771. return false
  772. }
  773. func allowedHost(host string) bool {
  774. if host == "" || host == "localhost" {
  775. return true
  776. }
  777. if hostname, err := os.Hostname(); err == nil && host == hostname {
  778. return true
  779. }
  780. var tlds = []string{
  781. "localhost",
  782. "local",
  783. "internal",
  784. }
  785. // check if the host is a local TLD
  786. for _, tld := range tlds {
  787. if strings.HasSuffix(host, "."+tld) {
  788. return true
  789. }
  790. }
  791. return false
  792. }
  793. func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
  794. return func(c *gin.Context) {
  795. if addr == nil {
  796. c.Next()
  797. return
  798. }
  799. if addr, err := netip.ParseAddrPort(addr.String()); err == nil && !addr.Addr().IsLoopback() {
  800. c.Next()
  801. return
  802. }
  803. host, _, err := net.SplitHostPort(c.Request.Host)
  804. if err != nil {
  805. host = c.Request.Host
  806. }
  807. if addr, err := netip.ParseAddr(host); err == nil {
  808. if addr.IsLoopback() || addr.IsPrivate() || addr.IsUnspecified() || isLocalIP(addr) {
  809. c.Next()
  810. return
  811. }
  812. }
  813. if allowedHost(host) {
  814. if c.Request.Method == http.MethodOptions {
  815. c.AbortWithStatus(http.StatusNoContent)
  816. return
  817. }
  818. c.Next()
  819. return
  820. }
  821. c.AbortWithStatus(http.StatusForbidden)
  822. }
  823. }
  824. func (s *Server) GenerateRoutes() http.Handler {
  825. config := cors.DefaultConfig()
  826. config.AllowWildcard = true
  827. config.AllowBrowserExtensions = true
  828. config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
  829. openAIProperties := []string{"lang", "package-version", "os", "arch", "runtime", "runtime-version", "async"}
  830. for _, prop := range openAIProperties {
  831. config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
  832. }
  833. config.AllowOrigins = envconfig.AllowOrigins
  834. r := gin.Default()
  835. r.Use(
  836. cors.New(config),
  837. allowedHostsMiddleware(s.addr),
  838. )
  839. r.POST("/api/pull", s.PullModelHandler)
  840. r.POST("/api/generate", s.GenerateHandler)
  841. r.POST("/api/chat", s.ChatHandler)
  842. r.POST("/api/embeddings", s.EmbeddingsHandler)
  843. r.POST("/api/create", s.CreateModelHandler)
  844. r.POST("/api/push", s.PushModelHandler)
  845. r.POST("/api/copy", s.CopyModelHandler)
  846. r.DELETE("/api/delete", s.DeleteModelHandler)
  847. r.POST("/api/show", s.ShowModelHandler)
  848. r.POST("/api/blobs/:digest", s.CreateBlobHandler)
  849. r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
  850. r.GET("/api/ps", s.ProcessHandler)
  851. // Compatibility endpoints
  852. r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
  853. r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
  854. r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
  855. r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
  856. for _, method := range []string{http.MethodGet, http.MethodHead} {
  857. r.Handle(method, "/", func(c *gin.Context) {
  858. c.String(http.StatusOK, "Ollama is running")
  859. })
  860. r.Handle(method, "/api/tags", s.ListModelsHandler)
  861. r.Handle(method, "/api/version", func(c *gin.Context) {
  862. c.JSON(http.StatusOK, gin.H{"version": version.Version})
  863. })
  864. }
  865. return r
  866. }
  867. func Serve(ln net.Listener) error {
  868. level := slog.LevelInfo
  869. if envconfig.Debug {
  870. level = slog.LevelDebug
  871. }
  872. slog.Info("server config", "env", envconfig.Values())
  873. handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
  874. Level: level,
  875. AddSource: true,
  876. ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
  877. if attr.Key == slog.SourceKey {
  878. source := attr.Value.Any().(*slog.Source)
  879. source.File = filepath.Base(source.File)
  880. }
  881. return attr
  882. },
  883. })
  884. slog.SetDefault(slog.New(handler))
  885. blobsDir, err := GetBlobsPath("")
  886. if err != nil {
  887. return err
  888. }
  889. if err := fixBlobs(blobsDir); err != nil {
  890. return err
  891. }
  892. if !envconfig.NoPrune {
  893. // clean up unused layers and manifests
  894. if err := PruneLayers(); err != nil {
  895. return err
  896. }
  897. manifestsPath, err := GetManifestPath()
  898. if err != nil {
  899. return err
  900. }
  901. if err := PruneDirectory(manifestsPath); err != nil {
  902. return err
  903. }
  904. }
  905. ctx, done := context.WithCancel(context.Background())
  906. schedCtx, schedDone := context.WithCancel(ctx)
  907. sched := InitScheduler(schedCtx)
  908. s := &Server{addr: ln.Addr(), sched: sched}
  909. http.Handle("/", s.GenerateRoutes())
  910. slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
  911. srvr := &http.Server{
  912. // Use http.DefaultServeMux so we get net/http/pprof for
  913. // free.
  914. //
  915. // TODO(bmizerany): Decide if we want to make this
  916. // configurable so it is not exposed by default, or allow
  917. // users to bind it to a different port. This was a quick
  918. // and easy way to get pprof, but it may not be the best
  919. // way.
  920. Handler: nil,
  921. }
  922. // listen for a ctrl+c and stop any loaded llm
  923. signals := make(chan os.Signal, 1)
  924. signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
  925. go func() {
  926. <-signals
  927. srvr.Close()
  928. schedDone()
  929. sched.unloadAllRunners()
  930. gpu.Cleanup()
  931. done()
  932. }()
  933. if err := llm.Init(); err != nil {
  934. return fmt.Errorf("unable to initialize llm library %w", err)
  935. }
  936. s.sched.Run(schedCtx)
  937. // At startup we retrieve GPU information so we can get log messages before loading a model
  938. // This will log warnings to the log in case we have problems with detected GPUs
  939. gpus := gpu.GetGPUInfo()
  940. gpus.LogDetails()
  941. err = srvr.Serve(ln)
  942. // If server is closed from the signal handler, wait for the ctx to be done
  943. // otherwise error out quickly
  944. if !errors.Is(err, http.ErrServerClosed) {
  945. return err
  946. }
  947. <-ctx.Done()
  948. return nil
  949. }
  950. func waitForStream(c *gin.Context, ch chan interface{}) {
  951. c.Header("Content-Type", "application/json")
  952. for resp := range ch {
  953. switch r := resp.(type) {
  954. case api.ProgressResponse:
  955. if r.Status == "success" {
  956. c.JSON(http.StatusOK, r)
  957. return
  958. }
  959. case gin.H:
  960. if errorMsg, ok := r["error"].(string); ok {
  961. c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
  962. return
  963. } else {
  964. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
  965. return
  966. }
  967. default:
  968. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"})
  969. return
  970. }
  971. }
  972. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected end of progress response"})
  973. }
  974. func streamResponse(c *gin.Context, ch chan any) {
  975. c.Header("Content-Type", "application/x-ndjson")
  976. c.Stream(func(w io.Writer) bool {
  977. val, ok := <-ch
  978. if !ok {
  979. return false
  980. }
  981. bts, err := json.Marshal(val)
  982. if err != nil {
  983. slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
  984. return false
  985. }
  986. // Delineate chunks with new-line delimiter
  987. bts = append(bts, '\n')
  988. if _, err := w.Write(bts); err != nil {
  989. slog.Info(fmt.Sprintf("streamResponse: w.Write failed with %s", err))
  990. return false
  991. }
  992. return true
  993. })
  994. }
  995. func (s *Server) ProcessHandler(c *gin.Context) {
  996. models := []api.ProcessModelResponse{}
  997. for _, v := range s.sched.loaded {
  998. model := v.model
  999. modelDetails := api.ModelDetails{
  1000. Format: model.Config.ModelFormat,
  1001. Family: model.Config.ModelFamily,
  1002. Families: model.Config.ModelFamilies,
  1003. ParameterSize: model.Config.ModelType,
  1004. QuantizationLevel: model.Config.FileType,
  1005. }
  1006. mr := api.ProcessModelResponse{
  1007. Model: model.ShortName,
  1008. Name: model.ShortName,
  1009. Size: int64(v.estimatedTotal),
  1010. SizeVRAM: int64(v.estimatedVRAM),
  1011. Digest: model.Digest,
  1012. Details: modelDetails,
  1013. ExpiresAt: v.expiresAt,
  1014. }
  1015. // The scheduler waits to set expiresAt, so if a model is loading it's
  1016. // possible that it will be set to the unix epoch. For those cases, just
  1017. // calculate the time w/ the sessionDuration instead.
  1018. var epoch time.Time
  1019. if v.expiresAt == epoch {
  1020. mr.ExpiresAt = time.Now().Add(v.sessionDuration)
  1021. }
  1022. models = append(models, mr)
  1023. }
  1024. slices.SortStableFunc(models, func(i, j api.ProcessModelResponse) int {
  1025. // longest duration remaining listed first
  1026. return cmp.Compare(j.ExpiresAt.Unix(), i.ExpiresAt.Unix())
  1027. })
  1028. c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
  1029. }
  1030. func (s *Server) ChatHandler(c *gin.Context) {
  1031. var req api.ChatRequest
  1032. if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
  1033. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  1034. return
  1035. } else if err != nil {
  1036. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  1037. return
  1038. }
  1039. caps := []Capability{CapabilityCompletion}
  1040. r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
  1041. if errors.Is(err, errCapabilityCompletion) {
  1042. c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
  1043. return
  1044. } else if err != nil {
  1045. handleScheduleError(c, req.Model, err)
  1046. return
  1047. }
  1048. if len(req.Messages) == 0 {
  1049. c.JSON(http.StatusOK, api.ChatResponse{
  1050. Model: req.Model,
  1051. CreatedAt: time.Now().UTC(),
  1052. Message: api.Message{Role: "assistant"},
  1053. Done: true,
  1054. DoneReason: "load",
  1055. })
  1056. return
  1057. }
  1058. prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
  1059. if err != nil {
  1060. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  1061. return
  1062. }
  1063. slog.Debug("chat request", "images", len(images), "prompt", prompt)
  1064. ch := make(chan any)
  1065. go func() {
  1066. defer close(ch)
  1067. if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
  1068. Prompt: prompt,
  1069. Images: images,
  1070. Format: req.Format,
  1071. Options: opts,
  1072. }, func(r llm.CompletionResponse) {
  1073. ch <- api.ChatResponse{
  1074. Model: req.Model,
  1075. CreatedAt: time.Now().UTC(),
  1076. Message: api.Message{Role: "assistant", Content: r.Content},
  1077. Done: r.Done,
  1078. DoneReason: r.DoneReason,
  1079. Metrics: api.Metrics{
  1080. PromptEvalCount: r.PromptEvalCount,
  1081. PromptEvalDuration: r.PromptEvalDuration,
  1082. EvalCount: r.EvalCount,
  1083. EvalDuration: r.EvalDuration,
  1084. },
  1085. }
  1086. }); err != nil {
  1087. ch <- gin.H{"error": err.Error()}
  1088. }
  1089. }()
  1090. if req.Stream != nil && !*req.Stream {
  1091. var r api.ChatResponse
  1092. var sb strings.Builder
  1093. for rr := range ch {
  1094. switch t := rr.(type) {
  1095. case api.ChatResponse:
  1096. sb.WriteString(t.Message.Content)
  1097. r = t
  1098. case gin.H:
  1099. msg, ok := t["error"].(string)
  1100. if !ok {
  1101. msg = "unexpected error format in response"
  1102. }
  1103. c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
  1104. return
  1105. default:
  1106. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
  1107. return
  1108. }
  1109. }
  1110. r.Message.Content = sb.String()
  1111. c.JSON(http.StatusOK, r)
  1112. return
  1113. }
  1114. streamResponse(c, ch)
  1115. }
  1116. func handleScheduleError(c *gin.Context, name string, err error) {
  1117. switch {
  1118. case errors.Is(err, errRequired):
  1119. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  1120. case errors.Is(err, context.Canceled):
  1121. c.JSON(499, gin.H{"error": "request canceled"})
  1122. case errors.Is(err, ErrMaxQueue):
  1123. c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
  1124. case errors.Is(err, os.ErrNotExist):
  1125. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
  1126. default:
  1127. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  1128. }
  1129. }