images.go 35 KB


  1. package server
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "crypto/sha256"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "html/template"
  11. "io"
  12. "log"
  13. "net/http"
  14. "os"
  15. "path"
  16. "path/filepath"
  17. "reflect"
  18. "strconv"
  19. "strings"
  20. "github.com/jmorganca/ollama/api"
  21. "github.com/jmorganca/ollama/llm"
  22. "github.com/jmorganca/ollama/parser"
  23. "github.com/jmorganca/ollama/vector"
  24. )
  25. const MaxRetries = 3
  26. type RegistryOptions struct {
  27. Insecure bool
  28. Username string
  29. Password string
  30. Token string
  31. }
  32. type Model struct {
  33. Name string `json:"name"`
  34. ModelPath string
  35. AdapterPaths []string
  36. Template string
  37. System string
  38. Digest string
  39. Options map[string]interface{}
  40. Embeddings []vector.Embedding
  41. }
  42. func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
  43. t := m.Template
  44. if request.Template != "" {
  45. t = request.Template
  46. }
  47. tmpl, err := template.New("").Parse(t)
  48. if err != nil {
  49. return "", err
  50. }
  51. var vars struct {
  52. First bool
  53. System string
  54. Prompt string
  55. Embed string
  56. // deprecated: versions <= 0.0.7 used this to omit the system prompt
  57. Context []int
  58. }
  59. vars.First = len(request.Context) == 0
  60. vars.System = m.System
  61. vars.Prompt = request.Prompt
  62. vars.Context = request.Context
  63. vars.Embed = embedding
  64. if request.System != "" {
  65. vars.System = request.System
  66. }
  67. var sb strings.Builder
  68. if err := tmpl.Execute(&sb, vars); err != nil {
  69. return "", err
  70. }
  71. return sb.String(), nil
  72. }
  73. type ManifestV2 struct {
  74. SchemaVersion int `json:"schemaVersion"`
  75. MediaType string `json:"mediaType"`
  76. Config Layer `json:"config"`
  77. Layers []*Layer `json:"layers"`
  78. }
  79. type Layer struct {
  80. MediaType string `json:"mediaType"`
  81. Digest string `json:"digest"`
  82. Size int `json:"size"`
  83. From string `json:"from,omitempty"`
  84. }
  85. type LayerReader struct {
  86. Layer
  87. io.Reader
  88. }
  89. type ConfigV2 struct {
  90. ModelFamily llm.ModelFamily `json:"model_family"`
  91. ModelType string `json:"model_type"`
  92. FileType string `json:"file_type"`
  93. RootFS RootFS `json:"rootfs"`
  94. // required by spec
  95. Architecture string `json:"architecture"`
  96. OS string `json:"os"`
  97. }
  98. type RootFS struct {
  99. Type string `json:"type"`
  100. DiffIDs []string `json:"diff_ids"`
  101. }
  102. func (m *ManifestV2) GetTotalSize() int {
  103. var total int
  104. for _, layer := range m.Layers {
  105. total += layer.Size
  106. }
  107. total += m.Config.Size
  108. return total
  109. }
  110. func GetManifest(mp ModelPath) (*ManifestV2, error) {
  111. fp, err := mp.GetManifestPath(false)
  112. if err != nil {
  113. return nil, err
  114. }
  115. if _, err = os.Stat(fp); err != nil {
  116. return nil, err
  117. }
  118. var manifest *ManifestV2
  119. bts, err := os.ReadFile(fp)
  120. if err != nil {
  121. return nil, fmt.Errorf("couldn't open file '%s'", fp)
  122. }
  123. if err := json.Unmarshal(bts, &manifest); err != nil {
  124. return nil, err
  125. }
  126. return manifest, nil
  127. }
  128. func GetModel(name string) (*Model, error) {
  129. mp := ParseModelPath(name)
  130. manifest, err := GetManifest(mp)
  131. if err != nil {
  132. return nil, err
  133. }
  134. model := &Model{
  135. Name: mp.GetFullTagname(),
  136. Digest: manifest.Config.Digest,
  137. }
  138. for _, layer := range manifest.Layers {
  139. filename, err := GetBlobsPath(layer.Digest)
  140. if err != nil {
  141. return nil, err
  142. }
  143. switch layer.MediaType {
  144. case "application/vnd.ollama.image.model":
  145. model.ModelPath = filename
  146. case "application/vnd.ollama.image.embed":
  147. file, err := os.Open(filename)
  148. if err != nil {
  149. return nil, fmt.Errorf("failed to open file: %s", filename)
  150. }
  151. defer file.Close()
  152. if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
  153. return nil, err
  154. }
  155. case "application/vnd.ollama.image.adapter":
  156. model.AdapterPaths = append(model.AdapterPaths, filename)
  157. case "application/vnd.ollama.image.template":
  158. bts, err := os.ReadFile(filename)
  159. if err != nil {
  160. return nil, err
  161. }
  162. model.Template = string(bts)
  163. case "application/vnd.ollama.image.system":
  164. bts, err := os.ReadFile(filename)
  165. if err != nil {
  166. return nil, err
  167. }
  168. model.System = string(bts)
  169. case "application/vnd.ollama.image.prompt":
  170. bts, err := os.ReadFile(filename)
  171. if err != nil {
  172. return nil, err
  173. }
  174. model.Template = string(bts)
  175. case "application/vnd.ollama.image.params":
  176. params, err := os.Open(filename)
  177. if err != nil {
  178. return nil, err
  179. }
  180. defer params.Close()
  181. // parse model options parameters into a map so that we can see which fields have been specified explicitly
  182. if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
  183. return nil, err
  184. }
  185. }
  186. }
  187. return model, nil
  188. }
  189. func filenameWithPath(path, f string) (string, error) {
  190. // if filePath starts with ~/, replace it with the user's home directory.
  191. if strings.HasPrefix(f, "~/") {
  192. parts := strings.Split(f, "/")
  193. home, err := os.UserHomeDir()
  194. if err != nil {
  195. return "", fmt.Errorf("failed to open file: %v", err)
  196. }
  197. f = filepath.Join(home, filepath.Join(parts[1:]...))
  198. }
  199. // if filePath is not an absolute path, make it relative to the modelfile path
  200. if !filepath.IsAbs(f) {
  201. f = filepath.Join(filepath.Dir(path), f)
  202. }
  203. return f, nil
  204. }
  205. func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error {
  206. mf, err := os.Open(path)
  207. if err != nil {
  208. fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)})
  209. return fmt.Errorf("failed to open file: %w", err)
  210. }
  211. defer mf.Close()
  212. fn(api.ProgressResponse{Status: "parsing modelfile"})
  213. commands, err := parser.Parse(mf)
  214. if err != nil {
  215. return err
  216. }
  217. config := ConfigV2{
  218. Architecture: "amd64",
  219. OS: "linux",
  220. }
  221. var layers []*LayerReader
  222. params := make(map[string][]string)
  223. embed := EmbeddingParams{fn: fn}
  224. for _, c := range commands {
  225. log.Printf("[%s] - %s\n", c.Name, c.Args)
  226. switch c.Name {
  227. case "model":
  228. fn(api.ProgressResponse{Status: "looking for model"})
  229. embed.model = c.Args
  230. mp := ParseModelPath(c.Args)
  231. mf, err := GetManifest(mp)
  232. if err != nil {
  233. modelFile, err := filenameWithPath(path, c.Args)
  234. if err != nil {
  235. return err
  236. }
  237. if _, err := os.Stat(modelFile); err != nil {
  238. // the model file does not exist, try pulling it
  239. if errors.Is(err, os.ErrNotExist) {
  240. fn(api.ProgressResponse{Status: "pulling model file"})
  241. if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
  242. return err
  243. }
  244. mf, err = GetManifest(mp)
  245. if err != nil {
  246. return fmt.Errorf("failed to open file after pull: %v", err)
  247. }
  248. } else {
  249. return err
  250. }
  251. } else {
  252. embed.model = modelFile
  253. // create a model from this specified file
  254. fn(api.ProgressResponse{Status: "creating model layer"})
  255. file, err := os.Open(modelFile)
  256. if err != nil {
  257. return fmt.Errorf("failed to open file: %v", err)
  258. }
  259. defer file.Close()
  260. ggml, err := llm.DecodeGGML(file, llm.ModelFamilyLlama)
  261. if err != nil {
  262. return err
  263. }
  264. config.ModelFamily = ggml.ModelFamily()
  265. config.ModelType = ggml.ModelType().String()
  266. config.FileType = ggml.FileType().String()
  267. // reset the file
  268. file.Seek(0, io.SeekStart)
  269. l, err := CreateLayer(file)
  270. if err != nil {
  271. return fmt.Errorf("failed to create layer: %v", err)
  272. }
  273. l.MediaType = "application/vnd.ollama.image.model"
  274. layers = append(layers, l)
  275. }
  276. }
  277. if mf != nil {
  278. sourceBlobPath, err := GetBlobsPath(mf.Config.Digest)
  279. if err != nil {
  280. return err
  281. }
  282. sourceBlob, err := os.Open(sourceBlobPath)
  283. if err != nil {
  284. return err
  285. }
  286. defer sourceBlob.Close()
  287. var source ConfigV2
  288. if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil {
  289. return err
  290. }
  291. // copie the model metadata
  292. config.ModelFamily = source.ModelFamily
  293. config.ModelType = source.ModelType
  294. config.FileType = source.FileType
  295. for _, l := range mf.Layers {
  296. newLayer, err := GetLayerWithBufferFromLayer(l)
  297. if err != nil {
  298. return err
  299. }
  300. newLayer.From = mp.GetNamespaceRepository()
  301. layers = append(layers, newLayer)
  302. }
  303. }
  304. case "embed":
  305. embedFilePath, err := filenameWithPath(path, c.Args)
  306. if err != nil {
  307. return err
  308. }
  309. embed.files = append(embed.files, embedFilePath)
  310. case "adapter":
  311. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  312. fp := c.Args
  313. if strings.HasPrefix(fp, "~/") {
  314. parts := strings.Split(fp, "/")
  315. home, err := os.UserHomeDir()
  316. if err != nil {
  317. return fmt.Errorf("failed to open file: %v", err)
  318. }
  319. fp = filepath.Join(home, filepath.Join(parts[1:]...))
  320. }
  321. // If filePath is not an absolute path, make it relative to the modelfile path
  322. if !filepath.IsAbs(fp) {
  323. fp = filepath.Join(filepath.Dir(path), fp)
  324. }
  325. // create a model from this specified file
  326. fn(api.ProgressResponse{Status: "creating model layer"})
  327. file, err := os.Open(fp)
  328. if err != nil {
  329. return fmt.Errorf("failed to open file: %v", err)
  330. }
  331. defer file.Close()
  332. l, err := CreateLayer(file)
  333. if err != nil {
  334. return fmt.Errorf("failed to create layer: %v", err)
  335. }
  336. l.MediaType = "application/vnd.ollama.image.adapter"
  337. layers = append(layers, l)
  338. case "license":
  339. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  340. mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
  341. layer, err := CreateLayer(strings.NewReader(c.Args))
  342. if err != nil {
  343. return err
  344. }
  345. layer.MediaType = mediaType
  346. layers = append(layers, layer)
  347. case "template", "system", "prompt":
  348. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  349. // remove the layer if one exists
  350. mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
  351. layers = removeLayerFromLayers(layers, mediaType)
  352. layer, err := CreateLayer(strings.NewReader(c.Args))
  353. if err != nil {
  354. return err
  355. }
  356. layer.MediaType = mediaType
  357. layers = append(layers, layer)
  358. default:
  359. // runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop tokens)
  360. params[c.Name] = append(params[c.Name], c.Args)
  361. }
  362. }
  363. // Create a single layer for the parameters
  364. if len(params) > 0 {
  365. fn(api.ProgressResponse{Status: "creating parameter layer"})
  366. layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
  367. formattedParams, err := formatParams(params)
  368. if err != nil {
  369. return fmt.Errorf("couldn't create params json: %v", err)
  370. }
  371. bts, err := json.Marshal(formattedParams)
  372. if err != nil {
  373. return err
  374. }
  375. l, err := CreateLayer(bytes.NewReader(bts))
  376. if err != nil {
  377. return fmt.Errorf("failed to create layer: %v", err)
  378. }
  379. l.MediaType = "application/vnd.ollama.image.params"
  380. layers = append(layers, l)
  381. // apply these parameters to the embedding options, in case embeddings need to be generated using this model
  382. embed.opts = formattedParams
  383. }
  384. // generate the embedding layers
  385. embeddingLayers, err := embeddingLayers(embed)
  386. if err != nil {
  387. return err
  388. }
  389. layers = append(layers, embeddingLayers...)
  390. digests, err := getLayerDigests(layers)
  391. if err != nil {
  392. return err
  393. }
  394. var manifestLayers []*Layer
  395. for _, l := range layers {
  396. manifestLayers = append(manifestLayers, &l.Layer)
  397. }
  398. // Create a layer for the config object
  399. fn(api.ProgressResponse{Status: "creating config layer"})
  400. cfg, err := createConfigLayer(config, digests)
  401. if err != nil {
  402. return err
  403. }
  404. layers = append(layers, cfg)
  405. if err := SaveLayers(layers, fn, false); err != nil {
  406. return err
  407. }
  408. // Create the manifest
  409. fn(api.ProgressResponse{Status: "writing manifest"})
  410. err = CreateManifest(name, cfg, manifestLayers)
  411. if err != nil {
  412. return err
  413. }
  414. fn(api.ProgressResponse{Status: "success"})
  415. return nil
  416. }
  417. type EmbeddingParams struct {
  418. model string
  419. opts map[string]interface{}
  420. files []string // paths to files to embed
  421. fn func(resp api.ProgressResponse)
  422. }
  423. // embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
  424. func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
  425. layers := []*LayerReader{}
  426. if len(e.files) > 0 {
  427. // check if the model is a file path or a model name
  428. model, err := GetModel(e.model)
  429. if err != nil {
  430. if !strings.Contains(err.Error(), "couldn't open file") {
  431. return nil, fmt.Errorf("unexpected error opening model to generate embeddings: %v", err)
  432. }
  433. // the model may be a file path, create a model from this file
  434. model = &Model{ModelPath: e.model}
  435. }
  436. if err := load(model, e.opts, defaultSessionDuration); err != nil {
  437. return nil, fmt.Errorf("load model to generate embeddings: %v", err)
  438. }
  439. // this will be used to check if we already have embeddings for a file
  440. modelInfo, err := os.Stat(model.ModelPath)
  441. if err != nil {
  442. return nil, fmt.Errorf("failed to get model file info: %v", err)
  443. }
  444. addedFiles := make(map[string]bool) // keep track of files that have already been added
  445. for _, filePattern := range e.files {
  446. matchingFiles, err := filepath.Glob(filePattern)
  447. if err != nil {
  448. return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
  449. }
  450. for _, filePath := range matchingFiles {
  451. if addedFiles[filePath] {
  452. continue
  453. }
  454. addedFiles[filePath] = true
  455. // check if we already have embeddings for this file path
  456. layerIdentifier := fmt.Sprintf("%s:%s:%s:%d", filePath, e.model, modelInfo.ModTime().Format("2006-01-02 15:04:05"), modelInfo.Size())
  457. digest, _ := GetSHA256Digest(strings.NewReader(layerIdentifier))
  458. existing, err := existingFileEmbeddings(digest)
  459. if err != nil {
  460. return nil, fmt.Errorf("failed to check existing embeddings for file %s: %v", filePath, err)
  461. }
  462. // TODO: check file type
  463. f, err := os.Open(filePath)
  464. if err != nil {
  465. return nil, fmt.Errorf("could not open embed file: %w", err)
  466. }
  467. scanner := bufio.NewScanner(f)
  468. scanner.Split(bufio.ScanLines)
  469. data := []string{}
  470. for scanner.Scan() {
  471. data = append(data, scanner.Text())
  472. }
  473. f.Close()
  474. // the digest of the file is set here so that the client knows a new operation is in progress
  475. fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
  476. embeddings := []vector.Embedding{}
  477. for i, d := range data {
  478. if strings.TrimSpace(d) == "" {
  479. continue
  480. }
  481. e.fn(api.ProgressResponse{
  482. Status: fmt.Sprintf("creating embeddings for file %s", filePath),
  483. Digest: fileDigest,
  484. Total: len(data) - 1,
  485. Completed: i,
  486. })
  487. if len(existing[d]) > 0 {
  488. // already have an embedding for this line
  489. embeddings = append(embeddings, vector.Embedding{Data: d, Vector: existing[d]})
  490. continue
  491. }
  492. embed, err := loaded.llm.Embedding(d)
  493. if err != nil {
  494. log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
  495. continue
  496. }
  497. embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
  498. }
  499. b, err := json.Marshal(embeddings)
  500. if err != nil {
  501. return nil, fmt.Errorf("failed to encode embeddings: %w", err)
  502. }
  503. r := bytes.NewReader(b)
  504. layer := &LayerReader{
  505. Layer: Layer{
  506. MediaType: "application/vnd.ollama.image.embed",
  507. Digest: digest,
  508. Size: r.Len(),
  509. },
  510. Reader: r,
  511. }
  512. layers = append(layers, layer)
  513. }
  514. }
  515. }
  516. return layers, nil
  517. }
  518. // existingFileEmbeddings checks if we already have embeddings for a file and loads them into a look-up map
  519. func existingFileEmbeddings(digest string) (map[string][]float64, error) {
  520. path, err := GetBlobsPath(digest)
  521. if err != nil {
  522. return nil, fmt.Errorf("embeddings blobs path: %w", err)
  523. }
  524. existingFileEmbeddings := make(map[string][]float64)
  525. if _, err := os.Stat(path); err == nil {
  526. // already have some embeddings for this file, load embeddings previously generated
  527. file, err := os.Open(path)
  528. if err != nil {
  529. return nil, fmt.Errorf("failed to open existing embedding file: %s", err)
  530. }
  531. defer file.Close()
  532. existing := []vector.Embedding{}
  533. if err = json.NewDecoder(file).Decode(&existing); err != nil {
  534. return nil, err
  535. }
  536. for _, e := range existing {
  537. existingFileEmbeddings[e.Data] = e.Vector
  538. }
  539. }
  540. return existingFileEmbeddings, nil
  541. }
  542. func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
  543. j := 0
  544. for _, l := range layers {
  545. if l.MediaType != mediaType {
  546. layers[j] = l
  547. j++
  548. }
  549. }
  550. return layers[:j]
  551. }
  552. func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
  553. // Write each of the layers to disk
  554. for _, layer := range layers {
  555. fp, err := GetBlobsPath(layer.Digest)
  556. if err != nil {
  557. return err
  558. }
  559. _, err = os.Stat(fp)
  560. // note: embed layers are always written since their digest doesnt indicate anything about the contents
  561. if os.IsNotExist(err) || force || layer.MediaType == "application/vnd.ollama.image.embed" {
  562. fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
  563. out, err := os.Create(fp)
  564. if err != nil {
  565. log.Printf("couldn't create %s", fp)
  566. return err
  567. }
  568. defer out.Close()
  569. if _, err = io.Copy(out, layer.Reader); err != nil {
  570. return err
  571. }
  572. } else {
  573. fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
  574. }
  575. }
  576. return nil
  577. }
  578. func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
  579. mp := ParseModelPath(name)
  580. manifest := ManifestV2{
  581. SchemaVersion: 2,
  582. MediaType: "application/vnd.docker.distribution.manifest.v2+json",
  583. Config: Layer{
  584. MediaType: cfg.MediaType,
  585. Size: cfg.Size,
  586. Digest: cfg.Digest,
  587. },
  588. Layers: layers,
  589. }
  590. manifestJSON, err := json.Marshal(manifest)
  591. if err != nil {
  592. return err
  593. }
  594. fp, err := mp.GetManifestPath(true)
  595. if err != nil {
  596. return err
  597. }
  598. return os.WriteFile(fp, manifestJSON, 0o644)
  599. }
  600. func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
  601. fp, err := GetBlobsPath(layer.Digest)
  602. if err != nil {
  603. return nil, err
  604. }
  605. file, err := os.Open(fp)
  606. if err != nil {
  607. return nil, fmt.Errorf("could not open blob: %w", err)
  608. }
  609. defer file.Close()
  610. newLayer, err := CreateLayer(file)
  611. if err != nil {
  612. return nil, err
  613. }
  614. newLayer.MediaType = layer.MediaType
  615. return newLayer, nil
  616. }
  617. // formatParams converts specified parameter options to their correct types
  618. func formatParams(params map[string][]string) (map[string]interface{}, error) {
  619. opts := api.Options{}
  620. valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
  621. typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
  622. // build map of json struct tags to their types
  623. jsonOpts := make(map[string]reflect.StructField)
  624. for _, field := range reflect.VisibleFields(typeOpts) {
  625. jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
  626. if jsonTag != "" {
  627. jsonOpts[jsonTag] = field
  628. }
  629. }
  630. out := make(map[string]interface{})
  631. // iterate params and set values based on json struct tags
  632. for key, vals := range params {
  633. if opt, ok := jsonOpts[key]; ok {
  634. field := valueOpts.FieldByName(opt.Name)
  635. if field.IsValid() && field.CanSet() {
  636. switch field.Kind() {
  637. case reflect.Float32:
  638. floatVal, err := strconv.ParseFloat(vals[0], 32)
  639. if err != nil {
  640. return nil, fmt.Errorf("invalid float value %s", vals)
  641. }
  642. out[key] = floatVal
  643. case reflect.Int:
  644. intVal, err := strconv.ParseInt(vals[0], 10, 0)
  645. if err != nil {
  646. return nil, fmt.Errorf("invalid int value %s", vals)
  647. }
  648. out[key] = intVal
  649. case reflect.Bool:
  650. boolVal, err := strconv.ParseBool(vals[0])
  651. if err != nil {
  652. return nil, fmt.Errorf("invalid bool value %s", vals)
  653. }
  654. out[key] = boolVal
  655. case reflect.String:
  656. out[key] = vals[0]
  657. case reflect.Slice:
  658. // TODO: only string slices are supported right now
  659. out[key] = vals
  660. default:
  661. return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
  662. }
  663. }
  664. }
  665. }
  666. return out, nil
  667. }
  668. func getLayerDigests(layers []*LayerReader) ([]string, error) {
  669. var digests []string
  670. for _, l := range layers {
  671. if l.Digest == "" {
  672. return nil, fmt.Errorf("layer is missing a digest")
  673. }
  674. digests = append(digests, l.Digest)
  675. }
  676. return digests, nil
  677. }
  678. // CreateLayer creates a Layer object from a given file
  679. func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
  680. digest, size := GetSHA256Digest(f)
  681. f.Seek(0, io.SeekStart)
  682. layer := &LayerReader{
  683. Layer: Layer{
  684. MediaType: "application/vnd.docker.image.rootfs.diff.tar",
  685. Digest: digest,
  686. Size: size,
  687. },
  688. Reader: f,
  689. }
  690. return layer, nil
  691. }
  692. func CopyModel(src, dest string) error {
  693. srcModelPath := ParseModelPath(src)
  694. srcPath, err := srcModelPath.GetManifestPath(false)
  695. if err != nil {
  696. return err
  697. }
  698. destModelPath := ParseModelPath(dest)
  699. destPath, err := destModelPath.GetManifestPath(true)
  700. if err != nil {
  701. return err
  702. }
  703. // copy the file
  704. input, err := os.ReadFile(srcPath)
  705. if err != nil {
  706. fmt.Println("Error reading file:", err)
  707. return err
  708. }
  709. err = os.WriteFile(destPath, input, 0o644)
  710. if err != nil {
  711. fmt.Println("Error reading file:", err)
  712. return err
  713. }
  714. return nil
  715. }
  716. func DeleteModel(name string) error {
  717. mp := ParseModelPath(name)
  718. manifest, err := GetManifest(mp)
  719. if err != nil {
  720. return err
  721. }
  722. deleteMap := make(map[string]bool)
  723. for _, layer := range manifest.Layers {
  724. deleteMap[layer.Digest] = true
  725. }
  726. deleteMap[manifest.Config.Digest] = true
  727. fp, err := GetManifestPath()
  728. if err != nil {
  729. return err
  730. }
  731. err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
  732. if err != nil {
  733. return err
  734. }
  735. if !info.IsDir() {
  736. path := path[len(fp)+1:]
  737. slashIndex := strings.LastIndex(path, "/")
  738. if slashIndex == -1 {
  739. return nil
  740. }
  741. tag := path[:slashIndex] + ":" + path[slashIndex+1:]
  742. fmp := ParseModelPath(tag)
  743. // skip the manifest we're trying to delete
  744. if mp.GetFullTagname() == fmp.GetFullTagname() {
  745. return nil
  746. }
  747. // save (i.e. delete from the deleteMap) any files used in other manifests
  748. manifest, err := GetManifest(fmp)
  749. if err != nil {
  750. log.Printf("skipping file: %s", fp)
  751. return nil
  752. }
  753. for _, layer := range manifest.Layers {
  754. delete(deleteMap, layer.Digest)
  755. }
  756. delete(deleteMap, manifest.Config.Digest)
  757. }
  758. return nil
  759. })
  760. if err != nil {
  761. return err
  762. }
  763. // only delete the files which are still in the deleteMap
  764. for k, v := range deleteMap {
  765. if v {
  766. fp, err := GetBlobsPath(k)
  767. if err != nil {
  768. log.Printf("couldn't get file path for '%s': %v", k, err)
  769. continue
  770. }
  771. if err := os.Remove(fp); err != nil {
  772. log.Printf("couldn't remove file '%s': %v", fp, err)
  773. continue
  774. }
  775. }
  776. }
  777. fp, err = mp.GetManifestPath(false)
  778. if err != nil {
  779. return err
  780. }
  781. err = os.Remove(fp)
  782. if err != nil {
  783. log.Printf("couldn't remove manifest file '%s': %v", fp, err)
  784. return err
  785. }
  786. return nil
  787. }
  788. func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  789. mp := ParseModelPath(name)
  790. fn(api.ProgressResponse{Status: "retrieving manifest"})
  791. if mp.ProtocolScheme == "http" && !regOpts.Insecure {
  792. return fmt.Errorf("insecure protocol http")
  793. }
  794. manifest, err := GetManifest(mp)
  795. if err != nil {
  796. fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
  797. return err
  798. }
  799. var layers []*Layer
  800. layers = append(layers, manifest.Layers...)
  801. layers = append(layers, &manifest.Config)
  802. for _, layer := range layers {
  803. exists, err := checkBlobExistence(ctx, mp, layer.Digest, regOpts)
  804. if err != nil {
  805. return err
  806. }
  807. if exists {
  808. fn(api.ProgressResponse{
  809. Status: "using existing layer",
  810. Digest: layer.Digest,
  811. Total: layer.Size,
  812. Completed: layer.Size,
  813. })
  814. log.Printf("Layer %s already exists", layer.Digest)
  815. continue
  816. }
  817. fn(api.ProgressResponse{
  818. Status: "starting upload",
  819. Digest: layer.Digest,
  820. Total: layer.Size,
  821. })
  822. location, err := startUpload(ctx, mp, layer, regOpts)
  823. if err != nil {
  824. log.Printf("couldn't start upload: %v", err)
  825. return err
  826. }
  827. if strings.HasPrefix(path.Base(location), "sha256:") {
  828. layer.Digest = path.Base(location)
  829. fn(api.ProgressResponse{
  830. Status: "using existing layer",
  831. Digest: layer.Digest,
  832. Total: layer.Size,
  833. Completed: layer.Size,
  834. })
  835. continue
  836. }
  837. if err := uploadBlobChunked(ctx, mp, location, layer, regOpts, fn); err != nil {
  838. log.Printf("error uploading blob: %v", err)
  839. return err
  840. }
  841. }
  842. fn(api.ProgressResponse{Status: "pushing manifest"})
  843. url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
  844. headers := map[string]string{
  845. "Content-Type": "application/vnd.docker.distribution.manifest.v2+json",
  846. }
  847. manifestJSON, err := json.Marshal(manifest)
  848. if err != nil {
  849. return err
  850. }
  851. resp, err := makeRequestWithRetry(ctx, "PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
  852. if err != nil {
  853. return err
  854. }
  855. defer resp.Body.Close()
  856. fn(api.ProgressResponse{Status: "success"})
  857. return nil
  858. }
  859. func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  860. mp := ParseModelPath(name)
  861. if mp.ProtocolScheme == "http" && !regOpts.Insecure {
  862. return fmt.Errorf("insecure protocol http")
  863. }
  864. fn(api.ProgressResponse{Status: "pulling manifest"})
  865. manifest, err := pullModelManifest(ctx, mp, regOpts)
  866. if err != nil {
  867. return fmt.Errorf("pull model manifest: %s", err)
  868. }
  869. var layers []*Layer
  870. layers = append(layers, manifest.Layers...)
  871. layers = append(layers, &manifest.Config)
  872. for _, layer := range layers {
  873. if err := downloadBlob(
  874. ctx,
  875. downloadOpts{
  876. mp: mp,
  877. digest: layer.Digest,
  878. regOpts: regOpts,
  879. fn: fn,
  880. }); err != nil {
  881. return err
  882. }
  883. }
  884. fn(api.ProgressResponse{Status: "verifying sha256 digest"})
  885. for _, layer := range layers {
  886. if err := verifyBlob(layer.Digest); err != nil {
  887. if errors.Is(err, errDigestMismatch) {
  888. // something went wrong, delete the blob
  889. fp, err := GetBlobsPath(layer.Digest)
  890. if err != nil {
  891. return err
  892. }
  893. if err := os.Remove(fp); err != nil {
  894. // log this, but return the original error
  895. log.Printf("couldn't remove file with digest mismatch '%s': %v", fp, err)
  896. }
  897. }
  898. return err
  899. }
  900. }
  901. fn(api.ProgressResponse{Status: "writing manifest"})
  902. manifestJSON, err := json.Marshal(manifest)
  903. if err != nil {
  904. return err
  905. }
  906. fp, err := mp.GetManifestPath(true)
  907. if err != nil {
  908. return err
  909. }
  910. err = os.WriteFile(fp, manifestJSON, 0o644)
  911. if err != nil {
  912. log.Printf("couldn't write to %s", fp)
  913. return err
  914. }
  915. fn(api.ProgressResponse{Status: "success"})
  916. return nil
  917. }
  918. func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
  919. url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
  920. headers := map[string]string{
  921. "Accept": "application/vnd.docker.distribution.manifest.v2+json",
  922. }
  923. resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
  924. if err != nil {
  925. log.Printf("couldn't get manifest: %v", err)
  926. return nil, err
  927. }
  928. defer resp.Body.Close()
  929. // Check for success: For a successful upload, the Docker registry will respond with a 201 Created
  930. if resp.StatusCode != http.StatusOK {
  931. if resp.StatusCode == http.StatusNotFound {
  932. return nil, fmt.Errorf("model not found")
  933. }
  934. body, _ := io.ReadAll(resp.Body)
  935. return nil, fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
  936. }
  937. var m *ManifestV2
  938. if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
  939. return nil, err
  940. }
  941. return m, err
  942. }
  943. func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
  944. config.RootFS = RootFS{
  945. Type: "layers",
  946. DiffIDs: layers,
  947. }
  948. configJSON, err := json.Marshal(config)
  949. if err != nil {
  950. return nil, err
  951. }
  952. digest, size := GetSHA256Digest(bytes.NewBuffer(configJSON))
  953. layer := &LayerReader{
  954. Layer: Layer{
  955. MediaType: "application/vnd.docker.container.image.v1+json",
  956. Digest: digest,
  957. Size: size,
  958. },
  959. Reader: bytes.NewBuffer(configJSON),
  960. }
  961. return layer, nil
  962. }
  963. // GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
  964. func GetSHA256Digest(r io.Reader) (string, int) {
  965. h := sha256.New()
  966. n, err := io.Copy(h, r)
  967. if err != nil {
  968. log.Fatal(err)
  969. }
  970. return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
  971. }
  972. type requestContextKey string
  973. func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (string, error) {
  974. url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
  975. if layer.From != "" {
  976. url = fmt.Sprintf("%s/v2/%s/blobs/uploads/?mount=%s&from=%s", mp.Registry, mp.GetNamespaceRepository(), layer.Digest, layer.From)
  977. }
  978. resp, err := makeRequestWithRetry(ctx, "POST", url, nil, nil, regOpts)
  979. if err != nil {
  980. log.Printf("couldn't start upload: %v", err)
  981. return "", err
  982. }
  983. defer resp.Body.Close()
  984. // Extract UUID location from header
  985. location := resp.Header.Get("Location")
  986. if location == "" {
  987. return "", fmt.Errorf("location header is missing in response")
  988. }
  989. return location, nil
  990. }
  991. // Function to check if a blob already exists in the Docker registry
  992. func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
  993. url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
  994. resp, err := makeRequest(ctx, "HEAD", url, nil, nil, regOpts)
  995. if err != nil {
  996. log.Printf("couldn't check for blob: %v", err)
  997. return false, err
  998. }
  999. defer resp.Body.Close()
  1000. // Check for success: If the blob exists, the Docker registry will respond with a 200 OK
  1001. return resp.StatusCode == http.StatusOK, nil
  1002. }
  1003. func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  1004. // TODO allow resumability
  1005. // TODO allow canceling uploads via DELETE
  1006. fp, err := GetBlobsPath(layer.Digest)
  1007. if err != nil {
  1008. return err
  1009. }
  1010. f, err := os.Open(fp)
  1011. if err != nil {
  1012. return err
  1013. }
  1014. defer f.Close()
  1015. var completed int64
  1016. chunkSize := 10 * 1024 * 1024
  1017. for {
  1018. chunk := int64(layer.Size) - completed
  1019. if chunk > int64(chunkSize) {
  1020. chunk = int64(chunkSize)
  1021. }
  1022. sectionReader := io.NewSectionReader(f, int64(completed), chunk)
  1023. headers := make(map[string]string)
  1024. headers["Content-Type"] = "application/octet-stream"
  1025. headers["Content-Length"] = strconv.Itoa(int(chunk))
  1026. headers["Content-Range"] = fmt.Sprintf("%d-%d", completed, completed+sectionReader.Size()-1)
  1027. resp, err := makeRequestWithRetry(ctx, "PATCH", url, headers, sectionReader, regOpts)
  1028. if err != nil && !errors.Is(err, io.EOF) {
  1029. fn(api.ProgressResponse{
  1030. Status: fmt.Sprintf("error uploading chunk: %v", err),
  1031. Digest: layer.Digest,
  1032. Total: layer.Size,
  1033. Completed: int(completed),
  1034. })
  1035. return err
  1036. }
  1037. defer resp.Body.Close()
  1038. completed += sectionReader.Size()
  1039. fn(api.ProgressResponse{
  1040. Status: fmt.Sprintf("uploading %s", layer.Digest),
  1041. Digest: layer.Digest,
  1042. Total: layer.Size,
  1043. Completed: int(completed),
  1044. })
  1045. url = resp.Header.Get("Location")
  1046. if completed >= int64(layer.Size) {
  1047. break
  1048. }
  1049. }
  1050. url = fmt.Sprintf("%s&digest=%s", url, layer.Digest)
  1051. headers := make(map[string]string)
  1052. headers["Content-Type"] = "application/octet-stream"
  1053. headers["Content-Length"] = "0"
  1054. // finish the upload
  1055. resp, err := makeRequest(ctx, "PUT", url, headers, nil, regOpts)
  1056. if err != nil {
  1057. log.Printf("couldn't finish upload: %v", err)
  1058. return err
  1059. }
  1060. defer resp.Body.Close()
  1061. if resp.StatusCode != http.StatusCreated {
  1062. body, _ := io.ReadAll(resp.Body)
  1063. return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
  1064. }
  1065. return nil
  1066. }
  1067. func makeRequestWithRetry(ctx context.Context, method, url string, headers map[string]string, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
  1068. var status string
  1069. for try := 0; try < MaxRetries; try++ {
  1070. resp, err := makeRequest(ctx, method, url, headers, body, regOpts)
  1071. if err != nil {
  1072. log.Printf("couldn't start upload: %v", err)
  1073. return nil, err
  1074. }
  1075. status = resp.Status
  1076. switch resp.StatusCode {
  1077. case http.StatusAccepted, http.StatusCreated:
  1078. return resp, nil
  1079. case http.StatusUnauthorized:
  1080. auth := resp.Header.Get("www-authenticate")
  1081. authRedir := ParseAuthRedirectString(auth)
  1082. token, err := getAuthToken(ctx, authRedir, regOpts)
  1083. if err != nil {
  1084. return nil, err
  1085. }
  1086. regOpts.Token = token
  1087. if body != nil {
  1088. if _, err := body.Seek(0, io.SeekStart); err != nil {
  1089. return nil, err
  1090. }
  1091. }
  1092. continue
  1093. default:
  1094. body, _ := io.ReadAll(resp.Body)
  1095. return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
  1096. }
  1097. }
  1098. return nil, fmt.Errorf("max retry exceeded: %v", status)
  1099. }
  1100. func makeRequest(ctx context.Context, method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
  1101. if !strings.HasPrefix(url, "http") {
  1102. if regOpts.Insecure {
  1103. url = "http://" + url
  1104. } else {
  1105. url = "https://" + url
  1106. }
  1107. }
  1108. req, err := http.NewRequestWithContext(ctx, method, url, body)
  1109. if err != nil {
  1110. return nil, err
  1111. }
  1112. if regOpts.Token != "" {
  1113. req.Header.Set("Authorization", "Bearer "+regOpts.Token)
  1114. } else if regOpts.Username != "" && regOpts.Password != "" {
  1115. req.SetBasicAuth(regOpts.Username, regOpts.Password)
  1116. }
  1117. for k, v := range headers {
  1118. req.Header.Set(k, v)
  1119. }
  1120. client := &http.Client{
  1121. CheckRedirect: func(req *http.Request, via []*http.Request) error {
  1122. if len(via) >= 10 {
  1123. return fmt.Errorf("too many redirects")
  1124. }
  1125. log.Printf("redirected to: %s\n", req.URL)
  1126. return nil
  1127. },
  1128. }
  1129. resp, err := client.Do(req)
  1130. if err != nil {
  1131. return nil, err
  1132. }
  1133. return resp, nil
  1134. }
  1135. func getValue(header, key string) string {
  1136. startIdx := strings.Index(header, key+"=")
  1137. if startIdx == -1 {
  1138. return ""
  1139. }
  1140. // Move the index to the starting quote after the key.
  1141. startIdx += len(key) + 2
  1142. endIdx := startIdx
  1143. for endIdx < len(header) {
  1144. if header[endIdx] == '"' {
  1145. if endIdx+1 < len(header) && header[endIdx+1] != ',' { // If the next character isn't a comma, continue
  1146. endIdx++
  1147. continue
  1148. }
  1149. break
  1150. }
  1151. endIdx++
  1152. }
  1153. return header[startIdx:endIdx]
  1154. }
  1155. func ParseAuthRedirectString(authStr string) AuthRedirect {
  1156. authStr = strings.TrimPrefix(authStr, "Bearer ")
  1157. return AuthRedirect{
  1158. Realm: getValue(authStr, "realm"),
  1159. Service: getValue(authStr, "service"),
  1160. Scope: getValue(authStr, "scope"),
  1161. }
  1162. }
  1163. var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again")
  1164. func verifyBlob(digest string) error {
  1165. fp, err := GetBlobsPath(digest)
  1166. if err != nil {
  1167. return err
  1168. }
  1169. f, err := os.Open(fp)
  1170. if err != nil {
  1171. return err
  1172. }
  1173. defer f.Close()
  1174. fileDigest, _ := GetSHA256Digest(f)
  1175. if digest != fileDigest {
  1176. return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest)
  1177. }
  1178. return nil
  1179. }