images.go 39 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609
  1. package server
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "crypto/sha256"
  7. "encoding/hex"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "io"
  12. "log"
  13. "net/http"
  14. "net/url"
  15. "os"
  16. "path/filepath"
  17. "reflect"
  18. "runtime"
  19. "strconv"
  20. "strings"
  21. "text/template"
  22. "golang.org/x/exp/slices"
  23. "github.com/jmorganca/ollama/api"
  24. "github.com/jmorganca/ollama/llm"
  25. "github.com/jmorganca/ollama/parser"
  26. "github.com/jmorganca/ollama/vector"
  27. "github.com/jmorganca/ollama/version"
  28. )
  29. const MaxRetries = 3
  30. type RegistryOptions struct {
  31. Insecure bool
  32. Username string
  33. Password string
  34. Token string
  35. }
  36. type Model struct {
  37. Name string `json:"name"`
  38. ShortName string
  39. ModelPath string
  40. OriginalModel string
  41. AdapterPaths []string
  42. Template string
  43. System string
  44. License []string
  45. Digest string
  46. ConfigDigest string
  47. Options map[string]interface{}
  48. Embeddings []vector.Embedding
  49. }
  50. func (m *Model) ChatPrompt(messages []api.Message) (string, error) {
  51. tmpl, err := template.New("").Parse(m.Template)
  52. if err != nil {
  53. return "", err
  54. }
  55. var vars struct {
  56. System string
  57. Prompt string
  58. First bool
  59. }
  60. vars.First = true
  61. var sb strings.Builder
  62. flush := func() {
  63. tmpl.Execute(&sb, vars)
  64. vars.System = ""
  65. vars.Prompt = ""
  66. }
  67. // build the chat history from messages
  68. for _, m := range messages {
  69. if m.Role == "system" {
  70. if vars.System != "" {
  71. flush()
  72. }
  73. vars.System = m.Content
  74. }
  75. if m.Role == "user" {
  76. if vars.Prompt != "" {
  77. flush()
  78. }
  79. vars.Prompt = m.Content
  80. }
  81. if m.Role == "assistant" {
  82. flush()
  83. sb.Write([]byte(m.Content))
  84. }
  85. }
  86. flush()
  87. return sb.String(), nil
  88. }
  89. func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
  90. t := m.Template
  91. if request.Template != "" {
  92. t = request.Template
  93. }
  94. tmpl, err := template.New("").Parse(t)
  95. if err != nil {
  96. return "", err
  97. }
  98. var vars struct {
  99. First bool
  100. System string
  101. Prompt string
  102. Embed string
  103. // deprecated: versions <= 0.0.7 used this to omit the system prompt
  104. Context []int
  105. }
  106. vars.First = len(request.Context) == 0
  107. vars.System = m.System
  108. vars.Prompt = request.Prompt
  109. vars.Context = request.Context
  110. vars.Embed = embedding
  111. if request.System != "" {
  112. vars.System = request.System
  113. }
  114. var sb strings.Builder
  115. if err := tmpl.Execute(&sb, vars); err != nil {
  116. return "", err
  117. }
  118. return sb.String(), nil
  119. }
  120. type ManifestV2 struct {
  121. SchemaVersion int `json:"schemaVersion"`
  122. MediaType string `json:"mediaType"`
  123. Config Layer `json:"config"`
  124. Layers []*Layer `json:"layers"`
  125. }
  126. type Layer struct {
  127. MediaType string `json:"mediaType"`
  128. Digest string `json:"digest"`
  129. Size int64 `json:"size"`
  130. From string `json:"from,omitempty"`
  131. }
  132. type LayerReader struct {
  133. Layer
  134. io.Reader
  135. }
  136. type ConfigV2 struct {
  137. ModelFormat string `json:"model_format"`
  138. ModelFamily string `json:"model_family"`
  139. ModelType string `json:"model_type"`
  140. FileType string `json:"file_type"`
  141. RootFS RootFS `json:"rootfs"`
  142. // required by spec
  143. Architecture string `json:"architecture"`
  144. OS string `json:"os"`
  145. }
  146. type RootFS struct {
  147. Type string `json:"type"`
  148. DiffIDs []string `json:"diff_ids"`
  149. }
  150. func (m *ManifestV2) GetTotalSize() (total int64) {
  151. for _, layer := range m.Layers {
  152. total += layer.Size
  153. }
  154. total += m.Config.Size
  155. return total
  156. }
  157. func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
  158. fp, err := mp.GetManifestPath(false)
  159. if err != nil {
  160. return nil, "", err
  161. }
  162. if _, err = os.Stat(fp); err != nil {
  163. return nil, "", err
  164. }
  165. var manifest *ManifestV2
  166. bts, err := os.ReadFile(fp)
  167. if err != nil {
  168. return nil, "", fmt.Errorf("couldn't open file '%s'", fp)
  169. }
  170. shaSum := sha256.Sum256(bts)
  171. shaStr := hex.EncodeToString(shaSum[:])
  172. if err := json.Unmarshal(bts, &manifest); err != nil {
  173. return nil, "", err
  174. }
  175. return manifest, shaStr, nil
  176. }
  177. func GetModel(name string) (*Model, error) {
  178. mp := ParseModelPath(name)
  179. manifest, digest, err := GetManifest(mp)
  180. if err != nil {
  181. return nil, err
  182. }
  183. model := &Model{
  184. Name: mp.GetFullTagname(),
  185. ShortName: mp.GetShortTagname(),
  186. Digest: digest,
  187. ConfigDigest: manifest.Config.Digest,
  188. Template: "{{ .Prompt }}",
  189. License: []string{},
  190. }
  191. for _, layer := range manifest.Layers {
  192. filename, err := GetBlobsPath(layer.Digest)
  193. if err != nil {
  194. return nil, err
  195. }
  196. switch layer.MediaType {
  197. case "application/vnd.ollama.image.model":
  198. model.ModelPath = filename
  199. model.OriginalModel = layer.From
  200. case "application/vnd.ollama.image.embed":
  201. file, err := os.Open(filename)
  202. if err != nil {
  203. return nil, fmt.Errorf("failed to open file: %s", filename)
  204. }
  205. defer file.Close()
  206. if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
  207. return nil, err
  208. }
  209. case "application/vnd.ollama.image.adapter":
  210. model.AdapterPaths = append(model.AdapterPaths, filename)
  211. case "application/vnd.ollama.image.template":
  212. bts, err := os.ReadFile(filename)
  213. if err != nil {
  214. return nil, err
  215. }
  216. model.Template = string(bts)
  217. case "application/vnd.ollama.image.system":
  218. bts, err := os.ReadFile(filename)
  219. if err != nil {
  220. return nil, err
  221. }
  222. model.System = string(bts)
  223. case "application/vnd.ollama.image.prompt":
  224. bts, err := os.ReadFile(filename)
  225. if err != nil {
  226. return nil, err
  227. }
  228. model.Template = string(bts)
  229. case "application/vnd.ollama.image.params":
  230. params, err := os.Open(filename)
  231. if err != nil {
  232. return nil, err
  233. }
  234. defer params.Close()
  235. // parse model options parameters into a map so that we can see which fields have been specified explicitly
  236. if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
  237. return nil, err
  238. }
  239. case "application/vnd.ollama.image.license":
  240. bts, err := os.ReadFile(filename)
  241. if err != nil {
  242. return nil, err
  243. }
  244. model.License = append(model.License, string(bts))
  245. }
  246. }
  247. return model, nil
  248. }
  249. func filenameWithPath(path, f string) (string, error) {
  250. // if filePath starts with ~/, replace it with the user's home directory.
  251. if strings.HasPrefix(f, fmt.Sprintf("~%s", string(os.PathSeparator))) {
  252. parts := strings.Split(f, string(os.PathSeparator))
  253. home, err := os.UserHomeDir()
  254. if err != nil {
  255. return "", fmt.Errorf("failed to open file: %v", err)
  256. }
  257. f = filepath.Join(home, filepath.Join(parts[1:]...))
  258. }
  259. // if filePath is not an absolute path, make it relative to the modelfile path
  260. if !filepath.IsAbs(f) {
  261. f = filepath.Join(filepath.Dir(path), f)
  262. }
  263. return f, nil
  264. }
  265. func CreateModel(ctx context.Context, workDir, name string, path string, fn func(resp api.ProgressResponse)) error {
  266. mp := ParseModelPath(name)
  267. var manifest *ManifestV2
  268. var err error
  269. var noprune string
  270. // build deleteMap to prune unused layers
  271. deleteMap := make(map[string]bool)
  272. if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
  273. manifest, _, err = GetManifest(mp)
  274. if err != nil && !errors.Is(err, os.ErrNotExist) {
  275. return err
  276. }
  277. if manifest != nil {
  278. for _, l := range manifest.Layers {
  279. deleteMap[l.Digest] = true
  280. }
  281. deleteMap[manifest.Config.Digest] = true
  282. }
  283. }
  284. mf, err := os.Open(path)
  285. if err != nil {
  286. fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)})
  287. return fmt.Errorf("failed to open file: %w", err)
  288. }
  289. defer mf.Close()
  290. fn(api.ProgressResponse{Status: "parsing modelfile"})
  291. commands, err := parser.Parse(mf)
  292. if err != nil {
  293. return err
  294. }
  295. config := ConfigV2{
  296. Architecture: "amd64",
  297. OS: "linux",
  298. }
  299. var layers []*LayerReader
  300. params := make(map[string][]string)
  301. var sourceParams map[string]any
  302. embed := EmbeddingParams{fn: fn}
  303. for _, c := range commands {
  304. log.Printf("[%s] - %s\n", c.Name, c.Args)
  305. switch c.Name {
  306. case "model":
  307. fn(api.ProgressResponse{Status: "looking for model"})
  308. embed.model = c.Args
  309. mp := ParseModelPath(c.Args)
  310. mf, _, err := GetManifest(mp)
  311. if err != nil {
  312. modelFile, err := filenameWithPath(path, c.Args)
  313. if err != nil {
  314. return err
  315. }
  316. if _, err := os.Stat(modelFile); err != nil {
  317. // the model file does not exist, try pulling it
  318. if errors.Is(err, os.ErrNotExist) {
  319. fn(api.ProgressResponse{Status: "pulling model file"})
  320. if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
  321. return err
  322. }
  323. mf, _, err = GetManifest(mp)
  324. if err != nil {
  325. return fmt.Errorf("failed to open file after pull: %v", err)
  326. }
  327. } else {
  328. return err
  329. }
  330. } else {
  331. embed.model = modelFile
  332. // create a model from this specified file
  333. fn(api.ProgressResponse{Status: "creating model layer"})
  334. file, err := os.Open(modelFile)
  335. if err != nil {
  336. return fmt.Errorf("failed to open file: %v", err)
  337. }
  338. defer file.Close()
  339. ggml, err := llm.DecodeGGML(file)
  340. if err != nil {
  341. return err
  342. }
  343. config.ModelFormat = ggml.Name()
  344. config.ModelFamily = ggml.ModelFamily()
  345. config.ModelType = ggml.ModelType()
  346. config.FileType = ggml.FileType()
  347. // reset the file
  348. file.Seek(0, io.SeekStart)
  349. l, err := CreateLayer(file)
  350. if err != nil {
  351. return fmt.Errorf("failed to create layer: %v", err)
  352. }
  353. l.MediaType = "application/vnd.ollama.image.model"
  354. layers = append(layers, l)
  355. }
  356. }
  357. if mf != nil {
  358. sourceBlobPath, err := GetBlobsPath(mf.Config.Digest)
  359. if err != nil {
  360. return err
  361. }
  362. sourceBlob, err := os.Open(sourceBlobPath)
  363. if err != nil {
  364. return err
  365. }
  366. defer sourceBlob.Close()
  367. var source ConfigV2
  368. if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil {
  369. return err
  370. }
  371. // copy the model metadata
  372. config.ModelFamily = source.ModelFamily
  373. config.ModelType = source.ModelType
  374. config.ModelFormat = source.ModelFormat
  375. config.FileType = source.FileType
  376. for _, l := range mf.Layers {
  377. if l.MediaType == "application/vnd.ollama.image.params" {
  378. sourceParamsBlobPath, err := GetBlobsPath(l.Digest)
  379. if err != nil {
  380. return err
  381. }
  382. sourceParamsBlob, err := os.Open(sourceParamsBlobPath)
  383. if err != nil {
  384. return err
  385. }
  386. defer sourceParamsBlob.Close()
  387. if err := json.NewDecoder(sourceParamsBlob).Decode(&sourceParams); err != nil {
  388. return err
  389. }
  390. }
  391. newLayer, err := GetLayerWithBufferFromLayer(l)
  392. if err != nil {
  393. return err
  394. }
  395. newLayer.From = mp.GetNamespaceRepository()
  396. layers = append(layers, newLayer)
  397. }
  398. }
  399. case "embed":
  400. embedFilePath, err := filenameWithPath(path, c.Args)
  401. if err != nil {
  402. return err
  403. }
  404. embed.files = append(embed.files, embedFilePath)
  405. case "adapter":
  406. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  407. fp, err := filenameWithPath(path, c.Args)
  408. if err != nil {
  409. return err
  410. }
  411. // create a model from this specified file
  412. fn(api.ProgressResponse{Status: "creating model layer"})
  413. file, err := os.Open(fp)
  414. if err != nil {
  415. return fmt.Errorf("failed to open file: %v", err)
  416. }
  417. defer file.Close()
  418. l, err := CreateLayer(file)
  419. if err != nil {
  420. return fmt.Errorf("failed to create layer: %v", err)
  421. }
  422. l.MediaType = "application/vnd.ollama.image.adapter"
  423. layers = append(layers, l)
  424. case "license":
  425. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  426. mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
  427. layer, err := CreateLayer(strings.NewReader(c.Args))
  428. if err != nil {
  429. return err
  430. }
  431. if layer.Size > 0 {
  432. layer.MediaType = mediaType
  433. layers = append(layers, layer)
  434. }
  435. case "template", "system", "prompt":
  436. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  437. // remove the layer if one exists
  438. mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
  439. layers = removeLayerFromLayers(layers, mediaType)
  440. layer, err := CreateLayer(strings.NewReader(c.Args))
  441. if err != nil {
  442. return err
  443. }
  444. if layer.Size > 0 {
  445. layer.MediaType = mediaType
  446. layers = append(layers, layer)
  447. }
  448. default:
  449. // runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop sequences)
  450. params[c.Name] = append(params[c.Name], c.Args)
  451. }
  452. }
  453. // Create a single layer for the parameters
  454. if len(params) > 0 {
  455. fn(api.ProgressResponse{Status: "creating parameter layer"})
  456. layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
  457. formattedParams, err := formatParams(params)
  458. if err != nil {
  459. return fmt.Errorf("couldn't create params json: %v", err)
  460. }
  461. for k, v := range sourceParams {
  462. if _, ok := formattedParams[k]; !ok {
  463. formattedParams[k] = v
  464. }
  465. }
  466. if config.ModelType == "65B" {
  467. if numGQA, ok := formattedParams["num_gqa"].(int); ok && numGQA == 8 {
  468. config.ModelType = "70B"
  469. }
  470. }
  471. bts, err := json.Marshal(formattedParams)
  472. if err != nil {
  473. return err
  474. }
  475. l, err := CreateLayer(bytes.NewReader(bts))
  476. if err != nil {
  477. return fmt.Errorf("failed to create layer: %v", err)
  478. }
  479. l.MediaType = "application/vnd.ollama.image.params"
  480. layers = append(layers, l)
  481. // apply these parameters to the embedding options, in case embeddings need to be generated using this model
  482. embed.opts = formattedParams
  483. }
  484. // generate the embedding layers
  485. embeddingLayers, err := embeddingLayers(workDir, embed)
  486. if err != nil {
  487. return err
  488. }
  489. layers = append(layers, embeddingLayers...)
  490. digests, err := getLayerDigests(layers)
  491. if err != nil {
  492. return err
  493. }
  494. var manifestLayers []*Layer
  495. for _, l := range layers {
  496. manifestLayers = append(manifestLayers, &l.Layer)
  497. delete(deleteMap, l.Layer.Digest)
  498. }
  499. // Create a layer for the config object
  500. fn(api.ProgressResponse{Status: "creating config layer"})
  501. cfg, err := createConfigLayer(config, digests)
  502. if err != nil {
  503. return err
  504. }
  505. layers = append(layers, cfg)
  506. delete(deleteMap, cfg.Layer.Digest)
  507. if err := SaveLayers(layers, fn, false); err != nil {
  508. return err
  509. }
  510. // Create the manifest
  511. fn(api.ProgressResponse{Status: "writing manifest"})
  512. err = CreateManifest(name, cfg, manifestLayers)
  513. if err != nil {
  514. return err
  515. }
  516. if noprune == "" {
  517. fn(api.ProgressResponse{Status: "removing any unused layers"})
  518. err = deleteUnusedLayers(nil, deleteMap, false)
  519. if err != nil {
  520. return err
  521. }
  522. }
  523. fn(api.ProgressResponse{Status: "success"})
  524. return nil
  525. }
  526. type EmbeddingParams struct {
  527. model string
  528. opts map[string]interface{}
  529. files []string // paths to files to embed
  530. fn func(resp api.ProgressResponse)
  531. }
  532. // embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
  533. func embeddingLayers(workDir string, e EmbeddingParams) ([]*LayerReader, error) {
  534. layers := []*LayerReader{}
  535. if len(e.files) > 0 {
  536. // check if the model is a file path or a model name
  537. model, err := GetModel(e.model)
  538. if err != nil {
  539. if !strings.Contains(err.Error(), "couldn't open file") {
  540. return nil, fmt.Errorf("unexpected error opening model to generate embeddings: %v", err)
  541. }
  542. // the model may be a file path, create a model from this file
  543. model = &Model{ModelPath: e.model}
  544. }
  545. if err := load(context.Background(), workDir, model, e.opts, defaultSessionDuration); err != nil {
  546. return nil, fmt.Errorf("load model to generate embeddings: %v", err)
  547. }
  548. // this will be used to check if we already have embeddings for a file
  549. modelInfo, err := os.Stat(model.ModelPath)
  550. if err != nil {
  551. return nil, fmt.Errorf("failed to get model file info: %v", err)
  552. }
  553. addedFiles := make(map[string]bool) // keep track of files that have already been added
  554. for _, filePattern := range e.files {
  555. matchingFiles, err := filepath.Glob(filePattern)
  556. if err != nil {
  557. return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
  558. }
  559. for _, filePath := range matchingFiles {
  560. if addedFiles[filePath] {
  561. continue
  562. }
  563. addedFiles[filePath] = true
  564. // check if we already have embeddings for this file path
  565. layerIdentifier := fmt.Sprintf("%s:%s:%s:%d", filePath, e.model, modelInfo.ModTime().Format("2006-01-02 15:04:05"), modelInfo.Size())
  566. digest, _ := GetSHA256Digest(strings.NewReader(layerIdentifier))
  567. existing, err := existingFileEmbeddings(digest)
  568. if err != nil {
  569. return nil, fmt.Errorf("failed to check existing embeddings for file %s: %v", filePath, err)
  570. }
  571. // TODO: check file type
  572. f, err := os.Open(filePath)
  573. if err != nil {
  574. return nil, fmt.Errorf("could not open embed file: %w", err)
  575. }
  576. scanner := bufio.NewScanner(f)
  577. scanner.Split(bufio.ScanLines)
  578. data := []string{}
  579. for scanner.Scan() {
  580. data = append(data, scanner.Text())
  581. }
  582. f.Close()
  583. // the digest of the file is set here so that the client knows a new operation is in progress
  584. fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
  585. embeddings := []vector.Embedding{}
  586. for i, d := range data {
  587. if strings.TrimSpace(d) == "" {
  588. continue
  589. }
  590. e.fn(api.ProgressResponse{
  591. Status: fmt.Sprintf("creating embeddings for file %s", filePath),
  592. Digest: fileDigest,
  593. Total: int64(len(data) - 1),
  594. Completed: int64(i),
  595. })
  596. if len(existing[d]) > 0 {
  597. // already have an embedding for this line
  598. embeddings = append(embeddings, vector.Embedding{Data: d, Vector: existing[d]})
  599. continue
  600. }
  601. embed, err := loaded.llm.Embedding(context.Background(), d)
  602. if err != nil {
  603. log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
  604. continue
  605. }
  606. embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
  607. }
  608. b, err := json.Marshal(embeddings)
  609. if err != nil {
  610. return nil, fmt.Errorf("failed to encode embeddings: %w", err)
  611. }
  612. r := bytes.NewReader(b)
  613. layer := &LayerReader{
  614. Layer: Layer{
  615. MediaType: "application/vnd.ollama.image.embed",
  616. Digest: digest,
  617. Size: r.Size(),
  618. },
  619. Reader: r,
  620. }
  621. layers = append(layers, layer)
  622. }
  623. }
  624. }
  625. return layers, nil
  626. }
  627. // existingFileEmbeddings checks if we already have embeddings for a file and loads them into a look-up map
  628. func existingFileEmbeddings(digest string) (map[string][]float64, error) {
  629. path, err := GetBlobsPath(digest)
  630. if err != nil {
  631. return nil, fmt.Errorf("embeddings blobs path: %w", err)
  632. }
  633. existingFileEmbeddings := make(map[string][]float64)
  634. if _, err := os.Stat(path); err == nil {
  635. // already have some embeddings for this file, load embeddings previously generated
  636. file, err := os.Open(path)
  637. if err != nil {
  638. return nil, fmt.Errorf("failed to open existing embedding file: %s", err)
  639. }
  640. defer file.Close()
  641. existing := []vector.Embedding{}
  642. if err = json.NewDecoder(file).Decode(&existing); err != nil {
  643. return nil, err
  644. }
  645. for _, e := range existing {
  646. existingFileEmbeddings[e.Data] = e.Vector
  647. }
  648. }
  649. return existingFileEmbeddings, nil
  650. }
  651. func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
  652. return slices.DeleteFunc(layers, func(layer *LayerReader) bool {
  653. return layer.MediaType == mediaType
  654. })
  655. }
  656. func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
  657. // Write each of the layers to disk
  658. for _, layer := range layers {
  659. fp, err := GetBlobsPath(layer.Digest)
  660. if err != nil {
  661. return err
  662. }
  663. _, err = os.Stat(fp)
  664. // note: embed layers are always written since their digest doesnt indicate anything about the contents
  665. if os.IsNotExist(err) || force || layer.MediaType == "application/vnd.ollama.image.embed" {
  666. fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
  667. out, err := os.Create(fp)
  668. if err != nil {
  669. log.Printf("couldn't create %s", fp)
  670. return err
  671. }
  672. defer out.Close()
  673. if _, err = io.Copy(out, layer.Reader); err != nil {
  674. return err
  675. }
  676. } else {
  677. fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
  678. }
  679. }
  680. return nil
  681. }
  682. func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
  683. mp := ParseModelPath(name)
  684. manifest := ManifestV2{
  685. SchemaVersion: 2,
  686. MediaType: "application/vnd.docker.distribution.manifest.v2+json",
  687. Config: Layer{
  688. MediaType: cfg.MediaType,
  689. Size: cfg.Size,
  690. Digest: cfg.Digest,
  691. },
  692. Layers: layers,
  693. }
  694. manifestJSON, err := json.Marshal(manifest)
  695. if err != nil {
  696. return err
  697. }
  698. fp, err := mp.GetManifestPath(true)
  699. if err != nil {
  700. return err
  701. }
  702. return os.WriteFile(fp, manifestJSON, 0o644)
  703. }
  704. func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
  705. fp, err := GetBlobsPath(layer.Digest)
  706. if err != nil {
  707. return nil, err
  708. }
  709. file, err := os.Open(fp)
  710. if err != nil {
  711. return nil, fmt.Errorf("could not open blob: %w", err)
  712. }
  713. defer file.Close()
  714. newLayer, err := CreateLayer(file)
  715. if err != nil {
  716. return nil, err
  717. }
  718. newLayer.MediaType = layer.MediaType
  719. return newLayer, nil
  720. }
  721. // formatParams converts specified parameter options to their correct types
  722. func formatParams(params map[string][]string) (map[string]interface{}, error) {
  723. opts := api.Options{}
  724. valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
  725. typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
  726. // build map of json struct tags to their types
  727. jsonOpts := make(map[string]reflect.StructField)
  728. for _, field := range reflect.VisibleFields(typeOpts) {
  729. jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
  730. if jsonTag != "" {
  731. jsonOpts[jsonTag] = field
  732. }
  733. }
  734. out := make(map[string]interface{})
  735. // iterate params and set values based on json struct tags
  736. for key, vals := range params {
  737. if opt, ok := jsonOpts[key]; ok {
  738. field := valueOpts.FieldByName(opt.Name)
  739. if field.IsValid() && field.CanSet() {
  740. switch field.Kind() {
  741. case reflect.Float32:
  742. floatVal, err := strconv.ParseFloat(vals[0], 32)
  743. if err != nil {
  744. return nil, fmt.Errorf("invalid float value %s", vals)
  745. }
  746. out[key] = float32(floatVal)
  747. case reflect.Int:
  748. intVal, err := strconv.ParseInt(vals[0], 10, 64)
  749. if err != nil {
  750. return nil, fmt.Errorf("invalid int value %s", vals)
  751. }
  752. out[key] = int(intVal)
  753. case reflect.Bool:
  754. boolVal, err := strconv.ParseBool(vals[0])
  755. if err != nil {
  756. return nil, fmt.Errorf("invalid bool value %s", vals)
  757. }
  758. out[key] = boolVal
  759. case reflect.String:
  760. out[key] = vals[0]
  761. case reflect.Slice:
  762. // TODO: only string slices are supported right now
  763. out[key] = vals
  764. default:
  765. return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
  766. }
  767. }
  768. }
  769. }
  770. return out, nil
  771. }
  772. func getLayerDigests(layers []*LayerReader) ([]string, error) {
  773. var digests []string
  774. for _, l := range layers {
  775. if l.Digest == "" {
  776. return nil, fmt.Errorf("layer is missing a digest")
  777. }
  778. digests = append(digests, l.Digest)
  779. }
  780. return digests, nil
  781. }
  782. // CreateLayer creates a Layer object from a given file
  783. func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
  784. digest, size := GetSHA256Digest(f)
  785. f.Seek(0, io.SeekStart)
  786. layer := &LayerReader{
  787. Layer: Layer{
  788. MediaType: "application/vnd.docker.image.rootfs.diff.tar",
  789. Digest: digest,
  790. Size: size,
  791. },
  792. Reader: f,
  793. }
  794. return layer, nil
  795. }
  796. func CopyModel(src, dest string) error {
  797. srcModelPath := ParseModelPath(src)
  798. srcPath, err := srcModelPath.GetManifestPath(false)
  799. if err != nil {
  800. return err
  801. }
  802. destModelPath := ParseModelPath(dest)
  803. destPath, err := destModelPath.GetManifestPath(true)
  804. if err != nil {
  805. return err
  806. }
  807. // copy the file
  808. input, err := os.ReadFile(srcPath)
  809. if err != nil {
  810. fmt.Println("Error reading file:", err)
  811. return err
  812. }
  813. err = os.WriteFile(destPath, input, 0o644)
  814. if err != nil {
  815. fmt.Println("Error reading file:", err)
  816. return err
  817. }
  818. return nil
  819. }
  820. func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dryRun bool) error {
  821. fp, err := GetManifestPath()
  822. if err != nil {
  823. return err
  824. }
  825. walkFunc := func(path string, info os.FileInfo, _ error) error {
  826. if info.IsDir() {
  827. return nil
  828. }
  829. dir, file := filepath.Split(path)
  830. dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
  831. tag := strings.Join([]string{dir, file}, ":")
  832. fmp := ParseModelPath(tag)
  833. // skip the manifest we're trying to delete
  834. if skipModelPath != nil && skipModelPath.GetFullTagname() == fmp.GetFullTagname() {
  835. return nil
  836. }
  837. // save (i.e. delete from the deleteMap) any files used in other manifests
  838. manifest, _, err := GetManifest(fmp)
  839. if err != nil {
  840. return nil
  841. }
  842. for _, layer := range manifest.Layers {
  843. delete(deleteMap, layer.Digest)
  844. }
  845. delete(deleteMap, manifest.Config.Digest)
  846. return nil
  847. }
  848. if err := filepath.Walk(fp, walkFunc); err != nil {
  849. return err
  850. }
  851. // only delete the files which are still in the deleteMap
  852. for k, v := range deleteMap {
  853. if v {
  854. fp, err := GetBlobsPath(k)
  855. if err != nil {
  856. log.Printf("couldn't get file path for '%s': %v", k, err)
  857. continue
  858. }
  859. if !dryRun {
  860. if err := os.Remove(fp); err != nil {
  861. log.Printf("couldn't remove file '%s': %v", fp, err)
  862. continue
  863. }
  864. } else {
  865. log.Printf("wanted to remove: %s", fp)
  866. }
  867. }
  868. }
  869. return nil
  870. }
  871. func PruneLayers() error {
  872. deleteMap := make(map[string]bool)
  873. p, err := GetBlobsPath("")
  874. if err != nil {
  875. return err
  876. }
  877. blobs, err := os.ReadDir(p)
  878. if err != nil {
  879. log.Printf("couldn't read dir '%s': %v", p, err)
  880. return err
  881. }
  882. for _, blob := range blobs {
  883. name := blob.Name()
  884. if runtime.GOOS == "windows" {
  885. name = strings.ReplaceAll(name, "-", ":")
  886. }
  887. deleteMap[name] = true
  888. }
  889. log.Printf("total blobs: %d", len(deleteMap))
  890. err = deleteUnusedLayers(nil, deleteMap, false)
  891. if err != nil {
  892. return err
  893. }
  894. log.Printf("total unused blobs removed: %d", len(deleteMap))
  895. return nil
  896. }
  897. func PruneDirectory(path string) error {
  898. info, err := os.Lstat(path)
  899. if err != nil {
  900. return err
  901. }
  902. if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
  903. entries, err := os.ReadDir(path)
  904. if err != nil {
  905. return err
  906. }
  907. for _, entry := range entries {
  908. if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
  909. return err
  910. }
  911. }
  912. entries, err = os.ReadDir(path)
  913. if err != nil {
  914. return err
  915. }
  916. if len(entries) > 0 {
  917. return nil
  918. }
  919. return os.Remove(path)
  920. }
  921. return nil
  922. }
  923. func DeleteModel(name string) error {
  924. mp := ParseModelPath(name)
  925. manifest, _, err := GetManifest(mp)
  926. if err != nil {
  927. return err
  928. }
  929. deleteMap := make(map[string]bool)
  930. for _, layer := range manifest.Layers {
  931. deleteMap[layer.Digest] = true
  932. }
  933. deleteMap[manifest.Config.Digest] = true
  934. err = deleteUnusedLayers(&mp, deleteMap, false)
  935. if err != nil {
  936. return err
  937. }
  938. fp, err := mp.GetManifestPath(false)
  939. if err != nil {
  940. return err
  941. }
  942. err = os.Remove(fp)
  943. if err != nil {
  944. log.Printf("couldn't remove manifest file '%s': %v", fp, err)
  945. return err
  946. }
  947. return nil
  948. }
  949. func ShowModelfile(model *Model) (string, error) {
  950. type modelTemplate struct {
  951. *Model
  952. From string
  953. Params string
  954. }
  955. var params []string
  956. for k, v := range model.Options {
  957. switch val := v.(type) {
  958. case string:
  959. params = append(params, fmt.Sprintf("PARAMETER %s %s", k, val))
  960. case int:
  961. params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.Itoa(val)))
  962. case float64:
  963. params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatFloat(val, 'f', 0, 64)))
  964. case bool:
  965. params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatBool(val)))
  966. case []interface{}:
  967. for _, nv := range val {
  968. switch nval := nv.(type) {
  969. case string:
  970. params = append(params, fmt.Sprintf("PARAMETER %s %s", k, nval))
  971. case int:
  972. params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.Itoa(nval)))
  973. case float64:
  974. params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatFloat(nval, 'f', 0, 64)))
  975. case bool:
  976. params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatBool(nval)))
  977. default:
  978. log.Printf("unknown type: %s", reflect.TypeOf(nv).String())
  979. }
  980. }
  981. default:
  982. log.Printf("unknown type: %s", reflect.TypeOf(v).String())
  983. }
  984. }
  985. mt := modelTemplate{
  986. Model: model,
  987. From: model.OriginalModel,
  988. Params: strings.Join(params, "\n"),
  989. }
  990. if mt.From == "" {
  991. mt.From = model.ModelPath
  992. }
  993. modelFile := `# Modelfile generated by "ollama show"
  994. # To build a new Modelfile based on this one, replace the FROM line with:
  995. # FROM {{ .ShortName }}
  996. FROM {{ .From }}
  997. TEMPLATE """{{ .Template }}"""
  998. SYSTEM """{{ .System }}"""
  999. {{ .Params }}
  1000. `
  1001. for _, l := range mt.Model.AdapterPaths {
  1002. modelFile += fmt.Sprintf("ADAPTER %s\n", l)
  1003. }
  1004. tmpl, err := template.New("").Parse(modelFile)
  1005. if err != nil {
  1006. log.Printf("error parsing template: %q", err)
  1007. return "", err
  1008. }
  1009. var buf bytes.Buffer
  1010. if err = tmpl.Execute(&buf, mt); err != nil {
  1011. log.Printf("error executing template: %q", err)
  1012. return "", err
  1013. }
  1014. return buf.String(), nil
  1015. }
  1016. func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  1017. mp := ParseModelPath(name)
  1018. fn(api.ProgressResponse{Status: "retrieving manifest"})
  1019. if mp.ProtocolScheme == "http" && !regOpts.Insecure {
  1020. return fmt.Errorf("insecure protocol http")
  1021. }
  1022. manifest, _, err := GetManifest(mp)
  1023. if err != nil {
  1024. fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
  1025. return err
  1026. }
  1027. var layers []*Layer
  1028. layers = append(layers, manifest.Layers...)
  1029. layers = append(layers, &manifest.Config)
  1030. for _, layer := range layers {
  1031. exists, err := checkBlobExistence(ctx, mp, layer.Digest, regOpts)
  1032. if err != nil {
  1033. return err
  1034. }
  1035. if exists {
  1036. fn(api.ProgressResponse{
  1037. Status: "using existing layer",
  1038. Digest: layer.Digest,
  1039. Total: layer.Size,
  1040. Completed: layer.Size,
  1041. })
  1042. log.Printf("Layer %s already exists", layer.Digest)
  1043. continue
  1044. }
  1045. fn(api.ProgressResponse{
  1046. Status: "starting upload",
  1047. Digest: layer.Digest,
  1048. Total: layer.Size,
  1049. })
  1050. location, chunkSize, err := startUpload(ctx, mp, layer, regOpts)
  1051. if err != nil {
  1052. log.Printf("couldn't start upload: %v", err)
  1053. return err
  1054. }
  1055. if strings.HasPrefix(filepath.Base(location.Path), "sha256:") {
  1056. layer.Digest = filepath.Base(location.Path)
  1057. fn(api.ProgressResponse{
  1058. Status: "using existing layer",
  1059. Digest: layer.Digest,
  1060. Total: layer.Size,
  1061. Completed: layer.Size,
  1062. })
  1063. continue
  1064. }
  1065. if err := uploadBlob(ctx, location, layer, chunkSize, regOpts, fn); err != nil {
  1066. log.Printf("error uploading blob: %v", err)
  1067. return err
  1068. }
  1069. }
  1070. fn(api.ProgressResponse{Status: "pushing manifest"})
  1071. requestURL := mp.BaseURL()
  1072. requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
  1073. manifestJSON, err := json.Marshal(manifest)
  1074. if err != nil {
  1075. return err
  1076. }
  1077. headers := make(http.Header)
  1078. headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
  1079. resp, err := makeRequestWithRetry(ctx, "PUT", requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
  1080. if err != nil {
  1081. return err
  1082. }
  1083. defer resp.Body.Close()
  1084. fn(api.ProgressResponse{Status: "success"})
  1085. return nil
  1086. }
  1087. func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  1088. mp := ParseModelPath(name)
  1089. var manifest *ManifestV2
  1090. var err error
  1091. var noprune string
  1092. // build deleteMap to prune unused layers
  1093. deleteMap := make(map[string]bool)
  1094. if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
  1095. manifest, _, err = GetManifest(mp)
  1096. if err != nil && !errors.Is(err, os.ErrNotExist) {
  1097. return err
  1098. }
  1099. if manifest != nil {
  1100. for _, l := range manifest.Layers {
  1101. deleteMap[l.Digest] = true
  1102. }
  1103. deleteMap[manifest.Config.Digest] = true
  1104. }
  1105. }
  1106. if mp.ProtocolScheme == "http" && !regOpts.Insecure {
  1107. return fmt.Errorf("insecure protocol http")
  1108. }
  1109. fn(api.ProgressResponse{Status: "pulling manifest"})
  1110. manifest, err = pullModelManifest(ctx, mp, regOpts)
  1111. if err != nil {
  1112. return fmt.Errorf("pull model manifest: %s", err)
  1113. }
  1114. var layers []*Layer
  1115. layers = append(layers, manifest.Layers...)
  1116. layers = append(layers, &manifest.Config)
  1117. for _, layer := range layers {
  1118. if err := downloadBlob(
  1119. ctx,
  1120. downloadOpts{
  1121. mp: mp,
  1122. digest: layer.Digest,
  1123. regOpts: regOpts,
  1124. fn: fn,
  1125. }); err != nil {
  1126. return err
  1127. }
  1128. delete(deleteMap, layer.Digest)
  1129. }
  1130. delete(deleteMap, manifest.Config.Digest)
  1131. fn(api.ProgressResponse{Status: "verifying sha256 digest"})
  1132. for _, layer := range layers {
  1133. if err := verifyBlob(layer.Digest); err != nil {
  1134. if errors.Is(err, errDigestMismatch) {
  1135. // something went wrong, delete the blob
  1136. fp, err := GetBlobsPath(layer.Digest)
  1137. if err != nil {
  1138. return err
  1139. }
  1140. if err := os.Remove(fp); err != nil {
  1141. // log this, but return the original error
  1142. log.Printf("couldn't remove file with digest mismatch '%s': %v", fp, err)
  1143. }
  1144. }
  1145. return err
  1146. }
  1147. }
  1148. fn(api.ProgressResponse{Status: "writing manifest"})
  1149. manifestJSON, err := json.Marshal(manifest)
  1150. if err != nil {
  1151. return err
  1152. }
  1153. fp, err := mp.GetManifestPath(true)
  1154. if err != nil {
  1155. return err
  1156. }
  1157. err = os.WriteFile(fp, manifestJSON, 0o644)
  1158. if err != nil {
  1159. log.Printf("couldn't write to %s", fp)
  1160. return err
  1161. }
  1162. if noprune == "" {
  1163. fn(api.ProgressResponse{Status: "removing any unused layers"})
  1164. err = deleteUnusedLayers(nil, deleteMap, false)
  1165. if err != nil {
  1166. return err
  1167. }
  1168. }
  1169. fn(api.ProgressResponse{Status: "success"})
  1170. return nil
  1171. }
  1172. func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
  1173. requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
  1174. headers := make(http.Header)
  1175. headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
  1176. resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, regOpts)
  1177. if err != nil {
  1178. log.Printf("couldn't get manifest: %v", err)
  1179. return nil, err
  1180. }
  1181. defer resp.Body.Close()
  1182. if resp.StatusCode >= http.StatusBadRequest {
  1183. if resp.StatusCode == http.StatusNotFound {
  1184. return nil, fmt.Errorf("model not found")
  1185. }
  1186. body, _ := io.ReadAll(resp.Body)
  1187. return nil, fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
  1188. }
  1189. var m *ManifestV2
  1190. if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
  1191. return nil, err
  1192. }
  1193. return m, err
  1194. }
  1195. func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
  1196. config.RootFS = RootFS{
  1197. Type: "layers",
  1198. DiffIDs: layers,
  1199. }
  1200. configJSON, err := json.Marshal(config)
  1201. if err != nil {
  1202. return nil, err
  1203. }
  1204. digest, size := GetSHA256Digest(bytes.NewBuffer(configJSON))
  1205. layer := &LayerReader{
  1206. Layer: Layer{
  1207. MediaType: "application/vnd.docker.container.image.v1+json",
  1208. Digest: digest,
  1209. Size: size,
  1210. },
  1211. Reader: bytes.NewBuffer(configJSON),
  1212. }
  1213. return layer, nil
  1214. }
  1215. // GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
  1216. func GetSHA256Digest(r io.Reader) (string, int64) {
  1217. h := sha256.New()
  1218. n, err := io.Copy(h, r)
  1219. if err != nil {
  1220. log.Fatal(err)
  1221. }
  1222. return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
  1223. }
  1224. // Function to check if a blob already exists in the Docker registry
  1225. func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
  1226. requestURL := mp.BaseURL()
  1227. requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", digest)
  1228. resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, regOpts)
  1229. if err != nil {
  1230. log.Printf("couldn't check for blob: %v", err)
  1231. return false, err
  1232. }
  1233. defer resp.Body.Close()
  1234. // Check for success: If the blob exists, the Docker registry will respond with a 200 OK
  1235. return resp.StatusCode < http.StatusBadRequest, nil
  1236. }
  1237. func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
  1238. var status string
  1239. for try := 0; try < MaxRetries; try++ {
  1240. resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
  1241. if err != nil {
  1242. log.Printf("couldn't start upload: %v", err)
  1243. return nil, err
  1244. }
  1245. status = resp.Status
  1246. switch {
  1247. case resp.StatusCode == http.StatusUnauthorized:
  1248. auth := resp.Header.Get("www-authenticate")
  1249. authRedir := ParseAuthRedirectString(auth)
  1250. token, err := getAuthToken(ctx, authRedir)
  1251. if err != nil {
  1252. return nil, err
  1253. }
  1254. regOpts.Token = token
  1255. if body != nil {
  1256. if _, err := body.Seek(0, io.SeekStart); err != nil {
  1257. return nil, err
  1258. }
  1259. }
  1260. continue
  1261. case resp.StatusCode >= http.StatusBadRequest:
  1262. body, _ := io.ReadAll(resp.Body)
  1263. return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
  1264. default:
  1265. return resp, nil
  1266. }
  1267. }
  1268. return nil, fmt.Errorf("max retry exceeded: %v", status)
  1269. }
  1270. func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
  1271. if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
  1272. requestURL.Scheme = "http"
  1273. }
  1274. req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
  1275. if err != nil {
  1276. return nil, err
  1277. }
  1278. if headers != nil {
  1279. req.Header = headers
  1280. }
  1281. if regOpts != nil {
  1282. if regOpts.Token != "" {
  1283. req.Header.Set("Authorization", "Bearer "+regOpts.Token)
  1284. } else if regOpts.Username != "" && regOpts.Password != "" {
  1285. req.SetBasicAuth(regOpts.Username, regOpts.Password)
  1286. }
  1287. }
  1288. req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
  1289. if s := req.Header.Get("Content-Length"); s != "" {
  1290. contentLength, err := strconv.ParseInt(s, 10, 64)
  1291. if err != nil {
  1292. return nil, err
  1293. }
  1294. req.ContentLength = contentLength
  1295. }
  1296. client := &http.Client{
  1297. CheckRedirect: func(req *http.Request, via []*http.Request) error {
  1298. if len(via) >= 10 {
  1299. return fmt.Errorf("too many redirects")
  1300. }
  1301. log.Printf("redirected to: %s\n", req.URL)
  1302. return nil
  1303. },
  1304. }
  1305. resp, err := client.Do(req)
  1306. if err != nil {
  1307. return nil, err
  1308. }
  1309. return resp, nil
  1310. }
  1311. func getValue(header, key string) string {
  1312. startIdx := strings.Index(header, key+"=")
  1313. if startIdx == -1 {
  1314. return ""
  1315. }
  1316. // Move the index to the starting quote after the key.
  1317. startIdx += len(key) + 2
  1318. endIdx := startIdx
  1319. for endIdx < len(header) {
  1320. if header[endIdx] == '"' {
  1321. if endIdx+1 < len(header) && header[endIdx+1] != ',' { // If the next character isn't a comma, continue
  1322. endIdx++
  1323. continue
  1324. }
  1325. break
  1326. }
  1327. endIdx++
  1328. }
  1329. return header[startIdx:endIdx]
  1330. }
  1331. func ParseAuthRedirectString(authStr string) AuthRedirect {
  1332. authStr = strings.TrimPrefix(authStr, "Bearer ")
  1333. return AuthRedirect{
  1334. Realm: getValue(authStr, "realm"),
  1335. Service: getValue(authStr, "service"),
  1336. Scope: getValue(authStr, "scope"),
  1337. }
  1338. }
  1339. var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again")
  1340. func verifyBlob(digest string) error {
  1341. fp, err := GetBlobsPath(digest)
  1342. if err != nil {
  1343. return err
  1344. }
  1345. f, err := os.Open(fp)
  1346. if err != nil {
  1347. return err
  1348. }
  1349. defer f.Close()
  1350. fileDigest, _ := GetSHA256Digest(f)
  1351. if digest != fileDigest {
  1352. return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest)
  1353. }
  1354. return nil
  1355. }