images.go 35 KB


  1. package server
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "crypto/sha256"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "html/template"
  11. "io"
  12. "log"
  13. "net/http"
  14. "os"
  15. "path"
  16. "path/filepath"
  17. "reflect"
  18. "strconv"
  19. "strings"
  20. "github.com/jmorganca/ollama/api"
  21. "github.com/jmorganca/ollama/llm"
  22. "github.com/jmorganca/ollama/parser"
  23. "github.com/jmorganca/ollama/vector"
  24. )
  25. const MaxRetries = 3
  26. type RegistryOptions struct {
  27. Insecure bool
  28. Username string
  29. Password string
  30. Token string
  31. }
  32. type Model struct {
  33. Name string `json:"name"`
  34. ModelPath string
  35. AdapterPaths []string
  36. Template string
  37. System string
  38. Digest string
  39. Options map[string]interface{}
  40. Embeddings []vector.Embedding
  41. }
  42. func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
  43. t := m.Template
  44. if request.Template != "" {
  45. t = request.Template
  46. }
  47. tmpl, err := template.New("").Parse(t)
  48. if err != nil {
  49. return "", err
  50. }
  51. var vars struct {
  52. First bool
  53. System string
  54. Prompt string
  55. Embed string
  56. // deprecated: versions <= 0.0.7 used this to omit the system prompt
  57. Context []int
  58. }
  59. vars.First = len(request.Context) == 0
  60. vars.System = m.System
  61. vars.Prompt = request.Prompt
  62. vars.Context = request.Context
  63. vars.Embed = embedding
  64. if request.System != "" {
  65. vars.System = request.System
  66. }
  67. var sb strings.Builder
  68. if err := tmpl.Execute(&sb, vars); err != nil {
  69. return "", err
  70. }
  71. return sb.String(), nil
  72. }
  73. type ManifestV2 struct {
  74. SchemaVersion int `json:"schemaVersion"`
  75. MediaType string `json:"mediaType"`
  76. Config Layer `json:"config"`
  77. Layers []*Layer `json:"layers"`
  78. }
  79. type Layer struct {
  80. MediaType string `json:"mediaType"`
  81. Digest string `json:"digest"`
  82. Size int `json:"size"`
  83. From string `json:"from,omitempty"`
  84. }
  85. type LayerReader struct {
  86. Layer
  87. io.Reader
  88. }
  89. type ConfigV2 struct {
  90. ModelFamily llm.ModelFamily `json:"model_family"`
  91. ModelType string `json:"model_type"`
  92. FileType string `json:"file_type"`
  93. RootFS RootFS `json:"rootfs"`
  94. // required by spec
  95. Architecture string `json:"architecture"`
  96. OS string `json:"os"`
  97. }
  98. type RootFS struct {
  99. Type string `json:"type"`
  100. DiffIDs []string `json:"diff_ids"`
  101. }
  102. func (m *ManifestV2) GetTotalSize() int {
  103. var total int
  104. for _, layer := range m.Layers {
  105. total += layer.Size
  106. }
  107. total += m.Config.Size
  108. return total
  109. }
  110. func GetManifest(mp ModelPath) (*ManifestV2, error) {
  111. fp, err := mp.GetManifestPath(false)
  112. if err != nil {
  113. return nil, err
  114. }
  115. if _, err = os.Stat(fp); err != nil {
  116. return nil, err
  117. }
  118. var manifest *ManifestV2
  119. bts, err := os.ReadFile(fp)
  120. if err != nil {
  121. return nil, fmt.Errorf("couldn't open file '%s'", fp)
  122. }
  123. if err := json.Unmarshal(bts, &manifest); err != nil {
  124. return nil, err
  125. }
  126. return manifest, nil
  127. }
  128. func GetModel(name string) (*Model, error) {
  129. mp, err := ParseModelPath(name, false)
  130. if err != nil {
  131. return nil, err
  132. }
  133. manifest, err := GetManifest(mp)
  134. if err != nil {
  135. return nil, err
  136. }
  137. model := &Model{
  138. Name: mp.GetFullTagname(),
  139. Digest: manifest.Config.Digest,
  140. }
  141. for _, layer := range manifest.Layers {
  142. filename, err := GetBlobsPath(layer.Digest)
  143. if err != nil {
  144. return nil, err
  145. }
  146. switch layer.MediaType {
  147. case "application/vnd.ollama.image.model":
  148. model.ModelPath = filename
  149. case "application/vnd.ollama.image.embed":
  150. file, err := os.Open(filename)
  151. if err != nil {
  152. return nil, fmt.Errorf("failed to open file: %s", filename)
  153. }
  154. defer file.Close()
  155. if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
  156. return nil, err
  157. }
  158. case "application/vnd.ollama.image.adapter":
  159. model.AdapterPaths = append(model.AdapterPaths, filename)
  160. case "application/vnd.ollama.image.template":
  161. bts, err := os.ReadFile(filename)
  162. if err != nil {
  163. return nil, err
  164. }
  165. model.Template = string(bts)
  166. case "application/vnd.ollama.image.system":
  167. bts, err := os.ReadFile(filename)
  168. if err != nil {
  169. return nil, err
  170. }
  171. model.System = string(bts)
  172. case "application/vnd.ollama.image.prompt":
  173. bts, err := os.ReadFile(filename)
  174. if err != nil {
  175. return nil, err
  176. }
  177. model.Template = string(bts)
  178. case "application/vnd.ollama.image.params":
  179. params, err := os.Open(filename)
  180. if err != nil {
  181. return nil, err
  182. }
  183. defer params.Close()
  184. // parse model options parameters into a map so that we can see which fields have been specified explicitly
  185. if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
  186. return nil, err
  187. }
  188. }
  189. }
  190. return model, nil
  191. }
  192. func filenameWithPath(path, f string) (string, error) {
  193. // if filePath starts with ~/, replace it with the user's home directory.
  194. if strings.HasPrefix(f, "~/") {
  195. parts := strings.Split(f, "/")
  196. home, err := os.UserHomeDir()
  197. if err != nil {
  198. return "", fmt.Errorf("failed to open file: %v", err)
  199. }
  200. f = filepath.Join(home, filepath.Join(parts[1:]...))
  201. }
  202. // if filePath is not an absolute path, make it relative to the modelfile path
  203. if !filepath.IsAbs(f) {
  204. f = filepath.Join(filepath.Dir(path), f)
  205. }
  206. return f, nil
  207. }
  208. func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error {
  209. mf, err := os.Open(path)
  210. if err != nil {
  211. fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)})
  212. return fmt.Errorf("failed to open file: %w", err)
  213. }
  214. defer mf.Close()
  215. fn(api.ProgressResponse{Status: "parsing modelfile"})
  216. commands, err := parser.Parse(mf)
  217. if err != nil {
  218. return err
  219. }
  220. config := ConfigV2{
  221. Architecture: "amd64",
  222. OS: "linux",
  223. }
  224. var layers []*LayerReader
  225. params := make(map[string][]string)
  226. embed := EmbeddingParams{fn: fn}
  227. for _, c := range commands {
  228. log.Printf("[%s] - %s\n", c.Name, c.Args)
  229. switch c.Name {
  230. case "model":
  231. fn(api.ProgressResponse{Status: "looking for model"})
  232. embed.model = c.Args
  233. mp, err := ParseModelPath(c.Args, false)
  234. if err != nil {
  235. return err
  236. }
  237. mf, err := GetManifest(mp)
  238. if err != nil {
  239. modelFile, err := filenameWithPath(path, c.Args)
  240. if err != nil {
  241. return err
  242. }
  243. if _, err := os.Stat(modelFile); err != nil {
  244. // the model file does not exist, try pulling it
  245. if errors.Is(err, os.ErrNotExist) {
  246. fn(api.ProgressResponse{Status: "pulling model file"})
  247. if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
  248. return err
  249. }
  250. mf, err = GetManifest(mp)
  251. if err != nil {
  252. return fmt.Errorf("failed to open file after pull: %v", err)
  253. }
  254. } else {
  255. return err
  256. }
  257. } else {
  258. embed.model = modelFile
  259. // create a model from this specified file
  260. fn(api.ProgressResponse{Status: "creating model layer"})
  261. file, err := os.Open(modelFile)
  262. if err != nil {
  263. return fmt.Errorf("failed to open file: %v", err)
  264. }
  265. defer file.Close()
  266. ggml, err := llm.DecodeGGML(file, llm.ModelFamilyLlama)
  267. if err != nil {
  268. return err
  269. }
  270. config.ModelFamily = ggml.ModelFamily()
  271. config.ModelType = ggml.ModelType().String()
  272. config.FileType = ggml.FileType().String()
  273. // reset the file
  274. file.Seek(0, io.SeekStart)
  275. l, err := CreateLayer(file)
  276. if err != nil {
  277. return fmt.Errorf("failed to create layer: %v", err)
  278. }
  279. l.MediaType = "application/vnd.ollama.image.model"
  280. layers = append(layers, l)
  281. }
  282. }
  283. if mf != nil {
  284. sourceBlobPath, err := GetBlobsPath(mf.Config.Digest)
  285. if err != nil {
  286. return err
  287. }
  288. sourceBlob, err := os.Open(sourceBlobPath)
  289. if err != nil {
  290. return err
  291. }
  292. defer sourceBlob.Close()
  293. var source ConfigV2
  294. if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil {
  295. return err
  296. }
  297. // copie the model metadata
  298. config.ModelFamily = source.ModelFamily
  299. config.ModelType = source.ModelType
  300. config.FileType = source.FileType
  301. for _, l := range mf.Layers {
  302. newLayer, err := GetLayerWithBufferFromLayer(l)
  303. if err != nil {
  304. return err
  305. }
  306. newLayer.From = mp.GetNamespaceRepository()
  307. layers = append(layers, newLayer)
  308. }
  309. }
  310. case "embed":
  311. embedFilePath, err := filenameWithPath(path, c.Args)
  312. if err != nil {
  313. return err
  314. }
  315. embed.files = append(embed.files, embedFilePath)
  316. case "adapter":
  317. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  318. fp := c.Args
  319. if strings.HasPrefix(fp, "~/") {
  320. parts := strings.Split(fp, "/")
  321. home, err := os.UserHomeDir()
  322. if err != nil {
  323. return fmt.Errorf("failed to open file: %v", err)
  324. }
  325. fp = filepath.Join(home, filepath.Join(parts[1:]...))
  326. }
  327. // If filePath is not an absolute path, make it relative to the modelfile path
  328. if !filepath.IsAbs(fp) {
  329. fp = filepath.Join(filepath.Dir(path), fp)
  330. }
  331. // create a model from this specified file
  332. fn(api.ProgressResponse{Status: "creating model layer"})
  333. file, err := os.Open(fp)
  334. if err != nil {
  335. return fmt.Errorf("failed to open file: %v", err)
  336. }
  337. defer file.Close()
  338. l, err := CreateLayer(file)
  339. if err != nil {
  340. return fmt.Errorf("failed to create layer: %v", err)
  341. }
  342. l.MediaType = "application/vnd.ollama.image.adapter"
  343. layers = append(layers, l)
  344. case "license":
  345. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  346. mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
  347. layer, err := CreateLayer(strings.NewReader(c.Args))
  348. if err != nil {
  349. return err
  350. }
  351. layer.MediaType = mediaType
  352. layers = append(layers, layer)
  353. case "template", "system", "prompt":
  354. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  355. // remove the layer if one exists
  356. mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
  357. layers = removeLayerFromLayers(layers, mediaType)
  358. layer, err := CreateLayer(strings.NewReader(c.Args))
  359. if err != nil {
  360. return err
  361. }
  362. layer.MediaType = mediaType
  363. layers = append(layers, layer)
  364. default:
  365. // runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop tokens)
  366. params[c.Name] = append(params[c.Name], c.Args)
  367. }
  368. }
  369. // Create a single layer for the parameters
  370. if len(params) > 0 {
  371. fn(api.ProgressResponse{Status: "creating parameter layer"})
  372. layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
  373. formattedParams, err := formatParams(params)
  374. if err != nil {
  375. return fmt.Errorf("couldn't create params json: %v", err)
  376. }
  377. bts, err := json.Marshal(formattedParams)
  378. if err != nil {
  379. return err
  380. }
  381. l, err := CreateLayer(bytes.NewReader(bts))
  382. if err != nil {
  383. return fmt.Errorf("failed to create layer: %v", err)
  384. }
  385. l.MediaType = "application/vnd.ollama.image.params"
  386. layers = append(layers, l)
  387. // apply these parameters to the embedding options, in case embeddings need to be generated using this model
  388. embed.opts = formattedParams
  389. }
  390. // generate the embedding layers
  391. embeddingLayers, err := embeddingLayers(embed)
  392. if err != nil {
  393. return err
  394. }
  395. layers = append(layers, embeddingLayers...)
  396. digests, err := getLayerDigests(layers)
  397. if err != nil {
  398. return err
  399. }
  400. var manifestLayers []*Layer
  401. for _, l := range layers {
  402. manifestLayers = append(manifestLayers, &l.Layer)
  403. }
  404. // Create a layer for the config object
  405. fn(api.ProgressResponse{Status: "creating config layer"})
  406. cfg, err := createConfigLayer(config, digests)
  407. if err != nil {
  408. return err
  409. }
  410. layers = append(layers, cfg)
  411. if err := SaveLayers(layers, fn, false); err != nil {
  412. return err
  413. }
  414. // Create the manifest
  415. fn(api.ProgressResponse{Status: "writing manifest"})
  416. err = CreateManifest(name, cfg, manifestLayers)
  417. if err != nil {
  418. return err
  419. }
  420. fn(api.ProgressResponse{Status: "success"})
  421. return nil
  422. }
  423. type EmbeddingParams struct {
  424. model string
  425. opts map[string]interface{}
  426. files []string // paths to files to embed
  427. fn func(resp api.ProgressResponse)
  428. }
  429. // embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
  430. func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
  431. layers := []*LayerReader{}
  432. if len(e.files) > 0 {
  433. // check if the model is a file path or a model name
  434. model, err := GetModel(e.model)
  435. if err != nil {
  436. if !strings.Contains(err.Error(), "couldn't open file") {
  437. return nil, fmt.Errorf("unexpected error opening model to generate embeddings: %v", err)
  438. }
  439. // the model may be a file path, create a model from this file
  440. model = &Model{ModelPath: e.model}
  441. }
  442. if err := load(model, e.opts, defaultSessionDuration); err != nil {
  443. return nil, fmt.Errorf("load model to generate embeddings: %v", err)
  444. }
  445. // this will be used to check if we already have embeddings for a file
  446. modelInfo, err := os.Stat(model.ModelPath)
  447. if err != nil {
  448. return nil, fmt.Errorf("failed to get model file info: %v", err)
  449. }
  450. addedFiles := make(map[string]bool) // keep track of files that have already been added
  451. for _, filePattern := range e.files {
  452. matchingFiles, err := filepath.Glob(filePattern)
  453. if err != nil {
  454. return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
  455. }
  456. for _, filePath := range matchingFiles {
  457. if addedFiles[filePath] {
  458. continue
  459. }
  460. addedFiles[filePath] = true
  461. // check if we already have embeddings for this file path
  462. layerIdentifier := fmt.Sprintf("%s:%s:%s:%d", filePath, e.model, modelInfo.ModTime().Format("2006-01-02 15:04:05"), modelInfo.Size())
  463. digest, _ := GetSHA256Digest(strings.NewReader(layerIdentifier))
  464. existing, err := existingFileEmbeddings(digest)
  465. if err != nil {
  466. return nil, fmt.Errorf("failed to check existing embeddings for file %s: %v", filePath, err)
  467. }
  468. // TODO: check file type
  469. f, err := os.Open(filePath)
  470. if err != nil {
  471. return nil, fmt.Errorf("could not open embed file: %w", err)
  472. }
  473. scanner := bufio.NewScanner(f)
  474. scanner.Split(bufio.ScanLines)
  475. data := []string{}
  476. for scanner.Scan() {
  477. data = append(data, scanner.Text())
  478. }
  479. f.Close()
  480. // the digest of the file is set here so that the client knows a new operation is in progress
  481. fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
  482. embeddings := []vector.Embedding{}
  483. for i, d := range data {
  484. if strings.TrimSpace(d) == "" {
  485. continue
  486. }
  487. e.fn(api.ProgressResponse{
  488. Status: fmt.Sprintf("creating embeddings for file %s", filePath),
  489. Digest: fileDigest,
  490. Total: len(data) - 1,
  491. Completed: i,
  492. })
  493. if len(existing[d]) > 0 {
  494. // already have an embedding for this line
  495. embeddings = append(embeddings, vector.Embedding{Data: d, Vector: existing[d]})
  496. continue
  497. }
  498. embed, err := loaded.llm.Embedding(d)
  499. if err != nil {
  500. log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
  501. continue
  502. }
  503. embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
  504. }
  505. b, err := json.Marshal(embeddings)
  506. if err != nil {
  507. return nil, fmt.Errorf("failed to encode embeddings: %w", err)
  508. }
  509. r := bytes.NewReader(b)
  510. layer := &LayerReader{
  511. Layer: Layer{
  512. MediaType: "application/vnd.ollama.image.embed",
  513. Digest: digest,
  514. Size: r.Len(),
  515. },
  516. Reader: r,
  517. }
  518. layers = append(layers, layer)
  519. }
  520. }
  521. }
  522. return layers, nil
  523. }
  524. // existingFileEmbeddings checks if we already have embeddings for a file and loads them into a look-up map
  525. func existingFileEmbeddings(digest string) (map[string][]float64, error) {
  526. path, err := GetBlobsPath(digest)
  527. if err != nil {
  528. return nil, fmt.Errorf("embeddings blobs path: %w", err)
  529. }
  530. existingFileEmbeddings := make(map[string][]float64)
  531. if _, err := os.Stat(path); err == nil {
  532. // already have some embeddings for this file, load embeddings previously generated
  533. file, err := os.Open(path)
  534. if err != nil {
  535. return nil, fmt.Errorf("failed to open existing embedding file: %s", err)
  536. }
  537. defer file.Close()
  538. existing := []vector.Embedding{}
  539. if err = json.NewDecoder(file).Decode(&existing); err != nil {
  540. return nil, err
  541. }
  542. for _, e := range existing {
  543. existingFileEmbeddings[e.Data] = e.Vector
  544. }
  545. }
  546. return existingFileEmbeddings, nil
  547. }
  548. func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
  549. j := 0
  550. for _, l := range layers {
  551. if l.MediaType != mediaType {
  552. layers[j] = l
  553. j++
  554. }
  555. }
  556. return layers[:j]
  557. }
  558. func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
  559. // Write each of the layers to disk
  560. for _, layer := range layers {
  561. fp, err := GetBlobsPath(layer.Digest)
  562. if err != nil {
  563. return err
  564. }
  565. _, err = os.Stat(fp)
  566. // note: embed layers are always written since their digest doesnt indicate anything about the contents
  567. if os.IsNotExist(err) || force || layer.MediaType == "application/vnd.ollama.image.embed" {
  568. fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
  569. out, err := os.Create(fp)
  570. if err != nil {
  571. log.Printf("couldn't create %s", fp)
  572. return err
  573. }
  574. defer out.Close()
  575. if _, err = io.Copy(out, layer.Reader); err != nil {
  576. return err
  577. }
  578. } else {
  579. fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
  580. }
  581. }
  582. return nil
  583. }
  584. func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
  585. mp, err := ParseModelPath(name, false)
  586. if err != nil {
  587. return err
  588. }
  589. manifest := ManifestV2{
  590. SchemaVersion: 2,
  591. MediaType: "application/vnd.docker.distribution.manifest.v2+json",
  592. Config: Layer{
  593. MediaType: cfg.MediaType,
  594. Size: cfg.Size,
  595. Digest: cfg.Digest,
  596. },
  597. Layers: layers,
  598. }
  599. manifestJSON, err := json.Marshal(manifest)
  600. if err != nil {
  601. return err
  602. }
  603. fp, err := mp.GetManifestPath(true)
  604. if err != nil {
  605. return err
  606. }
  607. return os.WriteFile(fp, manifestJSON, 0o644)
  608. }
  609. func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
  610. fp, err := GetBlobsPath(layer.Digest)
  611. if err != nil {
  612. return nil, err
  613. }
  614. file, err := os.Open(fp)
  615. if err != nil {
  616. return nil, fmt.Errorf("could not open blob: %w", err)
  617. }
  618. defer file.Close()
  619. newLayer, err := CreateLayer(file)
  620. if err != nil {
  621. return nil, err
  622. }
  623. newLayer.MediaType = layer.MediaType
  624. return newLayer, nil
  625. }
  626. // formatParams converts specified parameter options to their correct types
  627. func formatParams(params map[string][]string) (map[string]interface{}, error) {
  628. opts := api.Options{}
  629. valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
  630. typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
  631. // build map of json struct tags to their types
  632. jsonOpts := make(map[string]reflect.StructField)
  633. for _, field := range reflect.VisibleFields(typeOpts) {
  634. jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
  635. if jsonTag != "" {
  636. jsonOpts[jsonTag] = field
  637. }
  638. }
  639. out := make(map[string]interface{})
  640. // iterate params and set values based on json struct tags
  641. for key, vals := range params {
  642. if opt, ok := jsonOpts[key]; ok {
  643. field := valueOpts.FieldByName(opt.Name)
  644. if field.IsValid() && field.CanSet() {
  645. switch field.Kind() {
  646. case reflect.Float32:
  647. floatVal, err := strconv.ParseFloat(vals[0], 32)
  648. if err != nil {
  649. return nil, fmt.Errorf("invalid float value %s", vals)
  650. }
  651. out[key] = floatVal
  652. case reflect.Int:
  653. intVal, err := strconv.ParseInt(vals[0], 10, 0)
  654. if err != nil {
  655. return nil, fmt.Errorf("invalid int value %s", vals)
  656. }
  657. out[key] = intVal
  658. case reflect.Bool:
  659. boolVal, err := strconv.ParseBool(vals[0])
  660. if err != nil {
  661. return nil, fmt.Errorf("invalid bool value %s", vals)
  662. }
  663. out[key] = boolVal
  664. case reflect.String:
  665. out[key] = vals[0]
  666. case reflect.Slice:
  667. // TODO: only string slices are supported right now
  668. out[key] = vals
  669. default:
  670. return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
  671. }
  672. }
  673. }
  674. }
  675. return out, nil
  676. }
  677. func getLayerDigests(layers []*LayerReader) ([]string, error) {
  678. var digests []string
  679. for _, l := range layers {
  680. if l.Digest == "" {
  681. return nil, fmt.Errorf("layer is missing a digest")
  682. }
  683. digests = append(digests, l.Digest)
  684. }
  685. return digests, nil
  686. }
  687. // CreateLayer creates a Layer object from a given file
  688. func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
  689. digest, size := GetSHA256Digest(f)
  690. f.Seek(0, io.SeekStart)
  691. layer := &LayerReader{
  692. Layer: Layer{
  693. MediaType: "application/vnd.docker.image.rootfs.diff.tar",
  694. Digest: digest,
  695. Size: size,
  696. },
  697. Reader: f,
  698. }
  699. return layer, nil
  700. }
  701. func CopyModel(src, dest string) error {
  702. srcModelPath, err := ParseModelPath(src, false)
  703. if err != nil {
  704. return err
  705. }
  706. srcPath, err := srcModelPath.GetManifestPath(false)
  707. if err != nil {
  708. return err
  709. }
  710. destModelPath, err := ParseModelPath(dest, false)
  711. if err != nil {
  712. return err
  713. }
  714. destPath, err := destModelPath.GetManifestPath(true)
  715. if err != nil {
  716. return err
  717. }
  718. // copy the file
  719. input, err := os.ReadFile(srcPath)
  720. if err != nil {
  721. fmt.Println("Error reading file:", err)
  722. return err
  723. }
  724. err = os.WriteFile(destPath, input, 0o644)
  725. if err != nil {
  726. fmt.Println("Error reading file:", err)
  727. return err
  728. }
  729. return nil
  730. }
  731. func DeleteModel(name string) error {
  732. mp, err := ParseModelPath(name, false)
  733. if err != nil {
  734. return err
  735. }
  736. manifest, err := GetManifest(mp)
  737. if err != nil {
  738. return err
  739. }
  740. deleteMap := make(map[string]bool)
  741. for _, layer := range manifest.Layers {
  742. deleteMap[layer.Digest] = true
  743. }
  744. deleteMap[manifest.Config.Digest] = true
  745. fp, err := GetManifestPath()
  746. if err != nil {
  747. return err
  748. }
  749. err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
  750. if err != nil {
  751. return err
  752. }
  753. if !info.IsDir() {
  754. path := path[len(fp)+1:]
  755. slashIndex := strings.LastIndex(path, "/")
  756. if slashIndex == -1 {
  757. return nil
  758. }
  759. tag := path[:slashIndex] + ":" + path[slashIndex+1:]
  760. fmp, err := ParseModelPath(tag, false)
  761. if err != nil {
  762. return err
  763. }
  764. // skip the manifest we're trying to delete
  765. if mp.GetFullTagname() == fmp.GetFullTagname() {
  766. return nil
  767. }
  768. // save (i.e. delete from the deleteMap) any files used in other manifests
  769. manifest, err := GetManifest(fmp)
  770. if err != nil {
  771. log.Printf("skipping file: %s", fp)
  772. return nil
  773. }
  774. for _, layer := range manifest.Layers {
  775. delete(deleteMap, layer.Digest)
  776. }
  777. delete(deleteMap, manifest.Config.Digest)
  778. }
  779. return nil
  780. })
  781. if err != nil {
  782. return err
  783. }
  784. // only delete the files which are still in the deleteMap
  785. for k, v := range deleteMap {
  786. if v {
  787. fp, err := GetBlobsPath(k)
  788. if err != nil {
  789. log.Printf("couldn't get file path for '%s': %v", k, err)
  790. continue
  791. }
  792. if err := os.Remove(fp); err != nil {
  793. log.Printf("couldn't remove file '%s': %v", fp, err)
  794. continue
  795. }
  796. }
  797. }
  798. fp, err = mp.GetManifestPath(false)
  799. if err != nil {
  800. return err
  801. }
  802. err = os.Remove(fp)
  803. if err != nil {
  804. log.Printf("couldn't remove manifest file '%s': %v", fp, err)
  805. return err
  806. }
  807. return nil
  808. }
  809. func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  810. mp, err := ParseModelPath(name, regOpts.Insecure)
  811. if err != nil {
  812. return err
  813. }
  814. fn(api.ProgressResponse{Status: "retrieving manifest"})
  815. manifest, err := GetManifest(mp)
  816. if err != nil {
  817. fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
  818. return err
  819. }
  820. var layers []*Layer
  821. layers = append(layers, manifest.Layers...)
  822. layers = append(layers, &manifest.Config)
  823. for _, layer := range layers {
  824. exists, err := checkBlobExistence(ctx, mp, layer.Digest, regOpts)
  825. if err != nil {
  826. return err
  827. }
  828. if exists {
  829. fn(api.ProgressResponse{
  830. Status: "using existing layer",
  831. Digest: layer.Digest,
  832. Total: layer.Size,
  833. Completed: layer.Size,
  834. })
  835. log.Printf("Layer %s already exists", layer.Digest)
  836. continue
  837. }
  838. fn(api.ProgressResponse{
  839. Status: "starting upload",
  840. Digest: layer.Digest,
  841. Total: layer.Size,
  842. })
  843. location, err := startUpload(ctx, mp, layer, regOpts)
  844. if err != nil {
  845. log.Printf("couldn't start upload: %v", err)
  846. return err
  847. }
  848. if strings.HasPrefix(path.Base(location), "sha256:") {
  849. layer.Digest = path.Base(location)
  850. fn(api.ProgressResponse{
  851. Status: "using existing layer",
  852. Digest: layer.Digest,
  853. Total: layer.Size,
  854. Completed: layer.Size,
  855. })
  856. continue
  857. }
  858. if err := uploadBlobChunked(ctx, mp, location, layer, regOpts, fn); err != nil {
  859. log.Printf("error uploading blob: %v", err)
  860. return err
  861. }
  862. }
  863. fn(api.ProgressResponse{Status: "pushing manifest"})
  864. url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
  865. headers := map[string]string{
  866. "Content-Type": "application/vnd.docker.distribution.manifest.v2+json",
  867. }
  868. manifestJSON, err := json.Marshal(manifest)
  869. if err != nil {
  870. return err
  871. }
  872. resp, err := makeRequestWithRetry(ctx, "PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
  873. if err != nil {
  874. return err
  875. }
  876. defer resp.Body.Close()
  877. fn(api.ProgressResponse{Status: "success"})
  878. return nil
  879. }
  880. func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  881. mp, err := ParseModelPath(name, regOpts.Insecure)
  882. if err != nil {
  883. return err
  884. }
  885. fn(api.ProgressResponse{Status: "pulling manifest"})
  886. manifest, err := pullModelManifest(ctx, mp, regOpts)
  887. if err != nil {
  888. return fmt.Errorf("pull model manifest: %s", err)
  889. }
  890. var layers []*Layer
  891. layers = append(layers, manifest.Layers...)
  892. layers = append(layers, &manifest.Config)
  893. for _, layer := range layers {
  894. if err := downloadBlob(
  895. ctx,
  896. downloadOpts{
  897. mp: mp,
  898. digest: layer.Digest,
  899. regOpts: regOpts,
  900. fn: fn,
  901. }); err != nil {
  902. return err
  903. }
  904. }
  905. fn(api.ProgressResponse{Status: "verifying sha256 digest"})
  906. for _, layer := range layers {
  907. if err := verifyBlob(layer.Digest); err != nil {
  908. if errors.Is(err, errDigestMismatch) {
  909. // something went wrong, delete the blob
  910. fp, err := GetBlobsPath(layer.Digest)
  911. if err != nil {
  912. return err
  913. }
  914. if err := os.Remove(fp); err != nil {
  915. // log this, but return the original error
  916. log.Printf("couldn't remove file with digest mismatch '%s': %v", fp, err)
  917. }
  918. }
  919. return err
  920. }
  921. }
  922. fn(api.ProgressResponse{Status: "writing manifest"})
  923. manifestJSON, err := json.Marshal(manifest)
  924. if err != nil {
  925. return err
  926. }
  927. fp, err := mp.GetManifestPath(true)
  928. if err != nil {
  929. return err
  930. }
  931. err = os.WriteFile(fp, manifestJSON, 0o644)
  932. if err != nil {
  933. log.Printf("couldn't write to %s", fp)
  934. return err
  935. }
  936. fn(api.ProgressResponse{Status: "success"})
  937. return nil
  938. }
  939. func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
  940. url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
  941. headers := map[string]string{
  942. "Accept": "application/vnd.docker.distribution.manifest.v2+json",
  943. }
  944. resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
  945. if err != nil {
  946. log.Printf("couldn't get manifest: %v", err)
  947. return nil, err
  948. }
  949. defer resp.Body.Close()
  950. // Check for success: For a successful upload, the Docker registry will respond with a 201 Created
  951. if resp.StatusCode != http.StatusOK {
  952. if resp.StatusCode == http.StatusNotFound {
  953. return nil, fmt.Errorf("model not found")
  954. }
  955. body, _ := io.ReadAll(resp.Body)
  956. return nil, fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
  957. }
  958. var m *ManifestV2
  959. if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
  960. return nil, err
  961. }
  962. return m, err
  963. }
  964. func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
  965. config.RootFS = RootFS{
  966. Type: "layers",
  967. DiffIDs: layers,
  968. }
  969. configJSON, err := json.Marshal(config)
  970. if err != nil {
  971. return nil, err
  972. }
  973. digest, size := GetSHA256Digest(bytes.NewBuffer(configJSON))
  974. layer := &LayerReader{
  975. Layer: Layer{
  976. MediaType: "application/vnd.docker.container.image.v1+json",
  977. Digest: digest,
  978. Size: size,
  979. },
  980. Reader: bytes.NewBuffer(configJSON),
  981. }
  982. return layer, nil
  983. }
  984. // GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
  985. func GetSHA256Digest(r io.Reader) (string, int) {
  986. h := sha256.New()
  987. n, err := io.Copy(h, r)
  988. if err != nil {
  989. log.Fatal(err)
  990. }
  991. return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
  992. }
  993. type requestContextKey string
  994. func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (string, error) {
  995. url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
  996. if layer.From != "" {
  997. url = fmt.Sprintf("%s/v2/%s/blobs/uploads/?mount=%s&from=%s", mp.Registry, mp.GetNamespaceRepository(), layer.Digest, layer.From)
  998. }
  999. resp, err := makeRequestWithRetry(ctx, "POST", url, nil, nil, regOpts)
  1000. if err != nil {
  1001. log.Printf("couldn't start upload: %v", err)
  1002. return "", err
  1003. }
  1004. defer resp.Body.Close()
  1005. // Extract UUID location from header
  1006. location := resp.Header.Get("Location")
  1007. if location == "" {
  1008. return "", fmt.Errorf("location header is missing in response")
  1009. }
  1010. return location, nil
  1011. }
  1012. // Function to check if a blob already exists in the Docker registry
  1013. func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
  1014. url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
  1015. resp, err := makeRequest(ctx, "HEAD", url, nil, nil, regOpts)
  1016. if err != nil {
  1017. log.Printf("couldn't check for blob: %v", err)
  1018. return false, err
  1019. }
  1020. defer resp.Body.Close()
  1021. // Check for success: If the blob exists, the Docker registry will respond with a 200 OK
  1022. return resp.StatusCode == http.StatusOK, nil
  1023. }
  1024. func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  1025. // TODO allow resumability
  1026. // TODO allow canceling uploads via DELETE
  1027. fp, err := GetBlobsPath(layer.Digest)
  1028. if err != nil {
  1029. return err
  1030. }
  1031. f, err := os.Open(fp)
  1032. if err != nil {
  1033. return err
  1034. }
  1035. defer f.Close()
  1036. var completed int64
  1037. chunkSize := 10 * 1024 * 1024
  1038. for {
  1039. chunk := int64(layer.Size) - completed
  1040. if chunk > int64(chunkSize) {
  1041. chunk = int64(chunkSize)
  1042. }
  1043. sectionReader := io.NewSectionReader(f, int64(completed), chunk)
  1044. headers := make(map[string]string)
  1045. headers["Content-Type"] = "application/octet-stream"
  1046. headers["Content-Length"] = strconv.Itoa(int(chunk))
  1047. headers["Content-Range"] = fmt.Sprintf("%d-%d", completed, completed+sectionReader.Size()-1)
  1048. resp, err := makeRequestWithRetry(ctx, "PATCH", url, headers, sectionReader, regOpts)
  1049. if err != nil && !errors.Is(err, io.EOF) {
  1050. fn(api.ProgressResponse{
  1051. Status: fmt.Sprintf("error uploading chunk: %v", err),
  1052. Digest: layer.Digest,
  1053. Total: layer.Size,
  1054. Completed: int(completed),
  1055. })
  1056. return err
  1057. }
  1058. defer resp.Body.Close()
  1059. completed += sectionReader.Size()
  1060. fn(api.ProgressResponse{
  1061. Status: fmt.Sprintf("uploading %s", layer.Digest),
  1062. Digest: layer.Digest,
  1063. Total: layer.Size,
  1064. Completed: int(completed),
  1065. })
  1066. url = resp.Header.Get("Location")
  1067. if completed >= int64(layer.Size) {
  1068. break
  1069. }
  1070. }
  1071. url = fmt.Sprintf("%s&digest=%s", url, layer.Digest)
  1072. headers := make(map[string]string)
  1073. headers["Content-Type"] = "application/octet-stream"
  1074. headers["Content-Length"] = "0"
  1075. // finish the upload
  1076. resp, err := makeRequest(ctx, "PUT", url, headers, nil, regOpts)
  1077. if err != nil {
  1078. log.Printf("couldn't finish upload: %v", err)
  1079. return err
  1080. }
  1081. defer resp.Body.Close()
  1082. if resp.StatusCode != http.StatusCreated {
  1083. body, _ := io.ReadAll(resp.Body)
  1084. return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
  1085. }
  1086. return nil
  1087. }
  1088. func makeRequestWithRetry(ctx context.Context, method, url string, headers map[string]string, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
  1089. var status string
  1090. for try := 0; try < MaxRetries; try++ {
  1091. resp, err := makeRequest(ctx, method, url, headers, body, regOpts)
  1092. if err != nil {
  1093. log.Printf("couldn't start upload: %v", err)
  1094. return nil, err
  1095. }
  1096. status = resp.Status
  1097. switch resp.StatusCode {
  1098. case http.StatusAccepted, http.StatusCreated:
  1099. return resp, nil
  1100. case http.StatusUnauthorized:
  1101. auth := resp.Header.Get("www-authenticate")
  1102. authRedir := ParseAuthRedirectString(auth)
  1103. token, err := getAuthToken(ctx, authRedir, regOpts)
  1104. if err != nil {
  1105. return nil, err
  1106. }
  1107. regOpts.Token = token
  1108. if body != nil {
  1109. if _, err := body.Seek(0, io.SeekStart); err != nil {
  1110. return nil, err
  1111. }
  1112. }
  1113. continue
  1114. default:
  1115. body, _ := io.ReadAll(resp.Body)
  1116. return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
  1117. }
  1118. }
  1119. return nil, fmt.Errorf("max retry exceeded: %v", status)
  1120. }
  1121. func makeRequest(ctx context.Context, method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
  1122. if !strings.HasPrefix(url, "http") {
  1123. if regOpts.Insecure {
  1124. url = "http://" + url
  1125. } else {
  1126. url = "https://" + url
  1127. }
  1128. }
  1129. req, err := http.NewRequestWithContext(ctx, method, url, body)
  1130. if err != nil {
  1131. return nil, err
  1132. }
  1133. if regOpts.Token != "" {
  1134. req.Header.Set("Authorization", "Bearer "+regOpts.Token)
  1135. } else if regOpts.Username != "" && regOpts.Password != "" {
  1136. req.SetBasicAuth(regOpts.Username, regOpts.Password)
  1137. }
  1138. for k, v := range headers {
  1139. req.Header.Set(k, v)
  1140. }
  1141. client := &http.Client{
  1142. CheckRedirect: func(req *http.Request, via []*http.Request) error {
  1143. if len(via) >= 10 {
  1144. return fmt.Errorf("too many redirects")
  1145. }
  1146. log.Printf("redirected to: %s\n", req.URL)
  1147. return nil
  1148. },
  1149. }
  1150. resp, err := client.Do(req)
  1151. if err != nil {
  1152. return nil, err
  1153. }
  1154. return resp, nil
  1155. }
  1156. func getValue(header, key string) string {
  1157. startIdx := strings.Index(header, key+"=")
  1158. if startIdx == -1 {
  1159. return ""
  1160. }
  1161. // Move the index to the starting quote after the key.
  1162. startIdx += len(key) + 2
  1163. endIdx := startIdx
  1164. for endIdx < len(header) {
  1165. if header[endIdx] == '"' {
  1166. if endIdx+1 < len(header) && header[endIdx+1] != ',' { // If the next character isn't a comma, continue
  1167. endIdx++
  1168. continue
  1169. }
  1170. break
  1171. }
  1172. endIdx++
  1173. }
  1174. return header[startIdx:endIdx]
  1175. }
  1176. func ParseAuthRedirectString(authStr string) AuthRedirect {
  1177. authStr = strings.TrimPrefix(authStr, "Bearer ")
  1178. return AuthRedirect{
  1179. Realm: getValue(authStr, "realm"),
  1180. Service: getValue(authStr, "service"),
  1181. Scope: getValue(authStr, "scope"),
  1182. }
  1183. }
  1184. var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again")
  1185. func verifyBlob(digest string) error {
  1186. fp, err := GetBlobsPath(digest)
  1187. if err != nil {
  1188. return err
  1189. }
  1190. f, err := os.Open(fp)
  1191. if err != nil {
  1192. return err
  1193. }
  1194. defer f.Close()
  1195. fileDigest, _ := GetSHA256Digest(f)
  1196. if digest != fileDigest {
  1197. return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest)
  1198. }
  1199. return nil
  1200. }