images.go 34 KB


  1. package server
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "crypto/sha256"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "html/template"
  11. "io"
  12. "log"
  13. "net/http"
  14. "os"
  15. "path"
  16. "path/filepath"
  17. "reflect"
  18. "strconv"
  19. "strings"
  20. "github.com/jmorganca/ollama/api"
  21. "github.com/jmorganca/ollama/llm"
  22. "github.com/jmorganca/ollama/parser"
  23. "github.com/jmorganca/ollama/vector"
  24. )
  25. const MaxRetries = 3
  26. type RegistryOptions struct {
  27. Insecure bool
  28. Username string
  29. Password string
  30. Token string
  31. }
  32. type Model struct {
  33. Name string `json:"name"`
  34. ModelPath string
  35. AdapterPaths []string
  36. Template string
  37. System string
  38. Digest string
  39. Options map[string]interface{}
  40. Embeddings []vector.Embedding
  41. }
  42. func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
  43. t := m.Template
  44. if request.Template != "" {
  45. t = request.Template
  46. }
  47. tmpl, err := template.New("").Parse(t)
  48. if err != nil {
  49. return "", err
  50. }
  51. var vars struct {
  52. First bool
  53. System string
  54. Prompt string
  55. Embed string
  56. // deprecated: versions <= 0.0.7 used this to omit the system prompt
  57. Context []int
  58. }
  59. vars.First = len(request.Context) == 0
  60. vars.System = m.System
  61. vars.Prompt = request.Prompt
  62. vars.Context = request.Context
  63. vars.Embed = embedding
  64. if request.System != "" {
  65. vars.System = request.System
  66. }
  67. var sb strings.Builder
  68. if err := tmpl.Execute(&sb, vars); err != nil {
  69. return "", err
  70. }
  71. return sb.String(), nil
  72. }
  73. type ManifestV2 struct {
  74. SchemaVersion int `json:"schemaVersion"`
  75. MediaType string `json:"mediaType"`
  76. Config Layer `json:"config"`
  77. Layers []*Layer `json:"layers"`
  78. }
  79. type Layer struct {
  80. MediaType string `json:"mediaType"`
  81. Digest string `json:"digest"`
  82. Size int `json:"size"`
  83. From string `json:"from,omitempty"`
  84. }
  85. type LayerReader struct {
  86. Layer
  87. io.Reader
  88. }
  89. type ConfigV2 struct {
  90. ModelFamily llm.ModelFamily `json:"model_family"`
  91. ModelType string `json:"model_type"`
  92. FileType string `json:"file_type"`
  93. RootFS RootFS `json:"rootfs"`
  94. // required by spec
  95. Architecture string `json:"architecture"`
  96. OS string `json:"os"`
  97. }
  98. type RootFS struct {
  99. Type string `json:"type"`
  100. DiffIDs []string `json:"diff_ids"`
  101. }
  102. func (m *ManifestV2) GetTotalSize() int {
  103. var total int
  104. for _, layer := range m.Layers {
  105. total += layer.Size
  106. }
  107. total += m.Config.Size
  108. return total
  109. }
  110. func GetManifest(mp ModelPath) (*ManifestV2, error) {
  111. fp, err := mp.GetManifestPath(false)
  112. if err != nil {
  113. return nil, err
  114. }
  115. if _, err = os.Stat(fp); err != nil {
  116. return nil, err
  117. }
  118. var manifest *ManifestV2
  119. bts, err := os.ReadFile(fp)
  120. if err != nil {
  121. return nil, fmt.Errorf("couldn't open file '%s'", fp)
  122. }
  123. if err := json.Unmarshal(bts, &manifest); err != nil {
  124. return nil, err
  125. }
  126. return manifest, nil
  127. }
  128. func GetModel(name string) (*Model, error) {
  129. mp := ParseModelPath(name)
  130. manifest, err := GetManifest(mp)
  131. if err != nil {
  132. return nil, err
  133. }
  134. model := &Model{
  135. Name: mp.GetFullTagname(),
  136. Digest: manifest.Config.Digest,
  137. }
  138. for _, layer := range manifest.Layers {
  139. filename, err := GetBlobsPath(layer.Digest)
  140. if err != nil {
  141. return nil, err
  142. }
  143. switch layer.MediaType {
  144. case "application/vnd.ollama.image.model":
  145. model.ModelPath = filename
  146. case "application/vnd.ollama.image.embed":
  147. file, err := os.Open(filename)
  148. if err != nil {
  149. return nil, fmt.Errorf("failed to open file: %s", filename)
  150. }
  151. defer file.Close()
  152. if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
  153. return nil, err
  154. }
  155. case "application/vnd.ollama.image.adapter":
  156. model.AdapterPaths = append(model.AdapterPaths, filename)
  157. case "application/vnd.ollama.image.template":
  158. bts, err := os.ReadFile(filename)
  159. if err != nil {
  160. return nil, err
  161. }
  162. model.Template = string(bts)
  163. case "application/vnd.ollama.image.system":
  164. bts, err := os.ReadFile(filename)
  165. if err != nil {
  166. return nil, err
  167. }
  168. model.System = string(bts)
  169. case "application/vnd.ollama.image.prompt":
  170. bts, err := os.ReadFile(filename)
  171. if err != nil {
  172. return nil, err
  173. }
  174. model.Template = string(bts)
  175. case "application/vnd.ollama.image.params":
  176. params, err := os.Open(filename)
  177. if err != nil {
  178. return nil, err
  179. }
  180. defer params.Close()
  181. // parse model options parameters into a map so that we can see which fields have been specified explicitly
  182. if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
  183. return nil, err
  184. }
  185. }
  186. }
  187. return model, nil
  188. }
  189. func filenameWithPath(path, f string) (string, error) {
  190. // if filePath starts with ~/, replace it with the user's home directory.
  191. if strings.HasPrefix(f, "~/") {
  192. parts := strings.Split(f, "/")
  193. home, err := os.UserHomeDir()
  194. if err != nil {
  195. return "", fmt.Errorf("failed to open file: %v", err)
  196. }
  197. f = filepath.Join(home, filepath.Join(parts[1:]...))
  198. }
  199. // if filePath is not an absolute path, make it relative to the modelfile path
  200. if !filepath.IsAbs(f) {
  201. f = filepath.Join(filepath.Dir(path), f)
  202. }
  203. return f, nil
  204. }
  205. func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error {
  206. mf, err := os.Open(path)
  207. if err != nil {
  208. fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)})
  209. return fmt.Errorf("failed to open file: %w", err)
  210. }
  211. defer mf.Close()
  212. fn(api.ProgressResponse{Status: "parsing modelfile"})
  213. commands, err := parser.Parse(mf)
  214. if err != nil {
  215. return err
  216. }
  217. config := ConfigV2{
  218. Architecture: "amd64",
  219. OS: "linux",
  220. }
  221. var layers []*LayerReader
  222. params := make(map[string][]string)
  223. embed := EmbeddingParams{fn: fn}
  224. for _, c := range commands {
  225. log.Printf("[%s] - %s\n", c.Name, c.Args)
  226. switch c.Name {
  227. case "model":
  228. fn(api.ProgressResponse{Status: "looking for model"})
  229. embed.model = c.Args
  230. mp := ParseModelPath(c.Args)
  231. mf, err := GetManifest(mp)
  232. if err != nil {
  233. modelFile, err := filenameWithPath(path, c.Args)
  234. if err != nil {
  235. return err
  236. }
  237. if _, err := os.Stat(modelFile); err != nil {
  238. // the model file does not exist, try pulling it
  239. if errors.Is(err, os.ErrNotExist) {
  240. fn(api.ProgressResponse{Status: "pulling model file"})
  241. if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
  242. return err
  243. }
  244. mf, err = GetManifest(ParseModelPath(c.Args))
  245. if err != nil {
  246. return fmt.Errorf("failed to open file after pull: %v", err)
  247. }
  248. } else {
  249. return err
  250. }
  251. } else {
  252. embed.model = modelFile
  253. // create a model from this specified file
  254. fn(api.ProgressResponse{Status: "creating model layer"})
  255. file, err := os.Open(modelFile)
  256. if err != nil {
  257. return fmt.Errorf("failed to open file: %v", err)
  258. }
  259. defer file.Close()
  260. ggml, err := llm.DecodeGGML(file, llm.ModelFamilyLlama)
  261. if err != nil {
  262. return err
  263. }
  264. config.ModelFamily = ggml.ModelFamily()
  265. config.ModelType = ggml.ModelType().String()
  266. config.FileType = ggml.FileType().String()
  267. // reset the file
  268. file.Seek(0, io.SeekStart)
  269. l, err := CreateLayer(file)
  270. if err != nil {
  271. return fmt.Errorf("failed to create layer: %v", err)
  272. }
  273. l.MediaType = "application/vnd.ollama.image.model"
  274. layers = append(layers, l)
  275. }
  276. }
  277. if mf != nil {
  278. log.Printf("manifest = %#v", mf)
  279. for _, l := range mf.Layers {
  280. newLayer, err := GetLayerWithBufferFromLayer(l)
  281. if err != nil {
  282. return err
  283. }
  284. newLayer.From = mp.GetNamespaceRepository()
  285. layers = append(layers, newLayer)
  286. }
  287. }
  288. case "embed":
  289. embedFilePath, err := filenameWithPath(path, c.Args)
  290. if err != nil {
  291. return err
  292. }
  293. embed.files = append(embed.files, embedFilePath)
  294. case "adapter":
  295. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  296. fp := c.Args
  297. if strings.HasPrefix(fp, "~/") {
  298. parts := strings.Split(fp, "/")
  299. home, err := os.UserHomeDir()
  300. if err != nil {
  301. return fmt.Errorf("failed to open file: %v", err)
  302. }
  303. fp = filepath.Join(home, filepath.Join(parts[1:]...))
  304. }
  305. // If filePath is not an absolute path, make it relative to the modelfile path
  306. if !filepath.IsAbs(fp) {
  307. fp = filepath.Join(filepath.Dir(path), fp)
  308. }
  309. // create a model from this specified file
  310. fn(api.ProgressResponse{Status: "creating model layer"})
  311. file, err := os.Open(fp)
  312. if err != nil {
  313. return fmt.Errorf("failed to open file: %v", err)
  314. }
  315. defer file.Close()
  316. l, err := CreateLayer(file)
  317. if err != nil {
  318. return fmt.Errorf("failed to create layer: %v", err)
  319. }
  320. l.MediaType = "application/vnd.ollama.image.adapter"
  321. layers = append(layers, l)
  322. case "license":
  323. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  324. mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
  325. layer, err := CreateLayer(strings.NewReader(c.Args))
  326. if err != nil {
  327. return err
  328. }
  329. layer.MediaType = mediaType
  330. layers = append(layers, layer)
  331. case "template", "system", "prompt":
  332. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  333. // remove the layer if one exists
  334. mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
  335. layers = removeLayerFromLayers(layers, mediaType)
  336. layer, err := CreateLayer(strings.NewReader(c.Args))
  337. if err != nil {
  338. return err
  339. }
  340. layer.MediaType = mediaType
  341. layers = append(layers, layer)
  342. default:
  343. // runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop tokens)
  344. params[c.Name] = append(params[c.Name], c.Args)
  345. }
  346. }
  347. // Create a single layer for the parameters
  348. if len(params) > 0 {
  349. fn(api.ProgressResponse{Status: "creating parameter layer"})
  350. layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
  351. formattedParams, err := formatParams(params)
  352. if err != nil {
  353. return fmt.Errorf("couldn't create params json: %v", err)
  354. }
  355. bts, err := json.Marshal(formattedParams)
  356. if err != nil {
  357. return err
  358. }
  359. l, err := CreateLayer(bytes.NewReader(bts))
  360. if err != nil {
  361. return fmt.Errorf("failed to create layer: %v", err)
  362. }
  363. l.MediaType = "application/vnd.ollama.image.params"
  364. layers = append(layers, l)
  365. // apply these parameters to the embedding options, in case embeddings need to be generated using this model
  366. embed.opts = formattedParams
  367. }
  368. // generate the embedding layers
  369. embeddingLayers, err := embeddingLayers(embed)
  370. if err != nil {
  371. return err
  372. }
  373. layers = append(layers, embeddingLayers...)
  374. digests, err := getLayerDigests(layers)
  375. if err != nil {
  376. return err
  377. }
  378. var manifestLayers []*Layer
  379. for _, l := range layers {
  380. manifestLayers = append(manifestLayers, &l.Layer)
  381. }
  382. // Create a layer for the config object
  383. fn(api.ProgressResponse{Status: "creating config layer"})
  384. cfg, err := createConfigLayer(config, digests)
  385. if err != nil {
  386. return err
  387. }
  388. layers = append(layers, cfg)
  389. if err := SaveLayers(layers, fn, false); err != nil {
  390. return err
  391. }
  392. // Create the manifest
  393. fn(api.ProgressResponse{Status: "writing manifest"})
  394. err = CreateManifest(name, cfg, manifestLayers)
  395. if err != nil {
  396. return err
  397. }
  398. fn(api.ProgressResponse{Status: "success"})
  399. return nil
  400. }
  401. type EmbeddingParams struct {
  402. model string
  403. opts map[string]interface{}
  404. files []string // paths to files to embed
  405. fn func(resp api.ProgressResponse)
  406. }
  407. // embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
  408. func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
  409. layers := []*LayerReader{}
  410. if len(e.files) > 0 {
  411. // check if the model is a file path or a model name
  412. model, err := GetModel(e.model)
  413. if err != nil {
  414. if !strings.Contains(err.Error(), "couldn't open file") {
  415. return nil, fmt.Errorf("unexpected error opening model to generate embeddings: %v", err)
  416. }
  417. // the model may be a file path, create a model from this file
  418. model = &Model{ModelPath: e.model}
  419. }
  420. if err := load(model, e.opts, defaultSessionDuration); err != nil {
  421. return nil, fmt.Errorf("load model to generate embeddings: %v", err)
  422. }
  423. // this will be used to check if we already have embeddings for a file
  424. modelInfo, err := os.Stat(model.ModelPath)
  425. if err != nil {
  426. return nil, fmt.Errorf("failed to get model file info: %v", err)
  427. }
  428. addedFiles := make(map[string]bool) // keep track of files that have already been added
  429. for _, filePattern := range e.files {
  430. matchingFiles, err := filepath.Glob(filePattern)
  431. if err != nil {
  432. return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
  433. }
  434. for _, filePath := range matchingFiles {
  435. if addedFiles[filePath] {
  436. continue
  437. }
  438. addedFiles[filePath] = true
  439. // check if we already have embeddings for this file path
  440. layerIdentifier := fmt.Sprintf("%s:%s:%s:%d", filePath, e.model, modelInfo.ModTime().Format("2006-01-02 15:04:05"), modelInfo.Size())
  441. digest, _ := GetSHA256Digest(strings.NewReader(layerIdentifier))
  442. existing, err := existingFileEmbeddings(digest)
  443. if err != nil {
  444. return nil, fmt.Errorf("failed to check existing embeddings for file %s: %v", filePath, err)
  445. }
  446. // TODO: check file type
  447. f, err := os.Open(filePath)
  448. if err != nil {
  449. return nil, fmt.Errorf("could not open embed file: %w", err)
  450. }
  451. scanner := bufio.NewScanner(f)
  452. scanner.Split(bufio.ScanLines)
  453. data := []string{}
  454. for scanner.Scan() {
  455. data = append(data, scanner.Text())
  456. }
  457. f.Close()
  458. // the digest of the file is set here so that the client knows a new operation is in progress
  459. fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
  460. embeddings := []vector.Embedding{}
  461. for i, d := range data {
  462. if strings.TrimSpace(d) == "" {
  463. continue
  464. }
  465. e.fn(api.ProgressResponse{
  466. Status: fmt.Sprintf("creating embeddings for file %s", filePath),
  467. Digest: fileDigest,
  468. Total: len(data) - 1,
  469. Completed: i,
  470. })
  471. if len(existing[d]) > 0 {
  472. // already have an embedding for this line
  473. embeddings = append(embeddings, vector.Embedding{Data: d, Vector: existing[d]})
  474. continue
  475. }
  476. embed, err := loaded.llm.Embedding(d)
  477. if err != nil {
  478. log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
  479. continue
  480. }
  481. embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
  482. }
  483. b, err := json.Marshal(embeddings)
  484. if err != nil {
  485. return nil, fmt.Errorf("failed to encode embeddings: %w", err)
  486. }
  487. r := bytes.NewReader(b)
  488. layer := &LayerReader{
  489. Layer: Layer{
  490. MediaType: "application/vnd.ollama.image.embed",
  491. Digest: digest,
  492. Size: r.Len(),
  493. },
  494. Reader: r,
  495. }
  496. layers = append(layers, layer)
  497. }
  498. }
  499. }
  500. return layers, nil
  501. }
  502. // existingFileEmbeddings checks if we already have embeddings for a file and loads them into a look-up map
  503. func existingFileEmbeddings(digest string) (map[string][]float64, error) {
  504. path, err := GetBlobsPath(digest)
  505. if err != nil {
  506. return nil, fmt.Errorf("embeddings blobs path: %w", err)
  507. }
  508. existingFileEmbeddings := make(map[string][]float64)
  509. if _, err := os.Stat(path); err == nil {
  510. // already have some embeddings for this file, load embeddings previously generated
  511. file, err := os.Open(path)
  512. if err != nil {
  513. return nil, fmt.Errorf("failed to open existing embedding file: %s", err)
  514. }
  515. defer file.Close()
  516. existing := []vector.Embedding{}
  517. if err = json.NewDecoder(file).Decode(&existing); err != nil {
  518. return nil, err
  519. }
  520. for _, e := range existing {
  521. existingFileEmbeddings[e.Data] = e.Vector
  522. }
  523. }
  524. return existingFileEmbeddings, nil
  525. }
  526. func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
  527. j := 0
  528. for _, l := range layers {
  529. if l.MediaType != mediaType {
  530. layers[j] = l
  531. j++
  532. }
  533. }
  534. return layers[:j]
  535. }
  536. func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
  537. // Write each of the layers to disk
  538. for _, layer := range layers {
  539. fp, err := GetBlobsPath(layer.Digest)
  540. if err != nil {
  541. return err
  542. }
  543. _, err = os.Stat(fp)
  544. // note: embed layers are always written since their digest doesnt indicate anything about the contents
  545. if os.IsNotExist(err) || force || layer.MediaType == "application/vnd.ollama.image.embed" {
  546. fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
  547. out, err := os.Create(fp)
  548. if err != nil {
  549. log.Printf("couldn't create %s", fp)
  550. return err
  551. }
  552. defer out.Close()
  553. if _, err = io.Copy(out, layer.Reader); err != nil {
  554. return err
  555. }
  556. } else {
  557. fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
  558. }
  559. }
  560. return nil
  561. }
  562. func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
  563. mp := ParseModelPath(name)
  564. manifest := ManifestV2{
  565. SchemaVersion: 2,
  566. MediaType: "application/vnd.docker.distribution.manifest.v2+json",
  567. Config: Layer{
  568. MediaType: cfg.MediaType,
  569. Size: cfg.Size,
  570. Digest: cfg.Digest,
  571. },
  572. Layers: layers,
  573. }
  574. manifestJSON, err := json.Marshal(manifest)
  575. if err != nil {
  576. return err
  577. }
  578. fp, err := mp.GetManifestPath(true)
  579. if err != nil {
  580. return err
  581. }
  582. return os.WriteFile(fp, manifestJSON, 0o644)
  583. }
  584. func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
  585. fp, err := GetBlobsPath(layer.Digest)
  586. if err != nil {
  587. return nil, err
  588. }
  589. file, err := os.Open(fp)
  590. if err != nil {
  591. return nil, fmt.Errorf("could not open blob: %w", err)
  592. }
  593. defer file.Close()
  594. newLayer, err := CreateLayer(file)
  595. if err != nil {
  596. return nil, err
  597. }
  598. newLayer.MediaType = layer.MediaType
  599. return newLayer, nil
  600. }
  601. // formatParams converts specified parameter options to their correct types
  602. func formatParams(params map[string][]string) (map[string]interface{}, error) {
  603. opts := api.Options{}
  604. valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
  605. typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
  606. // build map of json struct tags to their types
  607. jsonOpts := make(map[string]reflect.StructField)
  608. for _, field := range reflect.VisibleFields(typeOpts) {
  609. jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
  610. if jsonTag != "" {
  611. jsonOpts[jsonTag] = field
  612. }
  613. }
  614. out := make(map[string]interface{})
  615. // iterate params and set values based on json struct tags
  616. for key, vals := range params {
  617. if opt, ok := jsonOpts[key]; ok {
  618. field := valueOpts.FieldByName(opt.Name)
  619. if field.IsValid() && field.CanSet() {
  620. switch field.Kind() {
  621. case reflect.Float32:
  622. floatVal, err := strconv.ParseFloat(vals[0], 32)
  623. if err != nil {
  624. return nil, fmt.Errorf("invalid float value %s", vals)
  625. }
  626. out[key] = floatVal
  627. case reflect.Int:
  628. intVal, err := strconv.ParseInt(vals[0], 10, 0)
  629. if err != nil {
  630. return nil, fmt.Errorf("invalid int value %s", vals)
  631. }
  632. out[key] = intVal
  633. case reflect.Bool:
  634. boolVal, err := strconv.ParseBool(vals[0])
  635. if err != nil {
  636. return nil, fmt.Errorf("invalid bool value %s", vals)
  637. }
  638. out[key] = boolVal
  639. case reflect.String:
  640. out[key] = vals[0]
  641. case reflect.Slice:
  642. // TODO: only string slices are supported right now
  643. out[key] = vals
  644. default:
  645. return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
  646. }
  647. }
  648. }
  649. }
  650. return out, nil
  651. }
  652. func getLayerDigests(layers []*LayerReader) ([]string, error) {
  653. var digests []string
  654. for _, l := range layers {
  655. if l.Digest == "" {
  656. return nil, fmt.Errorf("layer is missing a digest")
  657. }
  658. digests = append(digests, l.Digest)
  659. }
  660. return digests, nil
  661. }
  662. // CreateLayer creates a Layer object from a given file
  663. func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
  664. digest, size := GetSHA256Digest(f)
  665. f.Seek(0, io.SeekStart)
  666. layer := &LayerReader{
  667. Layer: Layer{
  668. MediaType: "application/vnd.docker.image.rootfs.diff.tar",
  669. Digest: digest,
  670. Size: size,
  671. },
  672. Reader: f,
  673. }
  674. return layer, nil
  675. }
  676. func CopyModel(src, dest string) error {
  677. srcPath, err := ParseModelPath(src).GetManifestPath(false)
  678. if err != nil {
  679. return err
  680. }
  681. destPath, err := ParseModelPath(dest).GetManifestPath(true)
  682. if err != nil {
  683. return err
  684. }
  685. // copy the file
  686. input, err := os.ReadFile(srcPath)
  687. if err != nil {
  688. fmt.Println("Error reading file:", err)
  689. return err
  690. }
  691. err = os.WriteFile(destPath, input, 0o644)
  692. if err != nil {
  693. fmt.Println("Error reading file:", err)
  694. return err
  695. }
  696. return nil
  697. }
  698. func DeleteModel(name string) error {
  699. mp := ParseModelPath(name)
  700. manifest, err := GetManifest(mp)
  701. if err != nil {
  702. return err
  703. }
  704. deleteMap := make(map[string]bool)
  705. for _, layer := range manifest.Layers {
  706. deleteMap[layer.Digest] = true
  707. }
  708. deleteMap[manifest.Config.Digest] = true
  709. fp, err := GetManifestPath()
  710. if err != nil {
  711. return err
  712. }
  713. err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
  714. if err != nil {
  715. return err
  716. }
  717. if !info.IsDir() {
  718. path := path[len(fp)+1:]
  719. slashIndex := strings.LastIndex(path, "/")
  720. if slashIndex == -1 {
  721. return nil
  722. }
  723. tag := path[:slashIndex] + ":" + path[slashIndex+1:]
  724. fmp := ParseModelPath(tag)
  725. // skip the manifest we're trying to delete
  726. if mp.GetFullTagname() == fmp.GetFullTagname() {
  727. return nil
  728. }
  729. // save (i.e. delete from the deleteMap) any files used in other manifests
  730. manifest, err := GetManifest(fmp)
  731. if err != nil {
  732. log.Printf("skipping file: %s", fp)
  733. return nil
  734. }
  735. for _, layer := range manifest.Layers {
  736. delete(deleteMap, layer.Digest)
  737. }
  738. delete(deleteMap, manifest.Config.Digest)
  739. }
  740. return nil
  741. })
  742. if err != nil {
  743. return err
  744. }
  745. // only delete the files which are still in the deleteMap
  746. for k, v := range deleteMap {
  747. if v {
  748. fp, err := GetBlobsPath(k)
  749. if err != nil {
  750. log.Printf("couldn't get file path for '%s': %v", k, err)
  751. continue
  752. }
  753. if err := os.Remove(fp); err != nil {
  754. log.Printf("couldn't remove file '%s': %v", fp, err)
  755. continue
  756. }
  757. }
  758. }
  759. fp, err = mp.GetManifestPath(false)
  760. if err != nil {
  761. return err
  762. }
  763. err = os.Remove(fp)
  764. if err != nil {
  765. log.Printf("couldn't remove manifest file '%s': %v", fp, err)
  766. return err
  767. }
  768. return nil
  769. }
  770. func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  771. mp := ParseModelPath(name)
  772. fn(api.ProgressResponse{Status: "retrieving manifest"})
  773. manifest, err := GetManifest(mp)
  774. if err != nil {
  775. fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
  776. return err
  777. }
  778. var layers []*Layer
  779. layers = append(layers, manifest.Layers...)
  780. layers = append(layers, &manifest.Config)
  781. for _, layer := range layers {
  782. exists, err := checkBlobExistence(ctx, mp, layer.Digest, regOpts)
  783. if err != nil {
  784. return err
  785. }
  786. if exists {
  787. fn(api.ProgressResponse{
  788. Status: "using existing layer",
  789. Digest: layer.Digest,
  790. Total: layer.Size,
  791. Completed: layer.Size,
  792. })
  793. log.Printf("Layer %s already exists", layer.Digest)
  794. continue
  795. }
  796. fn(api.ProgressResponse{
  797. Status: "starting upload",
  798. Digest: layer.Digest,
  799. Total: layer.Size,
  800. })
  801. location, err := startUpload(ctx, mp, layer, regOpts)
  802. if err != nil {
  803. log.Printf("couldn't start upload: %v", err)
  804. return err
  805. }
  806. if strings.HasPrefix(path.Base(location), "sha256:") {
  807. layer.Digest = path.Base(location)
  808. fn(api.ProgressResponse{
  809. Status: "using existing layer",
  810. Digest: layer.Digest,
  811. Total: layer.Size,
  812. Completed: layer.Size,
  813. })
  814. continue
  815. }
  816. if err := uploadBlobChunked(ctx, mp, location, layer, regOpts, fn); err != nil {
  817. log.Printf("error uploading blob: %v", err)
  818. return err
  819. }
  820. }
  821. fn(api.ProgressResponse{Status: "pushing manifest"})
  822. url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
  823. headers := map[string]string{
  824. "Content-Type": "application/vnd.docker.distribution.manifest.v2+json",
  825. }
  826. manifestJSON, err := json.Marshal(manifest)
  827. if err != nil {
  828. return err
  829. }
  830. resp, err := makeRequestWithRetry(ctx, "PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
  831. if err != nil {
  832. return err
  833. }
  834. defer resp.Body.Close()
  835. fn(api.ProgressResponse{Status: "success"})
  836. return nil
  837. }
  838. func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  839. mp := ParseModelPath(name)
  840. fn(api.ProgressResponse{Status: "pulling manifest"})
  841. manifest, err := pullModelManifest(ctx, mp, regOpts)
  842. if err != nil {
  843. return fmt.Errorf("pull model manifest: %s", err)
  844. }
  845. var layers []*Layer
  846. layers = append(layers, manifest.Layers...)
  847. layers = append(layers, &manifest.Config)
  848. for _, layer := range layers {
  849. if err := downloadBlob(
  850. ctx,
  851. downloadOpts{
  852. mp: mp,
  853. digest: layer.Digest,
  854. regOpts: regOpts,
  855. fn: fn,
  856. }); err != nil {
  857. return err
  858. }
  859. }
  860. fn(api.ProgressResponse{Status: "verifying sha256 digest"})
  861. for _, layer := range layers {
  862. if err := verifyBlob(layer.Digest); err != nil {
  863. if errors.Is(err, errDigestMismatch) {
  864. // something went wrong, delete the blob
  865. fp, err := GetBlobsPath(layer.Digest)
  866. if err != nil {
  867. return err
  868. }
  869. if err := os.Remove(fp); err != nil {
  870. // log this, but return the original error
  871. log.Printf("couldn't remove file with digest mismatch '%s': %v", fp, err)
  872. }
  873. }
  874. return err
  875. }
  876. }
  877. fn(api.ProgressResponse{Status: "writing manifest"})
  878. manifestJSON, err := json.Marshal(manifest)
  879. if err != nil {
  880. return err
  881. }
  882. fp, err := mp.GetManifestPath(true)
  883. if err != nil {
  884. return err
  885. }
  886. err = os.WriteFile(fp, manifestJSON, 0o644)
  887. if err != nil {
  888. log.Printf("couldn't write to %s", fp)
  889. return err
  890. }
  891. fn(api.ProgressResponse{Status: "success"})
  892. return nil
  893. }
  894. func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
  895. url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
  896. headers := map[string]string{
  897. "Accept": "application/vnd.docker.distribution.manifest.v2+json",
  898. }
  899. resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
  900. if err != nil {
  901. log.Printf("couldn't get manifest: %v", err)
  902. return nil, err
  903. }
  904. defer resp.Body.Close()
  905. // Check for success: For a successful upload, the Docker registry will respond with a 201 Created
  906. if resp.StatusCode != http.StatusOK {
  907. if resp.StatusCode == http.StatusNotFound {
  908. return nil, fmt.Errorf("model not found")
  909. }
  910. body, _ := io.ReadAll(resp.Body)
  911. return nil, fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
  912. }
  913. var m *ManifestV2
  914. if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
  915. return nil, err
  916. }
  917. return m, err
  918. }
  919. func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
  920. config.RootFS = RootFS{
  921. Type: "layers",
  922. DiffIDs: layers,
  923. }
  924. configJSON, err := json.Marshal(config)
  925. if err != nil {
  926. return nil, err
  927. }
  928. digest, size := GetSHA256Digest(bytes.NewBuffer(configJSON))
  929. layer := &LayerReader{
  930. Layer: Layer{
  931. MediaType: "application/vnd.docker.container.image.v1+json",
  932. Digest: digest,
  933. Size: size,
  934. },
  935. Reader: bytes.NewBuffer(configJSON),
  936. }
  937. return layer, nil
  938. }
  939. // GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
  940. func GetSHA256Digest(r io.Reader) (string, int) {
  941. h := sha256.New()
  942. n, err := io.Copy(h, r)
  943. if err != nil {
  944. log.Fatal(err)
  945. }
  946. return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
  947. }
  948. type requestContextKey string
  949. func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (string, error) {
  950. url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
  951. if layer.From != "" {
  952. url = fmt.Sprintf("%s/v2/%s/blobs/uploads/?mount=%s&from=%s", mp.Registry, mp.GetNamespaceRepository(), layer.Digest, layer.From)
  953. }
  954. resp, err := makeRequestWithRetry(ctx, "POST", url, nil, nil, regOpts)
  955. if err != nil {
  956. log.Printf("couldn't start upload: %v", err)
  957. return "", err
  958. }
  959. defer resp.Body.Close()
  960. // Extract UUID location from header
  961. location := resp.Header.Get("Location")
  962. if location == "" {
  963. return "", fmt.Errorf("location header is missing in response")
  964. }
  965. return location, nil
  966. }
  967. // Function to check if a blob already exists in the Docker registry
  968. func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
  969. url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
  970. resp, err := makeRequest(ctx, "HEAD", url, nil, nil, regOpts)
  971. if err != nil {
  972. log.Printf("couldn't check for blob: %v", err)
  973. return false, err
  974. }
  975. defer resp.Body.Close()
  976. // Check for success: If the blob exists, the Docker registry will respond with a 200 OK
  977. return resp.StatusCode == http.StatusOK, nil
  978. }
  979. func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  980. // TODO allow resumability
  981. // TODO allow canceling uploads via DELETE
  982. fp, err := GetBlobsPath(layer.Digest)
  983. if err != nil {
  984. return err
  985. }
  986. f, err := os.Open(fp)
  987. if err != nil {
  988. return err
  989. }
  990. defer f.Close()
  991. completed := 0
  992. chunkSize := 10 * 1024 * 1024
  993. for {
  994. r, w := io.Pipe()
  995. defer r.Close()
  996. limit := completed + chunkSize
  997. if chunkSize >= layer.Size-completed {
  998. limit = layer.Size
  999. chunkSize = layer.Size - completed
  1000. }
  1001. go func() {
  1002. defer w.Close()
  1003. for {
  1004. n, err := io.CopyN(w, f, 1024*1024)
  1005. if err != nil && !errors.Is(err, io.EOF) {
  1006. fn(api.ProgressResponse{
  1007. Status: fmt.Sprintf("error copying pipe: %v", err),
  1008. Digest: layer.Digest,
  1009. Total: layer.Size,
  1010. Completed: completed,
  1011. })
  1012. return
  1013. }
  1014. completed += int(n)
  1015. fn(api.ProgressResponse{
  1016. Status: fmt.Sprintf("uploading %s", layer.Digest),
  1017. Digest: layer.Digest,
  1018. Total: layer.Size,
  1019. Completed: completed,
  1020. })
  1021. if completed >= limit {
  1022. return
  1023. }
  1024. }
  1025. }()
  1026. headers := make(map[string]string)
  1027. headers["Content-Type"] = "application/octet-stream"
  1028. headers["Content-Length"] = strconv.Itoa(chunkSize)
  1029. headers["Content-Range"] = fmt.Sprintf("%d-%d", completed, limit-1)
  1030. resp, err := makeRequest(ctx, "PATCH", url, headers, r, regOpts)
  1031. if err != nil {
  1032. return err
  1033. }
  1034. defer resp.Body.Close()
  1035. if resp.StatusCode != http.StatusAccepted {
  1036. body, _ := io.ReadAll(resp.Body)
  1037. return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
  1038. }
  1039. url = resp.Header.Get("Location")
  1040. if completed >= layer.Size {
  1041. break
  1042. }
  1043. }
  1044. url = fmt.Sprintf("%s&digest=%s", url, layer.Digest)
  1045. headers := make(map[string]string)
  1046. headers["Content-Type"] = "application/octet-stream"
  1047. headers["Content-Length"] = "0"
  1048. // finish the upload
  1049. resp, err := makeRequest(ctx, "PUT", url, headers, nil, regOpts)
  1050. if err != nil {
  1051. log.Printf("couldn't finish upload: %v", err)
  1052. return err
  1053. }
  1054. defer resp.Body.Close()
  1055. if resp.StatusCode != http.StatusCreated {
  1056. body, _ := io.ReadAll(resp.Body)
  1057. return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
  1058. }
  1059. return nil
  1060. }
  1061. func makeRequestWithRetry(ctx context.Context, method, url string, headers map[string]string, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
  1062. var status string
  1063. for try := 0; try < MaxRetries; try++ {
  1064. resp, err := makeRequest(ctx, method, url, headers, body, regOpts)
  1065. if err != nil {
  1066. log.Printf("couldn't start upload: %v", err)
  1067. return nil, err
  1068. }
  1069. status = resp.Status
  1070. switch resp.StatusCode {
  1071. case http.StatusAccepted, http.StatusCreated:
  1072. return resp, nil
  1073. case http.StatusUnauthorized:
  1074. auth := resp.Header.Get("www-authenticate")
  1075. authRedir := ParseAuthRedirectString(auth)
  1076. token, err := getAuthToken(ctx, authRedir, regOpts)
  1077. if err != nil {
  1078. return nil, err
  1079. }
  1080. regOpts.Token = token
  1081. if body != nil {
  1082. if _, err := body.Seek(0, io.SeekStart); err != nil {
  1083. return nil, err
  1084. }
  1085. }
  1086. continue
  1087. default:
  1088. body, _ := io.ReadAll(resp.Body)
  1089. return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
  1090. }
  1091. }
  1092. return nil, fmt.Errorf("max retry exceeded: %v", status)
  1093. }
  1094. func makeRequest(ctx context.Context, method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
  1095. if !strings.HasPrefix(url, "http") {
  1096. if regOpts.Insecure {
  1097. url = "http://" + url
  1098. } else {
  1099. url = "https://" + url
  1100. }
  1101. }
  1102. req, err := http.NewRequestWithContext(ctx, method, url, body)
  1103. if err != nil {
  1104. return nil, err
  1105. }
  1106. if regOpts.Token != "" {
  1107. req.Header.Set("Authorization", "Bearer "+regOpts.Token)
  1108. } else if regOpts.Username != "" && regOpts.Password != "" {
  1109. req.SetBasicAuth(regOpts.Username, regOpts.Password)
  1110. }
  1111. for k, v := range headers {
  1112. req.Header.Set(k, v)
  1113. }
  1114. client := &http.Client{
  1115. CheckRedirect: func(req *http.Request, via []*http.Request) error {
  1116. if len(via) >= 10 {
  1117. return fmt.Errorf("too many redirects")
  1118. }
  1119. log.Printf("redirected to: %s\n", req.URL)
  1120. return nil
  1121. },
  1122. }
  1123. resp, err := client.Do(req)
  1124. if err != nil {
  1125. return nil, err
  1126. }
  1127. return resp, nil
  1128. }
  1129. func getValue(header, key string) string {
  1130. startIdx := strings.Index(header, key+"=")
  1131. if startIdx == -1 {
  1132. return ""
  1133. }
  1134. // Move the index to the starting quote after the key.
  1135. startIdx += len(key) + 2
  1136. endIdx := startIdx
  1137. for endIdx < len(header) {
  1138. if header[endIdx] == '"' {
  1139. if endIdx+1 < len(header) && header[endIdx+1] != ',' { // If the next character isn't a comma, continue
  1140. endIdx++
  1141. continue
  1142. }
  1143. break
  1144. }
  1145. endIdx++
  1146. }
  1147. return header[startIdx:endIdx]
  1148. }
  1149. func ParseAuthRedirectString(authStr string) AuthRedirect {
  1150. authStr = strings.TrimPrefix(authStr, "Bearer ")
  1151. return AuthRedirect{
  1152. Realm: getValue(authStr, "realm"),
  1153. Service: getValue(authStr, "service"),
  1154. Scope: getValue(authStr, "scope"),
  1155. }
  1156. }
  1157. var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again")
  1158. func verifyBlob(digest string) error {
  1159. fp, err := GetBlobsPath(digest)
  1160. if err != nil {
  1161. return err
  1162. }
  1163. f, err := os.Open(fp)
  1164. if err != nil {
  1165. return err
  1166. }
  1167. defer f.Close()
  1168. fileDigest, _ := GetSHA256Digest(f)
  1169. if digest != fileDigest {
  1170. return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest)
  1171. }
  1172. return nil
  1173. }