images.go 32 KB


  1. package server
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "crypto/sha256"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "html/template"
  11. "io"
  12. "log"
  13. "net/http"
  14. "os"
  15. "path/filepath"
  16. "reflect"
  17. "strconv"
  18. "strings"
  19. "github.com/jmorganca/ollama/api"
  20. "github.com/jmorganca/ollama/llm"
  21. "github.com/jmorganca/ollama/parser"
  22. "github.com/jmorganca/ollama/vector"
  23. )
  24. const MaxRetries = 3
  25. type RegistryOptions struct {
  26. Insecure bool
  27. Username string
  28. Password string
  29. Token string
  30. }
  31. type Model struct {
  32. Name string `json:"name"`
  33. ModelPath string
  34. AdapterPaths []string
  35. Template string
  36. System string
  37. Digest string
  38. Options map[string]interface{}
  39. Embeddings []vector.Embedding
  40. }
  41. func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
  42. t := m.Template
  43. if request.Template != "" {
  44. t = request.Template
  45. }
  46. tmpl, err := template.New("").Parse(t)
  47. if err != nil {
  48. return "", err
  49. }
  50. var vars struct {
  51. First bool
  52. System string
  53. Prompt string
  54. Embed string
  55. // deprecated: versions <= 0.0.7 used this to omit the system prompt
  56. Context []int
  57. }
  58. vars.First = len(request.Context) == 0
  59. vars.System = m.System
  60. vars.Prompt = request.Prompt
  61. vars.Context = request.Context
  62. vars.Embed = embedding
  63. if request.System != "" {
  64. vars.System = request.System
  65. }
  66. var sb strings.Builder
  67. if err := tmpl.Execute(&sb, vars); err != nil {
  68. return "", err
  69. }
  70. return sb.String(), nil
  71. }
  72. type ManifestV2 struct {
  73. SchemaVersion int `json:"schemaVersion"`
  74. MediaType string `json:"mediaType"`
  75. Config Layer `json:"config"`
  76. Layers []*Layer `json:"layers"`
  77. }
  78. type Layer struct {
  79. MediaType string `json:"mediaType"`
  80. Digest string `json:"digest"`
  81. Size int `json:"size"`
  82. }
  83. type LayerReader struct {
  84. Layer
  85. io.Reader
  86. }
  87. type ConfigV2 struct {
  88. ModelFamily llm.ModelFamily `json:"model_family"`
  89. ModelType llm.ModelType `json:"model_type"`
  90. FileType llm.FileType `json:"file_type"`
  91. RootFS RootFS `json:"rootfs"`
  92. // required by spec
  93. Architecture string `json:"architecture"`
  94. OS string `json:"os"`
  95. }
  96. type RootFS struct {
  97. Type string `json:"type"`
  98. DiffIDs []string `json:"diff_ids"`
  99. }
  100. func (m *ManifestV2) GetTotalSize() int {
  101. var total int
  102. for _, layer := range m.Layers {
  103. total += layer.Size
  104. }
  105. total += m.Config.Size
  106. return total
  107. }
  108. func GetManifest(mp ModelPath) (*ManifestV2, error) {
  109. fp, err := mp.GetManifestPath(false)
  110. if err != nil {
  111. return nil, err
  112. }
  113. if _, err = os.Stat(fp); err != nil {
  114. return nil, err
  115. }
  116. var manifest *ManifestV2
  117. bts, err := os.ReadFile(fp)
  118. if err != nil {
  119. return nil, fmt.Errorf("couldn't open file '%s'", fp)
  120. }
  121. if err := json.Unmarshal(bts, &manifest); err != nil {
  122. return nil, err
  123. }
  124. return manifest, nil
  125. }
  126. func GetModel(name string) (*Model, error) {
  127. mp := ParseModelPath(name)
  128. manifest, err := GetManifest(mp)
  129. if err != nil {
  130. return nil, err
  131. }
  132. model := &Model{
  133. Name: mp.GetFullTagname(),
  134. Digest: manifest.Config.Digest,
  135. }
  136. for _, layer := range manifest.Layers {
  137. filename, err := GetBlobsPath(layer.Digest)
  138. if err != nil {
  139. return nil, err
  140. }
  141. switch layer.MediaType {
  142. case "application/vnd.ollama.image.model":
  143. model.ModelPath = filename
  144. case "application/vnd.ollama.image.embed":
  145. file, err := os.Open(filename)
  146. if err != nil {
  147. return nil, fmt.Errorf("failed to open file: %s", filename)
  148. }
  149. defer file.Close()
  150. if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
  151. return nil, err
  152. }
  153. case "application/vnd.ollama.image.adapter":
  154. model.AdapterPaths = append(model.AdapterPaths, filename)
  155. case "application/vnd.ollama.image.template":
  156. bts, err := os.ReadFile(filename)
  157. if err != nil {
  158. return nil, err
  159. }
  160. model.Template = string(bts)
  161. case "application/vnd.ollama.image.system":
  162. bts, err := os.ReadFile(filename)
  163. if err != nil {
  164. return nil, err
  165. }
  166. model.System = string(bts)
  167. case "application/vnd.ollama.image.prompt":
  168. bts, err := os.ReadFile(filename)
  169. if err != nil {
  170. return nil, err
  171. }
  172. model.Template = string(bts)
  173. case "application/vnd.ollama.image.params":
  174. params, err := os.Open(filename)
  175. if err != nil {
  176. return nil, err
  177. }
  178. defer params.Close()
  179. // parse model options parameters into a map so that we can see which fields have been specified explicitly
  180. if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
  181. return nil, err
  182. }
  183. }
  184. }
  185. return model, nil
  186. }
  187. func filenameWithPath(path, f string) (string, error) {
  188. // if filePath starts with ~/, replace it with the user's home directory.
  189. if strings.HasPrefix(f, "~/") {
  190. parts := strings.Split(f, "/")
  191. home, err := os.UserHomeDir()
  192. if err != nil {
  193. return "", fmt.Errorf("failed to open file: %v", err)
  194. }
  195. f = filepath.Join(home, filepath.Join(parts[1:]...))
  196. }
  197. // if filePath is not an absolute path, make it relative to the modelfile path
  198. if !filepath.IsAbs(f) {
  199. f = filepath.Join(filepath.Dir(path), f)
  200. }
  201. return f, nil
  202. }
  203. func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error {
  204. mf, err := os.Open(path)
  205. if err != nil {
  206. fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)})
  207. return fmt.Errorf("failed to open file: %w", err)
  208. }
  209. defer mf.Close()
  210. fn(api.ProgressResponse{Status: "parsing modelfile"})
  211. commands, err := parser.Parse(mf)
  212. if err != nil {
  213. return err
  214. }
  215. config := ConfigV2{
  216. Architecture: "amd64",
  217. OS: "linux",
  218. }
  219. var layers []*LayerReader
  220. params := make(map[string][]string)
  221. embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()}
  222. for _, c := range commands {
  223. log.Printf("[%s] - %s\n", c.Name, c.Args)
  224. switch c.Name {
  225. case "model":
  226. fn(api.ProgressResponse{Status: "looking for model"})
  227. embed.model = c.Args
  228. mf, err := GetManifest(ParseModelPath(c.Args))
  229. if err != nil {
  230. modelFile, err := filenameWithPath(path, c.Args)
  231. if err != nil {
  232. return err
  233. }
  234. if _, err := os.Stat(modelFile); err != nil {
  235. // the model file does not exist, try pulling it
  236. if errors.Is(err, os.ErrNotExist) {
  237. fn(api.ProgressResponse{Status: "pulling model file"})
  238. if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
  239. return err
  240. }
  241. mf, err = GetManifest(ParseModelPath(c.Args))
  242. if err != nil {
  243. return fmt.Errorf("failed to open file after pull: %v", err)
  244. }
  245. } else {
  246. return err
  247. }
  248. } else {
  249. // create a model from this specified file
  250. fn(api.ProgressResponse{Status: "creating model layer"})
  251. file, err := os.Open(modelFile)
  252. if err != nil {
  253. return fmt.Errorf("failed to open file: %v", err)
  254. }
  255. defer file.Close()
  256. ggml, err := llm.DecodeGGML(file, llm.ModelFamilyLlama)
  257. if err != nil {
  258. return err
  259. }
  260. config.ModelFamily = ggml.ModelFamily
  261. config.ModelType = ggml.ModelType
  262. config.FileType = ggml.FileType
  263. // reset the file
  264. file.Seek(0, io.SeekStart)
  265. l, err := CreateLayer(file)
  266. if err != nil {
  267. return fmt.Errorf("failed to create layer: %v", err)
  268. }
  269. l.MediaType = "application/vnd.ollama.image.model"
  270. layers = append(layers, l)
  271. }
  272. }
  273. if mf != nil {
  274. log.Printf("manifest = %#v", mf)
  275. for _, l := range mf.Layers {
  276. newLayer, err := GetLayerWithBufferFromLayer(l)
  277. if err != nil {
  278. return err
  279. }
  280. layers = append(layers, newLayer)
  281. }
  282. }
  283. case "embed":
  284. embedFilePath, err := filenameWithPath(path, c.Args)
  285. if err != nil {
  286. return err
  287. }
  288. embed.files = append(embed.files, embedFilePath)
  289. case "adapter":
  290. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  291. fp := c.Args
  292. if strings.HasPrefix(fp, "~/") {
  293. parts := strings.Split(fp, "/")
  294. home, err := os.UserHomeDir()
  295. if err != nil {
  296. return fmt.Errorf("failed to open file: %v", err)
  297. }
  298. fp = filepath.Join(home, filepath.Join(parts[1:]...))
  299. }
  300. // If filePath is not an absolute path, make it relative to the modelfile path
  301. if !filepath.IsAbs(fp) {
  302. fp = filepath.Join(filepath.Dir(path), fp)
  303. }
  304. // create a model from this specified file
  305. fn(api.ProgressResponse{Status: "creating model layer"})
  306. file, err := os.Open(fp)
  307. if err != nil {
  308. return fmt.Errorf("failed to open file: %v", err)
  309. }
  310. defer file.Close()
  311. l, err := CreateLayer(file)
  312. if err != nil {
  313. return fmt.Errorf("failed to create layer: %v", err)
  314. }
  315. l.MediaType = "application/vnd.ollama.image.adapter"
  316. layers = append(layers, l)
  317. case "license":
  318. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  319. mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
  320. layer, err := CreateLayer(strings.NewReader(c.Args))
  321. if err != nil {
  322. return err
  323. }
  324. layer.MediaType = mediaType
  325. layers = append(layers, layer)
  326. case "template", "system", "prompt":
  327. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  328. // remove the layer if one exists
  329. mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
  330. layers = removeLayerFromLayers(layers, mediaType)
  331. layer, err := CreateLayer(strings.NewReader(c.Args))
  332. if err != nil {
  333. return err
  334. }
  335. layer.MediaType = mediaType
  336. layers = append(layers, layer)
  337. default:
  338. // runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop tokens)
  339. params[c.Name] = append(params[c.Name], c.Args)
  340. }
  341. }
  342. // Create a single layer for the parameters
  343. if len(params) > 0 {
  344. fn(api.ProgressResponse{Status: "creating parameter layer"})
  345. layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
  346. formattedParams, err := formatParams(params)
  347. if err != nil {
  348. return fmt.Errorf("couldn't create params json: %v", err)
  349. }
  350. bts, err := json.Marshal(formattedParams)
  351. if err != nil {
  352. return err
  353. }
  354. l, err := CreateLayer(bytes.NewReader(bts))
  355. if err != nil {
  356. return fmt.Errorf("failed to create layer: %v", err)
  357. }
  358. l.MediaType = "application/vnd.ollama.image.params"
  359. layers = append(layers, l)
  360. // apply these parameters to the embedding options, in case embeddings need to be generated using this model
  361. embed.opts = api.DefaultOptions()
  362. embed.opts.FromMap(formattedParams)
  363. }
  364. // generate the embedding layers
  365. embeddingLayers, err := embeddingLayers(embed)
  366. if err != nil {
  367. return err
  368. }
  369. layers = append(layers, embeddingLayers...)
  370. digests, err := getLayerDigests(layers)
  371. if err != nil {
  372. return err
  373. }
  374. var manifestLayers []*Layer
  375. for _, l := range layers {
  376. manifestLayers = append(manifestLayers, &l.Layer)
  377. }
  378. // Create a layer for the config object
  379. fn(api.ProgressResponse{Status: "creating config layer"})
  380. cfg, err := createConfigLayer(config, digests)
  381. if err != nil {
  382. return err
  383. }
  384. layers = append(layers, cfg)
  385. err = SaveLayers(layers, fn, false)
  386. if err != nil {
  387. return err
  388. }
  389. // Create the manifest
  390. fn(api.ProgressResponse{Status: "writing manifest"})
  391. err = CreateManifest(name, cfg, manifestLayers)
  392. if err != nil {
  393. return err
  394. }
  395. fn(api.ProgressResponse{Status: "success"})
  396. return nil
  397. }
  398. type EmbeddingParams struct {
  399. model string
  400. opts api.Options
  401. files []string // paths to files to embed
  402. fn func(resp api.ProgressResponse)
  403. }
  404. // embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
  405. func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
  406. layers := []*LayerReader{}
  407. if len(e.files) > 0 {
  408. if _, err := os.Stat(e.model); err != nil {
  409. if os.IsNotExist(err) {
  410. // this is a model name rather than the file
  411. model, err := GetModel(e.model)
  412. if err != nil {
  413. return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err)
  414. }
  415. e.model = model.ModelPath
  416. } else {
  417. return nil, fmt.Errorf("failed to get model file to generate embeddings: %v", err)
  418. }
  419. }
  420. e.opts.EmbeddingOnly = true
  421. llmModel, err := llm.New(e.model, []string{}, e.opts)
  422. if err != nil {
  423. return nil, fmt.Errorf("load model to generate embeddings: %v", err)
  424. }
  425. defer func() {
  426. if llmModel != nil {
  427. llmModel.Close()
  428. }
  429. }()
  430. addedFiles := make(map[string]bool) // keep track of files that have already been added
  431. for _, filePattern := range e.files {
  432. matchingFiles, err := filepath.Glob(filePattern)
  433. if err != nil {
  434. return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
  435. }
  436. for _, filePath := range matchingFiles {
  437. if addedFiles[filePath] {
  438. continue
  439. }
  440. addedFiles[filePath] = true
  441. // TODO: check file type
  442. f, err := os.Open(filePath)
  443. if err != nil {
  444. return nil, fmt.Errorf("could not open embed file: %w", err)
  445. }
  446. scanner := bufio.NewScanner(f)
  447. scanner.Split(bufio.ScanLines)
  448. data := []string{}
  449. for scanner.Scan() {
  450. data = append(data, scanner.Text())
  451. }
  452. f.Close()
  453. // the digest of the file is set here so that the client knows a new operation is in progress
  454. fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
  455. embeddings := []vector.Embedding{}
  456. for i, d := range data {
  457. if strings.TrimSpace(d) == "" {
  458. continue
  459. }
  460. e.fn(api.ProgressResponse{
  461. Status: fmt.Sprintf("creating embeddings for file %s", filePath),
  462. Digest: fileDigest,
  463. Total: len(data) - 1,
  464. Completed: i,
  465. })
  466. embed, err := llmModel.Embedding(d)
  467. if err != nil {
  468. log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
  469. continue
  470. }
  471. embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
  472. }
  473. b, err := json.Marshal(embeddings)
  474. if err != nil {
  475. return nil, fmt.Errorf("failed to encode embeddings: %w", err)
  476. }
  477. r := bytes.NewReader(b)
  478. digest, size := GetSHA256Digest(r)
  479. // Reset the position of the reader after calculating the digest
  480. if _, err := r.Seek(0, io.SeekStart); err != nil {
  481. return nil, fmt.Errorf("could not reset embed reader: %w", err)
  482. }
  483. layer := &LayerReader{
  484. Layer: Layer{
  485. MediaType: "application/vnd.ollama.image.embed",
  486. Digest: digest,
  487. Size: size,
  488. },
  489. Reader: r,
  490. }
  491. layers = append(layers, layer)
  492. }
  493. }
  494. }
  495. return layers, nil
  496. }
  497. func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
  498. j := 0
  499. for _, l := range layers {
  500. if l.MediaType != mediaType {
  501. layers[j] = l
  502. j++
  503. }
  504. }
  505. return layers[:j]
  506. }
  507. func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
  508. // Write each of the layers to disk
  509. for _, layer := range layers {
  510. fp, err := GetBlobsPath(layer.Digest)
  511. if err != nil {
  512. return err
  513. }
  514. _, err = os.Stat(fp)
  515. if os.IsNotExist(err) || force {
  516. fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
  517. out, err := os.Create(fp)
  518. if err != nil {
  519. log.Printf("couldn't create %s", fp)
  520. return err
  521. }
  522. defer out.Close()
  523. if _, err = io.Copy(out, layer.Reader); err != nil {
  524. return err
  525. }
  526. } else {
  527. fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
  528. }
  529. }
  530. return nil
  531. }
  532. func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
  533. mp := ParseModelPath(name)
  534. manifest := ManifestV2{
  535. SchemaVersion: 2,
  536. MediaType: "application/vnd.docker.distribution.manifest.v2+json",
  537. Config: Layer{
  538. MediaType: cfg.MediaType,
  539. Size: cfg.Size,
  540. Digest: cfg.Digest,
  541. },
  542. Layers: layers,
  543. }
  544. manifestJSON, err := json.Marshal(manifest)
  545. if err != nil {
  546. return err
  547. }
  548. fp, err := mp.GetManifestPath(true)
  549. if err != nil {
  550. return err
  551. }
  552. return os.WriteFile(fp, manifestJSON, 0o644)
  553. }
  554. func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
  555. fp, err := GetBlobsPath(layer.Digest)
  556. if err != nil {
  557. return nil, err
  558. }
  559. file, err := os.Open(fp)
  560. if err != nil {
  561. return nil, fmt.Errorf("could not open blob: %w", err)
  562. }
  563. defer file.Close()
  564. newLayer, err := CreateLayer(file)
  565. if err != nil {
  566. return nil, err
  567. }
  568. newLayer.MediaType = layer.MediaType
  569. return newLayer, nil
  570. }
  571. // formatParams converts specified parameter options to their correct types
  572. func formatParams(params map[string][]string) (map[string]interface{}, error) {
  573. opts := api.Options{}
  574. valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
  575. typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
  576. // build map of json struct tags to their types
  577. jsonOpts := make(map[string]reflect.StructField)
  578. for _, field := range reflect.VisibleFields(typeOpts) {
  579. jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
  580. if jsonTag != "" {
  581. jsonOpts[jsonTag] = field
  582. }
  583. }
  584. out := make(map[string]interface{})
  585. // iterate params and set values based on json struct tags
  586. for key, vals := range params {
  587. if opt, ok := jsonOpts[key]; ok {
  588. field := valueOpts.FieldByName(opt.Name)
  589. if field.IsValid() && field.CanSet() {
  590. switch field.Kind() {
  591. case reflect.Float32:
  592. floatVal, err := strconv.ParseFloat(vals[0], 32)
  593. if err != nil {
  594. return nil, fmt.Errorf("invalid float value %s", vals)
  595. }
  596. out[key] = floatVal
  597. case reflect.Int:
  598. intVal, err := strconv.ParseInt(vals[0], 10, 0)
  599. if err != nil {
  600. return nil, fmt.Errorf("invalid int value %s", vals)
  601. }
  602. out[key] = intVal
  603. case reflect.Bool:
  604. boolVal, err := strconv.ParseBool(vals[0])
  605. if err != nil {
  606. return nil, fmt.Errorf("invalid bool value %s", vals)
  607. }
  608. out[key] = boolVal
  609. case reflect.String:
  610. out[key] = vals[0]
  611. case reflect.Slice:
  612. // TODO: only string slices are supported right now
  613. out[key] = vals
  614. default:
  615. return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
  616. }
  617. }
  618. }
  619. }
  620. return out, nil
  621. }
  622. func getLayerDigests(layers []*LayerReader) ([]string, error) {
  623. var digests []string
  624. for _, l := range layers {
  625. if l.Digest == "" {
  626. return nil, fmt.Errorf("layer is missing a digest")
  627. }
  628. digests = append(digests, l.Digest)
  629. }
  630. return digests, nil
  631. }
  632. // CreateLayer creates a Layer object from a given file
  633. func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
  634. digest, size := GetSHA256Digest(f)
  635. f.Seek(0, io.SeekStart)
  636. layer := &LayerReader{
  637. Layer: Layer{
  638. MediaType: "application/vnd.docker.image.rootfs.diff.tar",
  639. Digest: digest,
  640. Size: size,
  641. },
  642. Reader: f,
  643. }
  644. return layer, nil
  645. }
  646. func CopyModel(src, dest string) error {
  647. srcPath, err := ParseModelPath(src).GetManifestPath(false)
  648. if err != nil {
  649. return err
  650. }
  651. destPath, err := ParseModelPath(dest).GetManifestPath(true)
  652. if err != nil {
  653. return err
  654. }
  655. // copy the file
  656. input, err := os.ReadFile(srcPath)
  657. if err != nil {
  658. fmt.Println("Error reading file:", err)
  659. return err
  660. }
  661. err = os.WriteFile(destPath, input, 0o644)
  662. if err != nil {
  663. fmt.Println("Error reading file:", err)
  664. return err
  665. }
  666. return nil
  667. }
  668. func DeleteModel(name string) error {
  669. mp := ParseModelPath(name)
  670. manifest, err := GetManifest(mp)
  671. if err != nil {
  672. return err
  673. }
  674. deleteMap := make(map[string]bool)
  675. for _, layer := range manifest.Layers {
  676. deleteMap[layer.Digest] = true
  677. }
  678. deleteMap[manifest.Config.Digest] = true
  679. fp, err := GetManifestPath()
  680. if err != nil {
  681. return err
  682. }
  683. err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
  684. if err != nil {
  685. return err
  686. }
  687. if !info.IsDir() {
  688. path := path[len(fp)+1:]
  689. slashIndex := strings.LastIndex(path, "/")
  690. if slashIndex == -1 {
  691. return nil
  692. }
  693. tag := path[:slashIndex] + ":" + path[slashIndex+1:]
  694. fmp := ParseModelPath(tag)
  695. // skip the manifest we're trying to delete
  696. if mp.GetFullTagname() == fmp.GetFullTagname() {
  697. return nil
  698. }
  699. // save (i.e. delete from the deleteMap) any files used in other manifests
  700. manifest, err := GetManifest(fmp)
  701. if err != nil {
  702. log.Printf("skipping file: %s", fp)
  703. return nil
  704. }
  705. for _, layer := range manifest.Layers {
  706. delete(deleteMap, layer.Digest)
  707. }
  708. delete(deleteMap, manifest.Config.Digest)
  709. }
  710. return nil
  711. })
  712. if err != nil {
  713. return err
  714. }
  715. // only delete the files which are still in the deleteMap
  716. for k, v := range deleteMap {
  717. if v {
  718. fp, err := GetBlobsPath(k)
  719. if err != nil {
  720. log.Printf("couldn't get file path for '%s': %v", k, err)
  721. continue
  722. }
  723. if err := os.Remove(fp); err != nil {
  724. log.Printf("couldn't remove file '%s': %v", fp, err)
  725. continue
  726. }
  727. }
  728. }
  729. fp, err = mp.GetManifestPath(false)
  730. if err != nil {
  731. return err
  732. }
  733. err = os.Remove(fp)
  734. if err != nil {
  735. log.Printf("couldn't remove manifest file '%s': %v", fp, err)
  736. return err
  737. }
  738. return nil
  739. }
  740. func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  741. mp := ParseModelPath(name)
  742. fn(api.ProgressResponse{Status: "retrieving manifest"})
  743. manifest, err := GetManifest(mp)
  744. if err != nil {
  745. fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
  746. return err
  747. }
  748. var layers []*Layer
  749. layers = append(layers, manifest.Layers...)
  750. layers = append(layers, &manifest.Config)
  751. for _, layer := range layers {
  752. exists, err := checkBlobExistence(ctx, mp, layer.Digest, regOpts)
  753. if err != nil {
  754. return err
  755. }
  756. if exists {
  757. fn(api.ProgressResponse{
  758. Status: "using existing layer",
  759. Digest: layer.Digest,
  760. Total: layer.Size,
  761. Completed: layer.Size,
  762. })
  763. log.Printf("Layer %s already exists", layer.Digest)
  764. continue
  765. }
  766. fn(api.ProgressResponse{
  767. Status: "starting upload",
  768. Digest: layer.Digest,
  769. Total: layer.Size,
  770. })
  771. location, err := startUpload(ctx, mp, regOpts)
  772. if err != nil {
  773. log.Printf("couldn't start upload: %v", err)
  774. return err
  775. }
  776. err = uploadBlobChunked(ctx, mp, location, layer, regOpts, fn)
  777. if err != nil {
  778. log.Printf("error uploading blob: %v", err)
  779. return err
  780. }
  781. }
  782. fn(api.ProgressResponse{Status: "pushing manifest"})
  783. url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
  784. headers := map[string]string{
  785. "Content-Type": "application/vnd.docker.distribution.manifest.v2+json",
  786. }
  787. manifestJSON, err := json.Marshal(manifest)
  788. if err != nil {
  789. return err
  790. }
  791. resp, err := makeRequest(ctx, "PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
  792. if err != nil {
  793. return err
  794. }
  795. defer resp.Body.Close()
  796. // Check for success: For a successful upload, the Docker registry will respond with a 201 Created
  797. if resp.StatusCode != http.StatusCreated {
  798. body, _ := io.ReadAll(resp.Body)
  799. return fmt.Errorf("on push registry responded with code %d: %v", resp.StatusCode, string(body))
  800. }
  801. fn(api.ProgressResponse{Status: "success"})
  802. return nil
  803. }
  804. func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  805. mp := ParseModelPath(name)
  806. fn(api.ProgressResponse{Status: "pulling manifest"})
  807. manifest, err := pullModelManifest(ctx, mp, regOpts)
  808. if err != nil {
  809. return fmt.Errorf("pull model manifest: %s", err)
  810. }
  811. var layers []*Layer
  812. layers = append(layers, manifest.Layers...)
  813. layers = append(layers, &manifest.Config)
  814. for _, layer := range layers {
  815. if err := downloadBlob(ctx, mp, layer.Digest, regOpts, fn); err != nil {
  816. return err
  817. }
  818. }
  819. fn(api.ProgressResponse{Status: "verifying sha256 digest"})
  820. for _, layer := range layers {
  821. if err := verifyBlob(layer.Digest); err != nil {
  822. if errors.Is(err, errDigestMismatch) {
  823. // something went wrong, delete the blob
  824. fp, err := GetBlobsPath(layer.Digest)
  825. if err != nil {
  826. return err
  827. }
  828. if err := os.Remove(fp); err != nil {
  829. // log this, but return the original error
  830. log.Printf("couldn't remove file with digest mismatch '%s': %v", fp, err)
  831. }
  832. }
  833. return err
  834. }
  835. }
  836. fn(api.ProgressResponse{Status: "writing manifest"})
  837. manifestJSON, err := json.Marshal(manifest)
  838. if err != nil {
  839. return err
  840. }
  841. fp, err := mp.GetManifestPath(true)
  842. if err != nil {
  843. return err
  844. }
  845. err = os.WriteFile(fp, manifestJSON, 0o644)
  846. if err != nil {
  847. log.Printf("couldn't write to %s", fp)
  848. return err
  849. }
  850. fn(api.ProgressResponse{Status: "success"})
  851. return nil
  852. }
  853. func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
  854. url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
  855. headers := map[string]string{
  856. "Accept": "application/vnd.docker.distribution.manifest.v2+json",
  857. }
  858. resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
  859. if err != nil {
  860. log.Printf("couldn't get manifest: %v", err)
  861. return nil, err
  862. }
  863. defer resp.Body.Close()
  864. // Check for success: For a successful upload, the Docker registry will respond with a 201 Created
  865. if resp.StatusCode != http.StatusOK {
  866. if resp.StatusCode == http.StatusNotFound {
  867. return nil, fmt.Errorf("model not found")
  868. }
  869. body, _ := io.ReadAll(resp.Body)
  870. return nil, fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
  871. }
  872. var m *ManifestV2
  873. if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
  874. return nil, err
  875. }
  876. return m, err
  877. }
  878. func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
  879. config.RootFS = RootFS{
  880. Type: "layers",
  881. DiffIDs: layers,
  882. }
  883. configJSON, err := json.Marshal(config)
  884. if err != nil {
  885. return nil, err
  886. }
  887. digest, size := GetSHA256Digest(bytes.NewBuffer(configJSON))
  888. layer := &LayerReader{
  889. Layer: Layer{
  890. MediaType: "application/vnd.docker.container.image.v1+json",
  891. Digest: digest,
  892. Size: size,
  893. },
  894. Reader: bytes.NewBuffer(configJSON),
  895. }
  896. return layer, nil
  897. }
  898. // GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
  899. func GetSHA256Digest(r io.Reader) (string, int) {
  900. h := sha256.New()
  901. n, err := io.Copy(h, r)
  902. if err != nil {
  903. log.Fatal(err)
  904. }
  905. return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
  906. }
  907. func startUpload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (string, error) {
  908. url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
  909. resp, err := makeRequest(ctx, "POST", url, nil, nil, regOpts)
  910. if err != nil {
  911. log.Printf("couldn't start upload: %v", err)
  912. return "", err
  913. }
  914. defer resp.Body.Close()
  915. // Check for success
  916. if resp.StatusCode != http.StatusAccepted {
  917. body, _ := io.ReadAll(resp.Body)
  918. return "", fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
  919. }
  920. // Extract UUID location from header
  921. location := resp.Header.Get("Location")
  922. if location == "" {
  923. return "", fmt.Errorf("location header is missing in response")
  924. }
  925. return location, nil
  926. }
  927. // Function to check if a blob already exists in the Docker registry
  928. func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
  929. url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
  930. resp, err := makeRequest(ctx, "HEAD", url, nil, nil, regOpts)
  931. if err != nil {
  932. log.Printf("couldn't check for blob: %v", err)
  933. return false, err
  934. }
  935. defer resp.Body.Close()
  936. // Check for success: If the blob exists, the Docker registry will respond with a 200 OK
  937. return resp.StatusCode == http.StatusOK, nil
  938. }
  939. func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  940. // TODO allow resumability
  941. // TODO allow canceling uploads via DELETE
  942. // TODO allow cross repo blob mount
  943. fp, err := GetBlobsPath(layer.Digest)
  944. if err != nil {
  945. return err
  946. }
  947. f, err := os.Open(fp)
  948. if err != nil {
  949. return err
  950. }
  951. totalUploaded := 0
  952. r, w := io.Pipe()
  953. defer r.Close()
  954. go func() {
  955. defer w.Close()
  956. for {
  957. n, err := io.CopyN(w, f, 1024*1024)
  958. if err != nil && !errors.Is(err, io.EOF) {
  959. fn(api.ProgressResponse{
  960. Status: fmt.Sprintf("error copying pipe: %v", err),
  961. Digest: layer.Digest,
  962. Total: layer.Size,
  963. Completed: totalUploaded,
  964. })
  965. return
  966. }
  967. totalUploaded += int(n)
  968. fn(api.ProgressResponse{
  969. Status: fmt.Sprintf("uploading %s", layer.Digest),
  970. Digest: layer.Digest,
  971. Total: layer.Size,
  972. Completed: totalUploaded,
  973. })
  974. if totalUploaded >= layer.Size {
  975. return
  976. }
  977. }
  978. }()
  979. url = fmt.Sprintf("%s&digest=%s", url, layer.Digest)
  980. headers := make(map[string]string)
  981. headers["Content-Type"] = "application/octet-stream"
  982. headers["Content-Range"] = fmt.Sprintf("0-%d", layer.Size-1)
  983. headers["Content-Length"] = strconv.Itoa(int(layer.Size))
  984. // finish the upload
  985. resp, err := makeRequest(ctx, "PUT", url, headers, r, regOpts)
  986. if err != nil {
  987. log.Printf("couldn't finish upload: %v", err)
  988. return err
  989. }
  990. defer resp.Body.Close()
  991. if resp.StatusCode != http.StatusCreated {
  992. body, _ := io.ReadAll(resp.Body)
  993. return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
  994. }
  995. return nil
  996. }
  997. func makeRequest(ctx context.Context, method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
  998. retryCtx := ctx.Value("retries")
  999. var retries int
  1000. var ok bool
  1001. if retries, ok = retryCtx.(int); ok {
  1002. if retries > MaxRetries {
  1003. return nil, fmt.Errorf("Maximum retries hit; are you sure you have access to this resource?")
  1004. }
  1005. }
  1006. if !strings.HasPrefix(url, "http") {
  1007. if regOpts.Insecure {
  1008. url = "http://" + url
  1009. } else {
  1010. url = "https://" + url
  1011. }
  1012. }
  1013. // make a copy of the body in case we need to try the call to makeRequest again
  1014. var buf bytes.Buffer
  1015. if body != nil {
  1016. _, err := io.Copy(&buf, body)
  1017. if err != nil {
  1018. return nil, err
  1019. }
  1020. }
  1021. bodyCopy := bytes.NewReader(buf.Bytes())
  1022. req, err := http.NewRequest(method, url, bodyCopy)
  1023. if err != nil {
  1024. return nil, err
  1025. }
  1026. if regOpts.Token != "" {
  1027. req.Header.Set("Authorization", "Bearer "+regOpts.Token)
  1028. } else if regOpts.Username != "" && regOpts.Password != "" {
  1029. req.SetBasicAuth(regOpts.Username, regOpts.Password)
  1030. }
  1031. for k, v := range headers {
  1032. req.Header.Set(k, v)
  1033. }
  1034. client := &http.Client{
  1035. CheckRedirect: func(req *http.Request, via []*http.Request) error {
  1036. if len(via) >= 10 {
  1037. return fmt.Errorf("too many redirects")
  1038. }
  1039. log.Printf("redirected to: %s\n", req.URL)
  1040. return nil
  1041. },
  1042. }
  1043. resp, err := client.Do(req)
  1044. if err != nil {
  1045. return nil, err
  1046. }
  1047. // if the request is unauthenticated, try to authenticate and make the request again
  1048. if resp.StatusCode == http.StatusUnauthorized {
  1049. auth := resp.Header.Get("Www-Authenticate")
  1050. authRedir := ParseAuthRedirectString(string(auth))
  1051. token, err := getAuthToken(ctx, authRedir, regOpts)
  1052. if err != nil {
  1053. return nil, err
  1054. }
  1055. regOpts.Token = token
  1056. bodyCopy = bytes.NewReader(buf.Bytes())
  1057. ctx = context.WithValue(ctx, "retries", retries+1)
  1058. return makeRequest(ctx, method, url, headers, bodyCopy, regOpts)
  1059. }
  1060. return resp, nil
  1061. }
  1062. func getValue(header, key string) string {
  1063. startIdx := strings.Index(header, key+"=")
  1064. if startIdx == -1 {
  1065. return ""
  1066. }
  1067. // Move the index to the starting quote after the key.
  1068. startIdx += len(key) + 2
  1069. endIdx := startIdx
  1070. for endIdx < len(header) {
  1071. if header[endIdx] == '"' {
  1072. if endIdx+1 < len(header) && header[endIdx+1] != ',' { // If the next character isn't a comma, continue
  1073. endIdx++
  1074. continue
  1075. }
  1076. break
  1077. }
  1078. endIdx++
  1079. }
  1080. return header[startIdx:endIdx]
  1081. }
  1082. func ParseAuthRedirectString(authStr string) AuthRedirect {
  1083. authStr = strings.TrimPrefix(authStr, "Bearer ")
  1084. return AuthRedirect{
  1085. Realm: getValue(authStr, "realm"),
  1086. Service: getValue(authStr, "service"),
  1087. Scope: getValue(authStr, "scope"),
  1088. }
  1089. }
  1090. var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again")
  1091. func verifyBlob(digest string) error {
  1092. fp, err := GetBlobsPath(digest)
  1093. if err != nil {
  1094. return err
  1095. }
  1096. f, err := os.Open(fp)
  1097. if err != nil {
  1098. return err
  1099. }
  1100. defer f.Close()
  1101. fileDigest, _ := GetSHA256Digest(f)
  1102. if digest != fileDigest {
  1103. return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest)
  1104. }
  1105. return nil
  1106. }