images.go 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284
  1. package server
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "crypto/sha256"
  7. "encoding/hex"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "io"
  12. "log"
  13. "net/http"
  14. "net/url"
  15. "os"
  16. "path"
  17. "path/filepath"
  18. "reflect"
  19. "runtime"
  20. "strconv"
  21. "strings"
  22. "text/template"
  23. "github.com/jmorganca/ollama/api"
  24. "github.com/jmorganca/ollama/llm"
  25. "github.com/jmorganca/ollama/parser"
  26. "github.com/jmorganca/ollama/vector"
  27. "github.com/jmorganca/ollama/version"
  28. )
  29. const MaxRetries = 3
  30. type RegistryOptions struct {
  31. Insecure bool
  32. Username string
  33. Password string
  34. Token string
  35. }
  36. type Model struct {
  37. Name string `json:"name"`
  38. ModelPath string
  39. AdapterPaths []string
  40. Template string
  41. System string
  42. Digest string
  43. ConfigDigest string
  44. Options map[string]interface{}
  45. Embeddings []vector.Embedding
  46. }
  47. func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
  48. t := m.Template
  49. if request.Template != "" {
  50. t = request.Template
  51. }
  52. tmpl, err := template.New("").Parse(t)
  53. if err != nil {
  54. return "", err
  55. }
  56. var vars struct {
  57. First bool
  58. System string
  59. Prompt string
  60. Embed string
  61. // deprecated: versions <= 0.0.7 used this to omit the system prompt
  62. Context []int
  63. }
  64. vars.First = len(request.Context) == 0
  65. vars.System = m.System
  66. vars.Prompt = request.Prompt
  67. vars.Context = request.Context
  68. vars.Embed = embedding
  69. if request.System != "" {
  70. vars.System = request.System
  71. }
  72. var sb strings.Builder
  73. if err := tmpl.Execute(&sb, vars); err != nil {
  74. return "", err
  75. }
  76. return sb.String(), nil
  77. }
  78. type ManifestV2 struct {
  79. SchemaVersion int `json:"schemaVersion"`
  80. MediaType string `json:"mediaType"`
  81. Config Layer `json:"config"`
  82. Layers []*Layer `json:"layers"`
  83. }
  84. type Layer struct {
  85. MediaType string `json:"mediaType"`
  86. Digest string `json:"digest"`
  87. Size int `json:"size"`
  88. From string `json:"from,omitempty"`
  89. }
  90. type LayerReader struct {
  91. Layer
  92. io.Reader
  93. }
  94. type ConfigV2 struct {
  95. ModelFamily llm.ModelFamily `json:"model_family"`
  96. ModelType string `json:"model_type"`
  97. FileType string `json:"file_type"`
  98. RootFS RootFS `json:"rootfs"`
  99. // required by spec
  100. Architecture string `json:"architecture"`
  101. OS string `json:"os"`
  102. }
  103. type RootFS struct {
  104. Type string `json:"type"`
  105. DiffIDs []string `json:"diff_ids"`
  106. }
  107. func (m *ManifestV2) GetTotalSize() int {
  108. var total int
  109. for _, layer := range m.Layers {
  110. total += layer.Size
  111. }
  112. total += m.Config.Size
  113. return total
  114. }
  115. func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
  116. fp, err := mp.GetManifestPath(false)
  117. if err != nil {
  118. return nil, "", err
  119. }
  120. if _, err = os.Stat(fp); err != nil {
  121. return nil, "", err
  122. }
  123. var manifest *ManifestV2
  124. bts, err := os.ReadFile(fp)
  125. if err != nil {
  126. return nil, "", fmt.Errorf("couldn't open file '%s'", fp)
  127. }
  128. shaSum := sha256.Sum256(bts)
  129. shaStr := hex.EncodeToString(shaSum[:])
  130. if err := json.Unmarshal(bts, &manifest); err != nil {
  131. return nil, "", err
  132. }
  133. return manifest, shaStr, nil
  134. }
  135. func GetModel(name string) (*Model, error) {
  136. mp := ParseModelPath(name)
  137. manifest, digest, err := GetManifest(mp)
  138. if err != nil {
  139. return nil, err
  140. }
  141. model := &Model{
  142. Name: mp.GetFullTagname(),
  143. Digest: digest,
  144. ConfigDigest: manifest.Config.Digest,
  145. Template: "{{ .Prompt }}",
  146. }
  147. for _, layer := range manifest.Layers {
  148. filename, err := GetBlobsPath(layer.Digest)
  149. if err != nil {
  150. return nil, err
  151. }
  152. switch layer.MediaType {
  153. case "application/vnd.ollama.image.model":
  154. model.ModelPath = filename
  155. case "application/vnd.ollama.image.embed":
  156. file, err := os.Open(filename)
  157. if err != nil {
  158. return nil, fmt.Errorf("failed to open file: %s", filename)
  159. }
  160. defer file.Close()
  161. if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
  162. return nil, err
  163. }
  164. case "application/vnd.ollama.image.adapter":
  165. model.AdapterPaths = append(model.AdapterPaths, filename)
  166. case "application/vnd.ollama.image.template":
  167. bts, err := os.ReadFile(filename)
  168. if err != nil {
  169. return nil, err
  170. }
  171. model.Template = string(bts)
  172. case "application/vnd.ollama.image.system":
  173. bts, err := os.ReadFile(filename)
  174. if err != nil {
  175. return nil, err
  176. }
  177. model.System = string(bts)
  178. case "application/vnd.ollama.image.prompt":
  179. bts, err := os.ReadFile(filename)
  180. if err != nil {
  181. return nil, err
  182. }
  183. model.Template = string(bts)
  184. case "application/vnd.ollama.image.params":
  185. params, err := os.Open(filename)
  186. if err != nil {
  187. return nil, err
  188. }
  189. defer params.Close()
  190. // parse model options parameters into a map so that we can see which fields have been specified explicitly
  191. if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
  192. return nil, err
  193. }
  194. }
  195. }
  196. return model, nil
  197. }
  198. func filenameWithPath(path, f string) (string, error) {
  199. // if filePath starts with ~/, replace it with the user's home directory.
  200. if strings.HasPrefix(f, fmt.Sprintf("~%s", string(os.PathSeparator))) {
  201. parts := strings.Split(f, string(os.PathSeparator))
  202. home, err := os.UserHomeDir()
  203. if err != nil {
  204. return "", fmt.Errorf("failed to open file: %v", err)
  205. }
  206. f = filepath.Join(home, filepath.Join(parts[1:]...))
  207. }
  208. // if filePath is not an absolute path, make it relative to the modelfile path
  209. if !filepath.IsAbs(f) {
  210. f = filepath.Join(filepath.Dir(path), f)
  211. }
  212. return f, nil
  213. }
  214. func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error {
  215. mf, err := os.Open(path)
  216. if err != nil {
  217. fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)})
  218. return fmt.Errorf("failed to open file: %w", err)
  219. }
  220. defer mf.Close()
  221. fn(api.ProgressResponse{Status: "parsing modelfile"})
  222. commands, err := parser.Parse(mf)
  223. if err != nil {
  224. return err
  225. }
  226. config := ConfigV2{
  227. Architecture: "amd64",
  228. OS: "linux",
  229. }
  230. var layers []*LayerReader
  231. params := make(map[string][]string)
  232. embed := EmbeddingParams{fn: fn}
  233. for _, c := range commands {
  234. log.Printf("[%s] - %s\n", c.Name, c.Args)
  235. switch c.Name {
  236. case "model":
  237. fn(api.ProgressResponse{Status: "looking for model"})
  238. embed.model = c.Args
  239. mp := ParseModelPath(c.Args)
  240. mf, _, err := GetManifest(mp)
  241. if err != nil {
  242. modelFile, err := filenameWithPath(path, c.Args)
  243. if err != nil {
  244. return err
  245. }
  246. if _, err := os.Stat(modelFile); err != nil {
  247. // the model file does not exist, try pulling it
  248. if errors.Is(err, os.ErrNotExist) {
  249. fn(api.ProgressResponse{Status: "pulling model file"})
  250. if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
  251. return err
  252. }
  253. mf, _, err = GetManifest(mp)
  254. if err != nil {
  255. return fmt.Errorf("failed to open file after pull: %v", err)
  256. }
  257. } else {
  258. return err
  259. }
  260. } else {
  261. embed.model = modelFile
  262. // create a model from this specified file
  263. fn(api.ProgressResponse{Status: "creating model layer"})
  264. file, err := os.Open(modelFile)
  265. if err != nil {
  266. return fmt.Errorf("failed to open file: %v", err)
  267. }
  268. defer file.Close()
  269. ggml, err := llm.DecodeGGML(file, llm.ModelFamilyLlama)
  270. if err != nil {
  271. return err
  272. }
  273. config.ModelFamily = ggml.ModelFamily()
  274. config.ModelType = ggml.ModelType().String()
  275. config.FileType = ggml.FileType().String()
  276. // reset the file
  277. file.Seek(0, io.SeekStart)
  278. l, err := CreateLayer(file)
  279. if err != nil {
  280. return fmt.Errorf("failed to create layer: %v", err)
  281. }
  282. l.MediaType = "application/vnd.ollama.image.model"
  283. layers = append(layers, l)
  284. }
  285. }
  286. if mf != nil {
  287. sourceBlobPath, err := GetBlobsPath(mf.Config.Digest)
  288. if err != nil {
  289. return err
  290. }
  291. sourceBlob, err := os.Open(sourceBlobPath)
  292. if err != nil {
  293. return err
  294. }
  295. defer sourceBlob.Close()
  296. var source ConfigV2
  297. if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil {
  298. return err
  299. }
  300. // copie the model metadata
  301. config.ModelFamily = source.ModelFamily
  302. config.ModelType = source.ModelType
  303. config.FileType = source.FileType
  304. for _, l := range mf.Layers {
  305. newLayer, err := GetLayerWithBufferFromLayer(l)
  306. if err != nil {
  307. return err
  308. }
  309. newLayer.From = mp.GetNamespaceRepository()
  310. layers = append(layers, newLayer)
  311. }
  312. }
  313. case "embed":
  314. embedFilePath, err := filenameWithPath(path, c.Args)
  315. if err != nil {
  316. return err
  317. }
  318. embed.files = append(embed.files, embedFilePath)
  319. case "adapter":
  320. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  321. fp, err := filenameWithPath(path, c.Args)
  322. if err != nil {
  323. return err
  324. }
  325. // create a model from this specified file
  326. fn(api.ProgressResponse{Status: "creating model layer"})
  327. file, err := os.Open(fp)
  328. if err != nil {
  329. return fmt.Errorf("failed to open file: %v", err)
  330. }
  331. defer file.Close()
  332. l, err := CreateLayer(file)
  333. if err != nil {
  334. return fmt.Errorf("failed to create layer: %v", err)
  335. }
  336. l.MediaType = "application/vnd.ollama.image.adapter"
  337. layers = append(layers, l)
  338. case "license":
  339. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  340. mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
  341. layer, err := CreateLayer(strings.NewReader(c.Args))
  342. if err != nil {
  343. return err
  344. }
  345. layer.MediaType = mediaType
  346. layers = append(layers, layer)
  347. case "template", "system", "prompt":
  348. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  349. // remove the layer if one exists
  350. mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
  351. layers = removeLayerFromLayers(layers, mediaType)
  352. layer, err := CreateLayer(strings.NewReader(c.Args))
  353. if err != nil {
  354. return err
  355. }
  356. layer.MediaType = mediaType
  357. layers = append(layers, layer)
  358. default:
  359. // runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop sequences)
  360. params[c.Name] = append(params[c.Name], c.Args)
  361. }
  362. }
  363. // Create a single layer for the parameters
  364. if len(params) > 0 {
  365. fn(api.ProgressResponse{Status: "creating parameter layer"})
  366. layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
  367. formattedParams, err := formatParams(params)
  368. if err != nil {
  369. return fmt.Errorf("couldn't create params json: %v", err)
  370. }
  371. bts, err := json.Marshal(formattedParams)
  372. if err != nil {
  373. return err
  374. }
  375. l, err := CreateLayer(bytes.NewReader(bts))
  376. if err != nil {
  377. return fmt.Errorf("failed to create layer: %v", err)
  378. }
  379. l.MediaType = "application/vnd.ollama.image.params"
  380. layers = append(layers, l)
  381. // apply these parameters to the embedding options, in case embeddings need to be generated using this model
  382. embed.opts = formattedParams
  383. }
  384. // generate the embedding layers
  385. embeddingLayers, err := embeddingLayers(embed)
  386. if err != nil {
  387. return err
  388. }
  389. layers = append(layers, embeddingLayers...)
  390. digests, err := getLayerDigests(layers)
  391. if err != nil {
  392. return err
  393. }
  394. var manifestLayers []*Layer
  395. for _, l := range layers {
  396. manifestLayers = append(manifestLayers, &l.Layer)
  397. }
  398. // Create a layer for the config object
  399. fn(api.ProgressResponse{Status: "creating config layer"})
  400. cfg, err := createConfigLayer(config, digests)
  401. if err != nil {
  402. return err
  403. }
  404. layers = append(layers, cfg)
  405. if err := SaveLayers(layers, fn, false); err != nil {
  406. return err
  407. }
  408. // Create the manifest
  409. fn(api.ProgressResponse{Status: "writing manifest"})
  410. err = CreateManifest(name, cfg, manifestLayers)
  411. if err != nil {
  412. return err
  413. }
  414. fn(api.ProgressResponse{Status: "success"})
  415. return nil
  416. }
  417. type EmbeddingParams struct {
  418. model string
  419. opts map[string]interface{}
  420. files []string // paths to files to embed
  421. fn func(resp api.ProgressResponse)
  422. }
  423. // embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
  424. func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
  425. layers := []*LayerReader{}
  426. if len(e.files) > 0 {
  427. // check if the model is a file path or a model name
  428. model, err := GetModel(e.model)
  429. if err != nil {
  430. if !strings.Contains(err.Error(), "couldn't open file") {
  431. return nil, fmt.Errorf("unexpected error opening model to generate embeddings: %v", err)
  432. }
  433. // the model may be a file path, create a model from this file
  434. model = &Model{ModelPath: e.model}
  435. }
  436. if err := load(context.Background(), model, e.opts, defaultSessionDuration); err != nil {
  437. return nil, fmt.Errorf("load model to generate embeddings: %v", err)
  438. }
  439. // this will be used to check if we already have embeddings for a file
  440. modelInfo, err := os.Stat(model.ModelPath)
  441. if err != nil {
  442. return nil, fmt.Errorf("failed to get model file info: %v", err)
  443. }
  444. addedFiles := make(map[string]bool) // keep track of files that have already been added
  445. for _, filePattern := range e.files {
  446. matchingFiles, err := filepath.Glob(filePattern)
  447. if err != nil {
  448. return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
  449. }
  450. for _, filePath := range matchingFiles {
  451. if addedFiles[filePath] {
  452. continue
  453. }
  454. addedFiles[filePath] = true
  455. // check if we already have embeddings for this file path
  456. layerIdentifier := fmt.Sprintf("%s:%s:%s:%d", filePath, e.model, modelInfo.ModTime().Format("2006-01-02 15:04:05"), modelInfo.Size())
  457. digest, _ := GetSHA256Digest(strings.NewReader(layerIdentifier))
  458. existing, err := existingFileEmbeddings(digest)
  459. if err != nil {
  460. return nil, fmt.Errorf("failed to check existing embeddings for file %s: %v", filePath, err)
  461. }
  462. // TODO: check file type
  463. f, err := os.Open(filePath)
  464. if err != nil {
  465. return nil, fmt.Errorf("could not open embed file: %w", err)
  466. }
  467. scanner := bufio.NewScanner(f)
  468. scanner.Split(bufio.ScanLines)
  469. data := []string{}
  470. for scanner.Scan() {
  471. data = append(data, scanner.Text())
  472. }
  473. f.Close()
  474. // the digest of the file is set here so that the client knows a new operation is in progress
  475. fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
  476. embeddings := []vector.Embedding{}
  477. for i, d := range data {
  478. if strings.TrimSpace(d) == "" {
  479. continue
  480. }
  481. e.fn(api.ProgressResponse{
  482. Status: fmt.Sprintf("creating embeddings for file %s", filePath),
  483. Digest: fileDigest,
  484. Total: len(data) - 1,
  485. Completed: i,
  486. })
  487. if len(existing[d]) > 0 {
  488. // already have an embedding for this line
  489. embeddings = append(embeddings, vector.Embedding{Data: d, Vector: existing[d]})
  490. continue
  491. }
  492. embed, err := loaded.llm.Embedding(context.Background(), d)
  493. if err != nil {
  494. log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
  495. continue
  496. }
  497. embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
  498. }
  499. b, err := json.Marshal(embeddings)
  500. if err != nil {
  501. return nil, fmt.Errorf("failed to encode embeddings: %w", err)
  502. }
  503. r := bytes.NewReader(b)
  504. layer := &LayerReader{
  505. Layer: Layer{
  506. MediaType: "application/vnd.ollama.image.embed",
  507. Digest: digest,
  508. Size: r.Len(),
  509. },
  510. Reader: r,
  511. }
  512. layers = append(layers, layer)
  513. }
  514. }
  515. }
  516. return layers, nil
  517. }
  518. // existingFileEmbeddings checks if we already have embeddings for a file and loads them into a look-up map
  519. func existingFileEmbeddings(digest string) (map[string][]float64, error) {
  520. path, err := GetBlobsPath(digest)
  521. if err != nil {
  522. return nil, fmt.Errorf("embeddings blobs path: %w", err)
  523. }
  524. existingFileEmbeddings := make(map[string][]float64)
  525. if _, err := os.Stat(path); err == nil {
  526. // already have some embeddings for this file, load embeddings previously generated
  527. file, err := os.Open(path)
  528. if err != nil {
  529. return nil, fmt.Errorf("failed to open existing embedding file: %s", err)
  530. }
  531. defer file.Close()
  532. existing := []vector.Embedding{}
  533. if err = json.NewDecoder(file).Decode(&existing); err != nil {
  534. return nil, err
  535. }
  536. for _, e := range existing {
  537. existingFileEmbeddings[e.Data] = e.Vector
  538. }
  539. }
  540. return existingFileEmbeddings, nil
  541. }
  542. func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
  543. j := 0
  544. for _, l := range layers {
  545. if l.MediaType != mediaType {
  546. layers[j] = l
  547. j++
  548. }
  549. }
  550. return layers[:j]
  551. }
  552. func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
  553. // Write each of the layers to disk
  554. for _, layer := range layers {
  555. fp, err := GetBlobsPath(layer.Digest)
  556. if err != nil {
  557. return err
  558. }
  559. _, err = os.Stat(fp)
  560. // note: embed layers are always written since their digest doesnt indicate anything about the contents
  561. if os.IsNotExist(err) || force || layer.MediaType == "application/vnd.ollama.image.embed" {
  562. fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
  563. out, err := os.Create(fp)
  564. if err != nil {
  565. log.Printf("couldn't create %s", fp)
  566. return err
  567. }
  568. defer out.Close()
  569. if _, err = io.Copy(out, layer.Reader); err != nil {
  570. return err
  571. }
  572. } else {
  573. fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
  574. }
  575. }
  576. return nil
  577. }
  578. func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
  579. mp := ParseModelPath(name)
  580. manifest := ManifestV2{
  581. SchemaVersion: 2,
  582. MediaType: "application/vnd.docker.distribution.manifest.v2+json",
  583. Config: Layer{
  584. MediaType: cfg.MediaType,
  585. Size: cfg.Size,
  586. Digest: cfg.Digest,
  587. },
  588. Layers: layers,
  589. }
  590. manifestJSON, err := json.Marshal(manifest)
  591. if err != nil {
  592. return err
  593. }
  594. fp, err := mp.GetManifestPath(true)
  595. if err != nil {
  596. return err
  597. }
  598. return os.WriteFile(fp, manifestJSON, 0o644)
  599. }
  600. func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
  601. fp, err := GetBlobsPath(layer.Digest)
  602. if err != nil {
  603. return nil, err
  604. }
  605. file, err := os.Open(fp)
  606. if err != nil {
  607. return nil, fmt.Errorf("could not open blob: %w", err)
  608. }
  609. defer file.Close()
  610. newLayer, err := CreateLayer(file)
  611. if err != nil {
  612. return nil, err
  613. }
  614. newLayer.MediaType = layer.MediaType
  615. return newLayer, nil
  616. }
  617. // formatParams converts specified parameter options to their correct types
  618. func formatParams(params map[string][]string) (map[string]interface{}, error) {
  619. opts := api.Options{}
  620. valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
  621. typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
  622. // build map of json struct tags to their types
  623. jsonOpts := make(map[string]reflect.StructField)
  624. for _, field := range reflect.VisibleFields(typeOpts) {
  625. jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
  626. if jsonTag != "" {
  627. jsonOpts[jsonTag] = field
  628. }
  629. }
  630. out := make(map[string]interface{})
  631. // iterate params and set values based on json struct tags
  632. for key, vals := range params {
  633. if opt, ok := jsonOpts[key]; ok {
  634. field := valueOpts.FieldByName(opt.Name)
  635. if field.IsValid() && field.CanSet() {
  636. switch field.Kind() {
  637. case reflect.Float32:
  638. floatVal, err := strconv.ParseFloat(vals[0], 32)
  639. if err != nil {
  640. return nil, fmt.Errorf("invalid float value %s", vals)
  641. }
  642. out[key] = floatVal
  643. case reflect.Int:
  644. intVal, err := strconv.ParseInt(vals[0], 10, 0)
  645. if err != nil {
  646. return nil, fmt.Errorf("invalid int value %s", vals)
  647. }
  648. out[key] = intVal
  649. case reflect.Bool:
  650. boolVal, err := strconv.ParseBool(vals[0])
  651. if err != nil {
  652. return nil, fmt.Errorf("invalid bool value %s", vals)
  653. }
  654. out[key] = boolVal
  655. case reflect.String:
  656. out[key] = vals[0]
  657. case reflect.Slice:
  658. // TODO: only string slices are supported right now
  659. out[key] = vals
  660. default:
  661. return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
  662. }
  663. }
  664. }
  665. }
  666. return out, nil
  667. }
  668. func getLayerDigests(layers []*LayerReader) ([]string, error) {
  669. var digests []string
  670. for _, l := range layers {
  671. if l.Digest == "" {
  672. return nil, fmt.Errorf("layer is missing a digest")
  673. }
  674. digests = append(digests, l.Digest)
  675. }
  676. return digests, nil
  677. }
  678. // CreateLayer creates a Layer object from a given file
  679. func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
  680. digest, size := GetSHA256Digest(f)
  681. f.Seek(0, io.SeekStart)
  682. layer := &LayerReader{
  683. Layer: Layer{
  684. MediaType: "application/vnd.docker.image.rootfs.diff.tar",
  685. Digest: digest,
  686. Size: size,
  687. },
  688. Reader: f,
  689. }
  690. return layer, nil
  691. }
  692. func CopyModel(src, dest string) error {
  693. srcModelPath := ParseModelPath(src)
  694. srcPath, err := srcModelPath.GetManifestPath(false)
  695. if err != nil {
  696. return err
  697. }
  698. destModelPath := ParseModelPath(dest)
  699. destPath, err := destModelPath.GetManifestPath(true)
  700. if err != nil {
  701. return err
  702. }
  703. // copy the file
  704. input, err := os.ReadFile(srcPath)
  705. if err != nil {
  706. fmt.Println("Error reading file:", err)
  707. return err
  708. }
  709. err = os.WriteFile(destPath, input, 0o644)
  710. if err != nil {
  711. fmt.Println("Error reading file:", err)
  712. return err
  713. }
  714. return nil
  715. }
  716. func DeleteModel(name string) error {
  717. mp := ParseModelPath(name)
  718. manifest, _, err := GetManifest(mp)
  719. if err != nil {
  720. return err
  721. }
  722. deleteMap := make(map[string]bool)
  723. for _, layer := range manifest.Layers {
  724. deleteMap[layer.Digest] = true
  725. }
  726. deleteMap[manifest.Config.Digest] = true
  727. fp, err := GetManifestPath()
  728. if err != nil {
  729. return err
  730. }
  731. walkFunc := func(path string, info os.FileInfo, _ error) error {
  732. if info.IsDir() {
  733. return nil
  734. }
  735. dir, file := filepath.Split(path)
  736. dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
  737. tag := strings.Join([]string{dir, file}, ":")
  738. fmp := ParseModelPath(tag)
  739. // skip the manifest we're trying to delete
  740. if mp.GetFullTagname() == fmp.GetFullTagname() {
  741. return nil
  742. }
  743. // save (i.e. delete from the deleteMap) any files used in other manifests
  744. manifest, _, err := GetManifest(fmp)
  745. if err != nil {
  746. log.Printf("skipping file: %s", fp)
  747. return nil
  748. }
  749. for _, layer := range manifest.Layers {
  750. delete(deleteMap, layer.Digest)
  751. }
  752. delete(deleteMap, manifest.Config.Digest)
  753. return nil
  754. }
  755. if err := filepath.Walk(fp, walkFunc); err != nil {
  756. return err
  757. }
  758. // only delete the files which are still in the deleteMap
  759. for k, v := range deleteMap {
  760. if v {
  761. fp, err := GetBlobsPath(k)
  762. if err != nil {
  763. log.Printf("couldn't get file path for '%s': %v", k, err)
  764. continue
  765. }
  766. if err := os.Remove(fp); err != nil {
  767. log.Printf("couldn't remove file '%s': %v", fp, err)
  768. continue
  769. }
  770. }
  771. }
  772. fp, err = mp.GetManifestPath(false)
  773. if err != nil {
  774. return err
  775. }
  776. err = os.Remove(fp)
  777. if err != nil {
  778. log.Printf("couldn't remove manifest file '%s': %v", fp, err)
  779. return err
  780. }
  781. return nil
  782. }
  783. func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  784. mp := ParseModelPath(name)
  785. fn(api.ProgressResponse{Status: "retrieving manifest"})
  786. if mp.ProtocolScheme == "http" && !regOpts.Insecure {
  787. return fmt.Errorf("insecure protocol http")
  788. }
  789. manifest, _, err := GetManifest(mp)
  790. if err != nil {
  791. fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
  792. return err
  793. }
  794. var layers []*Layer
  795. layers = append(layers, manifest.Layers...)
  796. layers = append(layers, &manifest.Config)
  797. for _, layer := range layers {
  798. exists, err := checkBlobExistence(ctx, mp, layer.Digest, regOpts)
  799. if err != nil {
  800. return err
  801. }
  802. if exists {
  803. fn(api.ProgressResponse{
  804. Status: "using existing layer",
  805. Digest: layer.Digest,
  806. Total: layer.Size,
  807. Completed: layer.Size,
  808. })
  809. log.Printf("Layer %s already exists", layer.Digest)
  810. continue
  811. }
  812. fn(api.ProgressResponse{
  813. Status: "starting upload",
  814. Digest: layer.Digest,
  815. Total: layer.Size,
  816. })
  817. location, err := startUpload(ctx, mp, layer, regOpts)
  818. if err != nil {
  819. log.Printf("couldn't start upload: %v", err)
  820. return err
  821. }
  822. if strings.HasPrefix(path.Base(location.Path), "sha256:") {
  823. layer.Digest = path.Base(location.Path)
  824. fn(api.ProgressResponse{
  825. Status: "using existing layer",
  826. Digest: layer.Digest,
  827. Total: layer.Size,
  828. Completed: layer.Size,
  829. })
  830. continue
  831. }
  832. if err := uploadBlobChunked(ctx, location, layer, regOpts, fn); err != nil {
  833. log.Printf("error uploading blob: %v", err)
  834. return err
  835. }
  836. }
  837. fn(api.ProgressResponse{Status: "pushing manifest"})
  838. requestURL := mp.BaseURL()
  839. requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
  840. manifestJSON, err := json.Marshal(manifest)
  841. if err != nil {
  842. return err
  843. }
  844. headers := make(http.Header)
  845. headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
  846. resp, err := makeRequestWithRetry(ctx, "PUT", requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
  847. if err != nil {
  848. return err
  849. }
  850. defer resp.Body.Close()
  851. fn(api.ProgressResponse{Status: "success"})
  852. return nil
  853. }
  854. func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  855. mp := ParseModelPath(name)
  856. if mp.ProtocolScheme == "http" && !regOpts.Insecure {
  857. return fmt.Errorf("insecure protocol http")
  858. }
  859. fn(api.ProgressResponse{Status: "pulling manifest"})
  860. manifest, err := pullModelManifest(ctx, mp, regOpts)
  861. if err != nil {
  862. return fmt.Errorf("pull model manifest: %s", err)
  863. }
  864. var layers []*Layer
  865. layers = append(layers, manifest.Layers...)
  866. layers = append(layers, &manifest.Config)
  867. for _, layer := range layers {
  868. if err := downloadBlob(
  869. ctx,
  870. downloadOpts{
  871. mp: mp,
  872. digest: layer.Digest,
  873. regOpts: regOpts,
  874. fn: fn,
  875. }); err != nil {
  876. return err
  877. }
  878. }
  879. fn(api.ProgressResponse{Status: "verifying sha256 digest"})
  880. for _, layer := range layers {
  881. if err := verifyBlob(layer.Digest); err != nil {
  882. if errors.Is(err, errDigestMismatch) {
  883. // something went wrong, delete the blob
  884. fp, err := GetBlobsPath(layer.Digest)
  885. if err != nil {
  886. return err
  887. }
  888. if err := os.Remove(fp); err != nil {
  889. // log this, but return the original error
  890. log.Printf("couldn't remove file with digest mismatch '%s': %v", fp, err)
  891. }
  892. }
  893. return err
  894. }
  895. }
  896. fn(api.ProgressResponse{Status: "writing manifest"})
  897. manifestJSON, err := json.Marshal(manifest)
  898. if err != nil {
  899. return err
  900. }
  901. fp, err := mp.GetManifestPath(true)
  902. if err != nil {
  903. return err
  904. }
  905. err = os.WriteFile(fp, manifestJSON, 0o644)
  906. if err != nil {
  907. log.Printf("couldn't write to %s", fp)
  908. return err
  909. }
  910. fn(api.ProgressResponse{Status: "success"})
  911. return nil
  912. }
  913. func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
  914. requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
  915. headers := make(http.Header)
  916. headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
  917. resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, regOpts)
  918. if err != nil {
  919. log.Printf("couldn't get manifest: %v", err)
  920. return nil, err
  921. }
  922. defer resp.Body.Close()
  923. if resp.StatusCode >= http.StatusBadRequest {
  924. if resp.StatusCode == http.StatusNotFound {
  925. return nil, fmt.Errorf("model not found")
  926. }
  927. body, _ := io.ReadAll(resp.Body)
  928. return nil, fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
  929. }
  930. var m *ManifestV2
  931. if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
  932. return nil, err
  933. }
  934. return m, err
  935. }
  936. func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
  937. config.RootFS = RootFS{
  938. Type: "layers",
  939. DiffIDs: layers,
  940. }
  941. configJSON, err := json.Marshal(config)
  942. if err != nil {
  943. return nil, err
  944. }
  945. digest, size := GetSHA256Digest(bytes.NewBuffer(configJSON))
  946. layer := &LayerReader{
  947. Layer: Layer{
  948. MediaType: "application/vnd.docker.container.image.v1+json",
  949. Digest: digest,
  950. Size: size,
  951. },
  952. Reader: bytes.NewBuffer(configJSON),
  953. }
  954. return layer, nil
  955. }
  956. // GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
  957. func GetSHA256Digest(r io.Reader) (string, int) {
  958. h := sha256.New()
  959. n, err := io.Copy(h, r)
  960. if err != nil {
  961. log.Fatal(err)
  962. }
  963. return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
  964. }
  965. // Function to check if a blob already exists in the Docker registry
  966. func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
  967. requestURL := mp.BaseURL()
  968. requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", digest)
  969. resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, regOpts)
  970. if err != nil {
  971. log.Printf("couldn't check for blob: %v", err)
  972. return false, err
  973. }
  974. defer resp.Body.Close()
  975. // Check for success: If the blob exists, the Docker registry will respond with a 200 OK
  976. return resp.StatusCode < http.StatusBadRequest, nil
  977. }
  978. func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
  979. var status string
  980. for try := 0; try < MaxRetries; try++ {
  981. resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
  982. if err != nil {
  983. log.Printf("couldn't start upload: %v", err)
  984. return nil, err
  985. }
  986. status = resp.Status
  987. switch {
  988. case resp.StatusCode == http.StatusUnauthorized:
  989. auth := resp.Header.Get("www-authenticate")
  990. authRedir := ParseAuthRedirectString(auth)
  991. token, err := getAuthToken(ctx, authRedir, regOpts)
  992. if err != nil {
  993. return nil, err
  994. }
  995. regOpts.Token = token
  996. if body != nil {
  997. if _, err := body.Seek(0, io.SeekStart); err != nil {
  998. return nil, err
  999. }
  1000. }
  1001. continue
  1002. case resp.StatusCode >= http.StatusBadRequest:
  1003. body, _ := io.ReadAll(resp.Body)
  1004. return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
  1005. default:
  1006. return resp, nil
  1007. }
  1008. }
  1009. return nil, fmt.Errorf("max retry exceeded: %v", status)
  1010. }
  1011. func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
  1012. if requestURL.Scheme != "http" && regOpts.Insecure {
  1013. requestURL.Scheme = "http"
  1014. }
  1015. req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
  1016. if err != nil {
  1017. return nil, err
  1018. }
  1019. if headers != nil {
  1020. req.Header = headers
  1021. }
  1022. if regOpts.Token != "" {
  1023. req.Header.Set("Authorization", "Bearer "+regOpts.Token)
  1024. } else if regOpts.Username != "" && regOpts.Password != "" {
  1025. req.SetBasicAuth(regOpts.Username, regOpts.Password)
  1026. }
  1027. req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
  1028. client := &http.Client{
  1029. CheckRedirect: func(req *http.Request, via []*http.Request) error {
  1030. if len(via) >= 10 {
  1031. return fmt.Errorf("too many redirects")
  1032. }
  1033. log.Printf("redirected to: %s\n", req.URL)
  1034. return nil
  1035. },
  1036. }
  1037. resp, err := client.Do(req)
  1038. if err != nil {
  1039. return nil, err
  1040. }
  1041. return resp, nil
  1042. }
  1043. func getValue(header, key string) string {
  1044. startIdx := strings.Index(header, key+"=")
  1045. if startIdx == -1 {
  1046. return ""
  1047. }
  1048. // Move the index to the starting quote after the key.
  1049. startIdx += len(key) + 2
  1050. endIdx := startIdx
  1051. for endIdx < len(header) {
  1052. if header[endIdx] == '"' {
  1053. if endIdx+1 < len(header) && header[endIdx+1] != ',' { // If the next character isn't a comma, continue
  1054. endIdx++
  1055. continue
  1056. }
  1057. break
  1058. }
  1059. endIdx++
  1060. }
  1061. return header[startIdx:endIdx]
  1062. }
  1063. func ParseAuthRedirectString(authStr string) AuthRedirect {
  1064. authStr = strings.TrimPrefix(authStr, "Bearer ")
  1065. return AuthRedirect{
  1066. Realm: getValue(authStr, "realm"),
  1067. Service: getValue(authStr, "service"),
  1068. Scope: getValue(authStr, "scope"),
  1069. }
  1070. }
  1071. var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again")
  1072. func verifyBlob(digest string) error {
  1073. fp, err := GetBlobsPath(digest)
  1074. if err != nil {
  1075. return err
  1076. }
  1077. f, err := os.Open(fp)
  1078. if err != nil {
  1079. return err
  1080. }
  1081. defer f.Close()
  1082. fileDigest, _ := GetSHA256Digest(f)
  1083. if digest != fileDigest {
  1084. return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest)
  1085. }
  1086. return nil
  1087. }