images.go 32 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286
  1. package server
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "crypto/sha256"
  7. "encoding/hex"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "io"
  12. "log"
  13. "net/http"
  14. "net/url"
  15. "os"
  16. "path"
  17. "path/filepath"
  18. "reflect"
  19. "runtime"
  20. "strconv"
  21. "strings"
  22. "text/template"
  23. "github.com/jmorganca/ollama/api"
  24. "github.com/jmorganca/ollama/llm"
  25. "github.com/jmorganca/ollama/parser"
  26. "github.com/jmorganca/ollama/vector"
  27. "github.com/jmorganca/ollama/version"
  28. )
  29. const MaxRetries = 3
  30. type RegistryOptions struct {
  31. Insecure bool
  32. Username string
  33. Password string
  34. Token string
  35. }
  36. type Model struct {
  37. Name string `json:"name"`
  38. ModelPath string
  39. AdapterPaths []string
  40. Template string
  41. System string
  42. Digest string
  43. ConfigDigest string
  44. Options map[string]interface{}
  45. Embeddings []vector.Embedding
  46. }
  47. func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
  48. t := m.Template
  49. if request.Template != "" {
  50. t = request.Template
  51. }
  52. tmpl, err := template.New("").Parse(t)
  53. if err != nil {
  54. return "", err
  55. }
  56. var vars struct {
  57. First bool
  58. System string
  59. Prompt string
  60. Embed string
  61. Args map[string]any
  62. // deprecated: versions <= 0.0.7 used this to omit the system prompt
  63. Context []int
  64. }
  65. vars.First = len(request.Context) == 0
  66. vars.System = m.System
  67. vars.Prompt = request.Prompt
  68. vars.Context = request.Context
  69. vars.Args = request.Args
  70. vars.Embed = embedding
  71. if request.System != "" {
  72. vars.System = request.System
  73. }
  74. var sb strings.Builder
  75. if err := tmpl.Execute(&sb, vars); err != nil {
  76. return "", err
  77. }
  78. return sb.String(), nil
  79. }
  80. type ManifestV2 struct {
  81. SchemaVersion int `json:"schemaVersion"`
  82. MediaType string `json:"mediaType"`
  83. Config Layer `json:"config"`
  84. Layers []*Layer `json:"layers"`
  85. }
  86. type Layer struct {
  87. MediaType string `json:"mediaType"`
  88. Digest string `json:"digest"`
  89. Size int `json:"size"`
  90. From string `json:"from,omitempty"`
  91. }
  92. type LayerReader struct {
  93. Layer
  94. io.Reader
  95. }
  96. type ConfigV2 struct {
  97. ModelFamily llm.ModelFamily `json:"model_family"`
  98. ModelType string `json:"model_type"`
  99. FileType string `json:"file_type"`
  100. RootFS RootFS `json:"rootfs"`
  101. // required by spec
  102. Architecture string `json:"architecture"`
  103. OS string `json:"os"`
  104. }
  105. type RootFS struct {
  106. Type string `json:"type"`
  107. DiffIDs []string `json:"diff_ids"`
  108. }
  109. func (m *ManifestV2) GetTotalSize() int {
  110. var total int
  111. for _, layer := range m.Layers {
  112. total += layer.Size
  113. }
  114. total += m.Config.Size
  115. return total
  116. }
  117. func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
  118. fp, err := mp.GetManifestPath(false)
  119. if err != nil {
  120. return nil, "", err
  121. }
  122. if _, err = os.Stat(fp); err != nil {
  123. return nil, "", err
  124. }
  125. var manifest *ManifestV2
  126. bts, err := os.ReadFile(fp)
  127. if err != nil {
  128. return nil, "", fmt.Errorf("couldn't open file '%s'", fp)
  129. }
  130. shaSum := sha256.Sum256(bts)
  131. shaStr := hex.EncodeToString(shaSum[:])
  132. if err := json.Unmarshal(bts, &manifest); err != nil {
  133. return nil, "", err
  134. }
  135. return manifest, shaStr, nil
  136. }
  137. func GetModel(name string) (*Model, error) {
  138. mp := ParseModelPath(name)
  139. manifest, digest, err := GetManifest(mp)
  140. if err != nil {
  141. return nil, err
  142. }
  143. model := &Model{
  144. Name: mp.GetFullTagname(),
  145. Digest: digest,
  146. ConfigDigest: manifest.Config.Digest,
  147. Template: "{{ .Prompt }}",
  148. }
  149. for _, layer := range manifest.Layers {
  150. filename, err := GetBlobsPath(layer.Digest)
  151. if err != nil {
  152. return nil, err
  153. }
  154. switch layer.MediaType {
  155. case "application/vnd.ollama.image.model":
  156. model.ModelPath = filename
  157. case "application/vnd.ollama.image.embed":
  158. file, err := os.Open(filename)
  159. if err != nil {
  160. return nil, fmt.Errorf("failed to open file: %s", filename)
  161. }
  162. defer file.Close()
  163. if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
  164. return nil, err
  165. }
  166. case "application/vnd.ollama.image.adapter":
  167. model.AdapterPaths = append(model.AdapterPaths, filename)
  168. case "application/vnd.ollama.image.template":
  169. bts, err := os.ReadFile(filename)
  170. if err != nil {
  171. return nil, err
  172. }
  173. model.Template = string(bts)
  174. case "application/vnd.ollama.image.system":
  175. bts, err := os.ReadFile(filename)
  176. if err != nil {
  177. return nil, err
  178. }
  179. model.System = string(bts)
  180. case "application/vnd.ollama.image.prompt":
  181. bts, err := os.ReadFile(filename)
  182. if err != nil {
  183. return nil, err
  184. }
  185. model.Template = string(bts)
  186. case "application/vnd.ollama.image.params":
  187. params, err := os.Open(filename)
  188. if err != nil {
  189. return nil, err
  190. }
  191. defer params.Close()
  192. // parse model options parameters into a map so that we can see which fields have been specified explicitly
  193. if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
  194. return nil, err
  195. }
  196. }
  197. }
  198. return model, nil
  199. }
  200. func filenameWithPath(path, f string) (string, error) {
  201. // if filePath starts with ~/, replace it with the user's home directory.
  202. if strings.HasPrefix(f, fmt.Sprintf("~%s", string(os.PathSeparator))) {
  203. parts := strings.Split(f, string(os.PathSeparator))
  204. home, err := os.UserHomeDir()
  205. if err != nil {
  206. return "", fmt.Errorf("failed to open file: %v", err)
  207. }
  208. f = filepath.Join(home, filepath.Join(parts[1:]...))
  209. }
  210. // if filePath is not an absolute path, make it relative to the modelfile path
  211. if !filepath.IsAbs(f) {
  212. f = filepath.Join(filepath.Dir(path), f)
  213. }
  214. return f, nil
  215. }
  216. func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error {
  217. mf, err := os.Open(path)
  218. if err != nil {
  219. fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)})
  220. return fmt.Errorf("failed to open file: %w", err)
  221. }
  222. defer mf.Close()
  223. fn(api.ProgressResponse{Status: "parsing modelfile"})
  224. commands, err := parser.Parse(mf)
  225. if err != nil {
  226. return err
  227. }
  228. config := ConfigV2{
  229. Architecture: "amd64",
  230. OS: "linux",
  231. }
  232. var layers []*LayerReader
  233. params := make(map[string][]string)
  234. embed := EmbeddingParams{fn: fn}
  235. for _, c := range commands {
  236. log.Printf("[%s] - %s\n", c.Name, c.Args)
  237. switch c.Name {
  238. case "model":
  239. fn(api.ProgressResponse{Status: "looking for model"})
  240. embed.model = c.Args
  241. mp := ParseModelPath(c.Args)
  242. mf, _, err := GetManifest(mp)
  243. if err != nil {
  244. modelFile, err := filenameWithPath(path, c.Args)
  245. if err != nil {
  246. return err
  247. }
  248. if _, err := os.Stat(modelFile); err != nil {
  249. // the model file does not exist, try pulling it
  250. if errors.Is(err, os.ErrNotExist) {
  251. fn(api.ProgressResponse{Status: "pulling model file"})
  252. if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
  253. return err
  254. }
  255. mf, _, err = GetManifest(mp)
  256. if err != nil {
  257. return fmt.Errorf("failed to open file after pull: %v", err)
  258. }
  259. } else {
  260. return err
  261. }
  262. } else {
  263. embed.model = modelFile
  264. // create a model from this specified file
  265. fn(api.ProgressResponse{Status: "creating model layer"})
  266. file, err := os.Open(modelFile)
  267. if err != nil {
  268. return fmt.Errorf("failed to open file: %v", err)
  269. }
  270. defer file.Close()
  271. ggml, err := llm.DecodeGGML(file, llm.ModelFamilyLlama)
  272. if err != nil {
  273. return err
  274. }
  275. config.ModelFamily = ggml.ModelFamily()
  276. config.ModelType = ggml.ModelType().String()
  277. config.FileType = ggml.FileType().String()
  278. // reset the file
  279. file.Seek(0, io.SeekStart)
  280. l, err := CreateLayer(file)
  281. if err != nil {
  282. return fmt.Errorf("failed to create layer: %v", err)
  283. }
  284. l.MediaType = "application/vnd.ollama.image.model"
  285. layers = append(layers, l)
  286. }
  287. }
  288. if mf != nil {
  289. sourceBlobPath, err := GetBlobsPath(mf.Config.Digest)
  290. if err != nil {
  291. return err
  292. }
  293. sourceBlob, err := os.Open(sourceBlobPath)
  294. if err != nil {
  295. return err
  296. }
  297. defer sourceBlob.Close()
  298. var source ConfigV2
  299. if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil {
  300. return err
  301. }
  302. // copie the model metadata
  303. config.ModelFamily = source.ModelFamily
  304. config.ModelType = source.ModelType
  305. config.FileType = source.FileType
  306. for _, l := range mf.Layers {
  307. newLayer, err := GetLayerWithBufferFromLayer(l)
  308. if err != nil {
  309. return err
  310. }
  311. newLayer.From = mp.GetNamespaceRepository()
  312. layers = append(layers, newLayer)
  313. }
  314. }
  315. case "embed":
  316. embedFilePath, err := filenameWithPath(path, c.Args)
  317. if err != nil {
  318. return err
  319. }
  320. embed.files = append(embed.files, embedFilePath)
  321. case "adapter":
  322. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  323. fp, err := filenameWithPath(path, c.Args)
  324. if err != nil {
  325. return err
  326. }
  327. // create a model from this specified file
  328. fn(api.ProgressResponse{Status: "creating model layer"})
  329. file, err := os.Open(fp)
  330. if err != nil {
  331. return fmt.Errorf("failed to open file: %v", err)
  332. }
  333. defer file.Close()
  334. l, err := CreateLayer(file)
  335. if err != nil {
  336. return fmt.Errorf("failed to create layer: %v", err)
  337. }
  338. l.MediaType = "application/vnd.ollama.image.adapter"
  339. layers = append(layers, l)
  340. case "license":
  341. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  342. mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
  343. layer, err := CreateLayer(strings.NewReader(c.Args))
  344. if err != nil {
  345. return err
  346. }
  347. layer.MediaType = mediaType
  348. layers = append(layers, layer)
  349. case "template", "system", "prompt":
  350. fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
  351. // remove the layer if one exists
  352. mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
  353. layers = removeLayerFromLayers(layers, mediaType)
  354. layer, err := CreateLayer(strings.NewReader(c.Args))
  355. if err != nil {
  356. return err
  357. }
  358. layer.MediaType = mediaType
  359. layers = append(layers, layer)
  360. default:
  361. // runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop sequences)
  362. params[c.Name] = append(params[c.Name], c.Args)
  363. }
  364. }
  365. // Create a single layer for the parameters
  366. if len(params) > 0 {
  367. fn(api.ProgressResponse{Status: "creating parameter layer"})
  368. layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
  369. formattedParams, err := formatParams(params)
  370. if err != nil {
  371. return fmt.Errorf("couldn't create params json: %v", err)
  372. }
  373. bts, err := json.Marshal(formattedParams)
  374. if err != nil {
  375. return err
  376. }
  377. l, err := CreateLayer(bytes.NewReader(bts))
  378. if err != nil {
  379. return fmt.Errorf("failed to create layer: %v", err)
  380. }
  381. l.MediaType = "application/vnd.ollama.image.params"
  382. layers = append(layers, l)
  383. // apply these parameters to the embedding options, in case embeddings need to be generated using this model
  384. embed.opts = formattedParams
  385. }
  386. // generate the embedding layers
  387. embeddingLayers, err := embeddingLayers(embed)
  388. if err != nil {
  389. return err
  390. }
  391. layers = append(layers, embeddingLayers...)
  392. digests, err := getLayerDigests(layers)
  393. if err != nil {
  394. return err
  395. }
  396. var manifestLayers []*Layer
  397. for _, l := range layers {
  398. manifestLayers = append(manifestLayers, &l.Layer)
  399. }
  400. // Create a layer for the config object
  401. fn(api.ProgressResponse{Status: "creating config layer"})
  402. cfg, err := createConfigLayer(config, digests)
  403. if err != nil {
  404. return err
  405. }
  406. layers = append(layers, cfg)
  407. if err := SaveLayers(layers, fn, false); err != nil {
  408. return err
  409. }
  410. // Create the manifest
  411. fn(api.ProgressResponse{Status: "writing manifest"})
  412. err = CreateManifest(name, cfg, manifestLayers)
  413. if err != nil {
  414. return err
  415. }
  416. fn(api.ProgressResponse{Status: "success"})
  417. return nil
  418. }
  419. type EmbeddingParams struct {
  420. model string
  421. opts map[string]interface{}
  422. files []string // paths to files to embed
  423. fn func(resp api.ProgressResponse)
  424. }
  425. // embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
  426. func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
  427. layers := []*LayerReader{}
  428. if len(e.files) > 0 {
  429. // check if the model is a file path or a model name
  430. model, err := GetModel(e.model)
  431. if err != nil {
  432. if !strings.Contains(err.Error(), "couldn't open file") {
  433. return nil, fmt.Errorf("unexpected error opening model to generate embeddings: %v", err)
  434. }
  435. // the model may be a file path, create a model from this file
  436. model = &Model{ModelPath: e.model}
  437. }
  438. if err := load(context.Background(), model, e.opts, defaultSessionDuration); err != nil {
  439. return nil, fmt.Errorf("load model to generate embeddings: %v", err)
  440. }
  441. // this will be used to check if we already have embeddings for a file
  442. modelInfo, err := os.Stat(model.ModelPath)
  443. if err != nil {
  444. return nil, fmt.Errorf("failed to get model file info: %v", err)
  445. }
  446. addedFiles := make(map[string]bool) // keep track of files that have already been added
  447. for _, filePattern := range e.files {
  448. matchingFiles, err := filepath.Glob(filePattern)
  449. if err != nil {
  450. return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
  451. }
  452. for _, filePath := range matchingFiles {
  453. if addedFiles[filePath] {
  454. continue
  455. }
  456. addedFiles[filePath] = true
  457. // check if we already have embeddings for this file path
  458. layerIdentifier := fmt.Sprintf("%s:%s:%s:%d", filePath, e.model, modelInfo.ModTime().Format("2006-01-02 15:04:05"), modelInfo.Size())
  459. digest, _ := GetSHA256Digest(strings.NewReader(layerIdentifier))
  460. existing, err := existingFileEmbeddings(digest)
  461. if err != nil {
  462. return nil, fmt.Errorf("failed to check existing embeddings for file %s: %v", filePath, err)
  463. }
  464. // TODO: check file type
  465. f, err := os.Open(filePath)
  466. if err != nil {
  467. return nil, fmt.Errorf("could not open embed file: %w", err)
  468. }
  469. scanner := bufio.NewScanner(f)
  470. scanner.Split(bufio.ScanLines)
  471. data := []string{}
  472. for scanner.Scan() {
  473. data = append(data, scanner.Text())
  474. }
  475. f.Close()
  476. // the digest of the file is set here so that the client knows a new operation is in progress
  477. fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
  478. embeddings := []vector.Embedding{}
  479. for i, d := range data {
  480. if strings.TrimSpace(d) == "" {
  481. continue
  482. }
  483. e.fn(api.ProgressResponse{
  484. Status: fmt.Sprintf("creating embeddings for file %s", filePath),
  485. Digest: fileDigest,
  486. Total: len(data) - 1,
  487. Completed: i,
  488. })
  489. if len(existing[d]) > 0 {
  490. // already have an embedding for this line
  491. embeddings = append(embeddings, vector.Embedding{Data: d, Vector: existing[d]})
  492. continue
  493. }
  494. embed, err := loaded.llm.Embedding(context.Background(), d)
  495. if err != nil {
  496. log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
  497. continue
  498. }
  499. embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
  500. }
  501. b, err := json.Marshal(embeddings)
  502. if err != nil {
  503. return nil, fmt.Errorf("failed to encode embeddings: %w", err)
  504. }
  505. r := bytes.NewReader(b)
  506. layer := &LayerReader{
  507. Layer: Layer{
  508. MediaType: "application/vnd.ollama.image.embed",
  509. Digest: digest,
  510. Size: r.Len(),
  511. },
  512. Reader: r,
  513. }
  514. layers = append(layers, layer)
  515. }
  516. }
  517. }
  518. return layers, nil
  519. }
  520. // existingFileEmbeddings checks if we already have embeddings for a file and loads them into a look-up map
  521. func existingFileEmbeddings(digest string) (map[string][]float64, error) {
  522. path, err := GetBlobsPath(digest)
  523. if err != nil {
  524. return nil, fmt.Errorf("embeddings blobs path: %w", err)
  525. }
  526. existingFileEmbeddings := make(map[string][]float64)
  527. if _, err := os.Stat(path); err == nil {
  528. // already have some embeddings for this file, load embeddings previously generated
  529. file, err := os.Open(path)
  530. if err != nil {
  531. return nil, fmt.Errorf("failed to open existing embedding file: %s", err)
  532. }
  533. defer file.Close()
  534. existing := []vector.Embedding{}
  535. if err = json.NewDecoder(file).Decode(&existing); err != nil {
  536. return nil, err
  537. }
  538. for _, e := range existing {
  539. existingFileEmbeddings[e.Data] = e.Vector
  540. }
  541. }
  542. return existingFileEmbeddings, nil
  543. }
  544. func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
  545. j := 0
  546. for _, l := range layers {
  547. if l.MediaType != mediaType {
  548. layers[j] = l
  549. j++
  550. }
  551. }
  552. return layers[:j]
  553. }
  554. func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
  555. // Write each of the layers to disk
  556. for _, layer := range layers {
  557. fp, err := GetBlobsPath(layer.Digest)
  558. if err != nil {
  559. return err
  560. }
  561. _, err = os.Stat(fp)
  562. // note: embed layers are always written since their digest doesnt indicate anything about the contents
  563. if os.IsNotExist(err) || force || layer.MediaType == "application/vnd.ollama.image.embed" {
  564. fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
  565. out, err := os.Create(fp)
  566. if err != nil {
  567. log.Printf("couldn't create %s", fp)
  568. return err
  569. }
  570. defer out.Close()
  571. if _, err = io.Copy(out, layer.Reader); err != nil {
  572. return err
  573. }
  574. } else {
  575. fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
  576. }
  577. }
  578. return nil
  579. }
  580. func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
  581. mp := ParseModelPath(name)
  582. manifest := ManifestV2{
  583. SchemaVersion: 2,
  584. MediaType: "application/vnd.docker.distribution.manifest.v2+json",
  585. Config: Layer{
  586. MediaType: cfg.MediaType,
  587. Size: cfg.Size,
  588. Digest: cfg.Digest,
  589. },
  590. Layers: layers,
  591. }
  592. manifestJSON, err := json.Marshal(manifest)
  593. if err != nil {
  594. return err
  595. }
  596. fp, err := mp.GetManifestPath(true)
  597. if err != nil {
  598. return err
  599. }
  600. return os.WriteFile(fp, manifestJSON, 0o644)
  601. }
  602. func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
  603. fp, err := GetBlobsPath(layer.Digest)
  604. if err != nil {
  605. return nil, err
  606. }
  607. file, err := os.Open(fp)
  608. if err != nil {
  609. return nil, fmt.Errorf("could not open blob: %w", err)
  610. }
  611. defer file.Close()
  612. newLayer, err := CreateLayer(file)
  613. if err != nil {
  614. return nil, err
  615. }
  616. newLayer.MediaType = layer.MediaType
  617. return newLayer, nil
  618. }
  619. // formatParams converts specified parameter options to their correct types
  620. func formatParams(params map[string][]string) (map[string]interface{}, error) {
  621. opts := api.Options{}
  622. valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
  623. typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
  624. // build map of json struct tags to their types
  625. jsonOpts := make(map[string]reflect.StructField)
  626. for _, field := range reflect.VisibleFields(typeOpts) {
  627. jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
  628. if jsonTag != "" {
  629. jsonOpts[jsonTag] = field
  630. }
  631. }
  632. out := make(map[string]interface{})
  633. // iterate params and set values based on json struct tags
  634. for key, vals := range params {
  635. if opt, ok := jsonOpts[key]; ok {
  636. field := valueOpts.FieldByName(opt.Name)
  637. if field.IsValid() && field.CanSet() {
  638. switch field.Kind() {
  639. case reflect.Float32:
  640. floatVal, err := strconv.ParseFloat(vals[0], 32)
  641. if err != nil {
  642. return nil, fmt.Errorf("invalid float value %s", vals)
  643. }
  644. out[key] = floatVal
  645. case reflect.Int:
  646. intVal, err := strconv.ParseInt(vals[0], 10, 0)
  647. if err != nil {
  648. return nil, fmt.Errorf("invalid int value %s", vals)
  649. }
  650. out[key] = intVal
  651. case reflect.Bool:
  652. boolVal, err := strconv.ParseBool(vals[0])
  653. if err != nil {
  654. return nil, fmt.Errorf("invalid bool value %s", vals)
  655. }
  656. out[key] = boolVal
  657. case reflect.String:
  658. out[key] = vals[0]
  659. case reflect.Slice:
  660. // TODO: only string slices are supported right now
  661. out[key] = vals
  662. default:
  663. return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
  664. }
  665. }
  666. }
  667. }
  668. return out, nil
  669. }
  670. func getLayerDigests(layers []*LayerReader) ([]string, error) {
  671. var digests []string
  672. for _, l := range layers {
  673. if l.Digest == "" {
  674. return nil, fmt.Errorf("layer is missing a digest")
  675. }
  676. digests = append(digests, l.Digest)
  677. }
  678. return digests, nil
  679. }
  680. // CreateLayer creates a Layer object from a given file
  681. func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
  682. digest, size := GetSHA256Digest(f)
  683. f.Seek(0, io.SeekStart)
  684. layer := &LayerReader{
  685. Layer: Layer{
  686. MediaType: "application/vnd.docker.image.rootfs.diff.tar",
  687. Digest: digest,
  688. Size: size,
  689. },
  690. Reader: f,
  691. }
  692. return layer, nil
  693. }
  694. func CopyModel(src, dest string) error {
  695. srcModelPath := ParseModelPath(src)
  696. srcPath, err := srcModelPath.GetManifestPath(false)
  697. if err != nil {
  698. return err
  699. }
  700. destModelPath := ParseModelPath(dest)
  701. destPath, err := destModelPath.GetManifestPath(true)
  702. if err != nil {
  703. return err
  704. }
  705. // copy the file
  706. input, err := os.ReadFile(srcPath)
  707. if err != nil {
  708. fmt.Println("Error reading file:", err)
  709. return err
  710. }
  711. err = os.WriteFile(destPath, input, 0o644)
  712. if err != nil {
  713. fmt.Println("Error reading file:", err)
  714. return err
  715. }
  716. return nil
  717. }
  718. func DeleteModel(name string) error {
  719. mp := ParseModelPath(name)
  720. manifest, _, err := GetManifest(mp)
  721. if err != nil {
  722. return err
  723. }
  724. deleteMap := make(map[string]bool)
  725. for _, layer := range manifest.Layers {
  726. deleteMap[layer.Digest] = true
  727. }
  728. deleteMap[manifest.Config.Digest] = true
  729. fp, err := GetManifestPath()
  730. if err != nil {
  731. return err
  732. }
  733. walkFunc := func(path string, info os.FileInfo, _ error) error {
  734. if info.IsDir() {
  735. return nil
  736. }
  737. dir, file := filepath.Split(path)
  738. dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
  739. tag := strings.Join([]string{dir, file}, ":")
  740. fmp := ParseModelPath(tag)
  741. // skip the manifest we're trying to delete
  742. if mp.GetFullTagname() == fmp.GetFullTagname() {
  743. return nil
  744. }
  745. // save (i.e. delete from the deleteMap) any files used in other manifests
  746. manifest, _, err := GetManifest(fmp)
  747. if err != nil {
  748. log.Printf("skipping file: %s", fp)
  749. return nil
  750. }
  751. for _, layer := range manifest.Layers {
  752. delete(deleteMap, layer.Digest)
  753. }
  754. delete(deleteMap, manifest.Config.Digest)
  755. return nil
  756. }
  757. if err := filepath.Walk(fp, walkFunc); err != nil {
  758. return err
  759. }
  760. // only delete the files which are still in the deleteMap
  761. for k, v := range deleteMap {
  762. if v {
  763. fp, err := GetBlobsPath(k)
  764. if err != nil {
  765. log.Printf("couldn't get file path for '%s': %v", k, err)
  766. continue
  767. }
  768. if err := os.Remove(fp); err != nil {
  769. log.Printf("couldn't remove file '%s': %v", fp, err)
  770. continue
  771. }
  772. }
  773. }
  774. fp, err = mp.GetManifestPath(false)
  775. if err != nil {
  776. return err
  777. }
  778. err = os.Remove(fp)
  779. if err != nil {
  780. log.Printf("couldn't remove manifest file '%s': %v", fp, err)
  781. return err
  782. }
  783. return nil
  784. }
  785. func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  786. mp := ParseModelPath(name)
  787. fn(api.ProgressResponse{Status: "retrieving manifest"})
  788. if mp.ProtocolScheme == "http" && !regOpts.Insecure {
  789. return fmt.Errorf("insecure protocol http")
  790. }
  791. manifest, _, err := GetManifest(mp)
  792. if err != nil {
  793. fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
  794. return err
  795. }
  796. var layers []*Layer
  797. layers = append(layers, manifest.Layers...)
  798. layers = append(layers, &manifest.Config)
  799. for _, layer := range layers {
  800. exists, err := checkBlobExistence(ctx, mp, layer.Digest, regOpts)
  801. if err != nil {
  802. return err
  803. }
  804. if exists {
  805. fn(api.ProgressResponse{
  806. Status: "using existing layer",
  807. Digest: layer.Digest,
  808. Total: layer.Size,
  809. Completed: layer.Size,
  810. })
  811. log.Printf("Layer %s already exists", layer.Digest)
  812. continue
  813. }
  814. fn(api.ProgressResponse{
  815. Status: "starting upload",
  816. Digest: layer.Digest,
  817. Total: layer.Size,
  818. })
  819. location, err := startUpload(ctx, mp, layer, regOpts)
  820. if err != nil {
  821. log.Printf("couldn't start upload: %v", err)
  822. return err
  823. }
  824. if strings.HasPrefix(path.Base(location.Path), "sha256:") {
  825. layer.Digest = path.Base(location.Path)
  826. fn(api.ProgressResponse{
  827. Status: "using existing layer",
  828. Digest: layer.Digest,
  829. Total: layer.Size,
  830. Completed: layer.Size,
  831. })
  832. continue
  833. }
  834. if err := uploadBlobChunked(ctx, location, layer, regOpts, fn); err != nil {
  835. log.Printf("error uploading blob: %v", err)
  836. return err
  837. }
  838. }
  839. fn(api.ProgressResponse{Status: "pushing manifest"})
  840. requestURL := mp.BaseURL()
  841. requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
  842. manifestJSON, err := json.Marshal(manifest)
  843. if err != nil {
  844. return err
  845. }
  846. headers := make(http.Header)
  847. headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
  848. resp, err := makeRequestWithRetry(ctx, "PUT", requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
  849. if err != nil {
  850. return err
  851. }
  852. defer resp.Body.Close()
  853. fn(api.ProgressResponse{Status: "success"})
  854. return nil
  855. }
  856. func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
  857. mp := ParseModelPath(name)
  858. if mp.ProtocolScheme == "http" && !regOpts.Insecure {
  859. return fmt.Errorf("insecure protocol http")
  860. }
  861. fn(api.ProgressResponse{Status: "pulling manifest"})
  862. manifest, err := pullModelManifest(ctx, mp, regOpts)
  863. if err != nil {
  864. return fmt.Errorf("pull model manifest: %s", err)
  865. }
  866. var layers []*Layer
  867. layers = append(layers, manifest.Layers...)
  868. layers = append(layers, &manifest.Config)
  869. for _, layer := range layers {
  870. if err := downloadBlob(
  871. ctx,
  872. downloadOpts{
  873. mp: mp,
  874. digest: layer.Digest,
  875. regOpts: regOpts,
  876. fn: fn,
  877. }); err != nil {
  878. return err
  879. }
  880. }
  881. fn(api.ProgressResponse{Status: "verifying sha256 digest"})
  882. for _, layer := range layers {
  883. if err := verifyBlob(layer.Digest); err != nil {
  884. if errors.Is(err, errDigestMismatch) {
  885. // something went wrong, delete the blob
  886. fp, err := GetBlobsPath(layer.Digest)
  887. if err != nil {
  888. return err
  889. }
  890. if err := os.Remove(fp); err != nil {
  891. // log this, but return the original error
  892. log.Printf("couldn't remove file with digest mismatch '%s': %v", fp, err)
  893. }
  894. }
  895. return err
  896. }
  897. }
  898. fn(api.ProgressResponse{Status: "writing manifest"})
  899. manifestJSON, err := json.Marshal(manifest)
  900. if err != nil {
  901. return err
  902. }
  903. fp, err := mp.GetManifestPath(true)
  904. if err != nil {
  905. return err
  906. }
  907. err = os.WriteFile(fp, manifestJSON, 0o644)
  908. if err != nil {
  909. log.Printf("couldn't write to %s", fp)
  910. return err
  911. }
  912. fn(api.ProgressResponse{Status: "success"})
  913. return nil
  914. }
  915. func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
  916. requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
  917. headers := make(http.Header)
  918. headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
  919. resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, regOpts)
  920. if err != nil {
  921. log.Printf("couldn't get manifest: %v", err)
  922. return nil, err
  923. }
  924. defer resp.Body.Close()
  925. if resp.StatusCode >= http.StatusBadRequest {
  926. if resp.StatusCode == http.StatusNotFound {
  927. return nil, fmt.Errorf("model not found")
  928. }
  929. body, _ := io.ReadAll(resp.Body)
  930. return nil, fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
  931. }
  932. var m *ManifestV2
  933. if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
  934. return nil, err
  935. }
  936. return m, err
  937. }
  938. func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
  939. config.RootFS = RootFS{
  940. Type: "layers",
  941. DiffIDs: layers,
  942. }
  943. configJSON, err := json.Marshal(config)
  944. if err != nil {
  945. return nil, err
  946. }
  947. digest, size := GetSHA256Digest(bytes.NewBuffer(configJSON))
  948. layer := &LayerReader{
  949. Layer: Layer{
  950. MediaType: "application/vnd.docker.container.image.v1+json",
  951. Digest: digest,
  952. Size: size,
  953. },
  954. Reader: bytes.NewBuffer(configJSON),
  955. }
  956. return layer, nil
  957. }
  958. // GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
  959. func GetSHA256Digest(r io.Reader) (string, int) {
  960. h := sha256.New()
  961. n, err := io.Copy(h, r)
  962. if err != nil {
  963. log.Fatal(err)
  964. }
  965. return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
  966. }
  967. // Function to check if a blob already exists in the Docker registry
  968. func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
  969. requestURL := mp.BaseURL()
  970. requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", digest)
  971. resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, regOpts)
  972. if err != nil {
  973. log.Printf("couldn't check for blob: %v", err)
  974. return false, err
  975. }
  976. defer resp.Body.Close()
  977. // Check for success: If the blob exists, the Docker registry will respond with a 200 OK
  978. return resp.StatusCode < http.StatusBadRequest, nil
  979. }
  980. func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
  981. var status string
  982. for try := 0; try < MaxRetries; try++ {
  983. resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
  984. if err != nil {
  985. log.Printf("couldn't start upload: %v", err)
  986. return nil, err
  987. }
  988. status = resp.Status
  989. switch {
  990. case resp.StatusCode == http.StatusUnauthorized:
  991. auth := resp.Header.Get("www-authenticate")
  992. authRedir := ParseAuthRedirectString(auth)
  993. token, err := getAuthToken(ctx, authRedir, regOpts)
  994. if err != nil {
  995. return nil, err
  996. }
  997. regOpts.Token = token
  998. if body != nil {
  999. if _, err := body.Seek(0, io.SeekStart); err != nil {
  1000. return nil, err
  1001. }
  1002. }
  1003. continue
  1004. case resp.StatusCode >= http.StatusBadRequest:
  1005. body, _ := io.ReadAll(resp.Body)
  1006. return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
  1007. default:
  1008. return resp, nil
  1009. }
  1010. }
  1011. return nil, fmt.Errorf("max retry exceeded: %v", status)
  1012. }
  1013. func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
  1014. if requestURL.Scheme != "http" && regOpts.Insecure {
  1015. requestURL.Scheme = "http"
  1016. }
  1017. req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
  1018. if err != nil {
  1019. return nil, err
  1020. }
  1021. if headers != nil {
  1022. req.Header = headers
  1023. }
  1024. if regOpts.Token != "" {
  1025. req.Header.Set("Authorization", "Bearer "+regOpts.Token)
  1026. } else if regOpts.Username != "" && regOpts.Password != "" {
  1027. req.SetBasicAuth(regOpts.Username, regOpts.Password)
  1028. }
  1029. req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
  1030. client := &http.Client{
  1031. CheckRedirect: func(req *http.Request, via []*http.Request) error {
  1032. if len(via) >= 10 {
  1033. return fmt.Errorf("too many redirects")
  1034. }
  1035. log.Printf("redirected to: %s\n", req.URL)
  1036. return nil
  1037. },
  1038. }
  1039. resp, err := client.Do(req)
  1040. if err != nil {
  1041. return nil, err
  1042. }
  1043. return resp, nil
  1044. }
  1045. func getValue(header, key string) string {
  1046. startIdx := strings.Index(header, key+"=")
  1047. if startIdx == -1 {
  1048. return ""
  1049. }
  1050. // Move the index to the starting quote after the key.
  1051. startIdx += len(key) + 2
  1052. endIdx := startIdx
  1053. for endIdx < len(header) {
  1054. if header[endIdx] == '"' {
  1055. if endIdx+1 < len(header) && header[endIdx+1] != ',' { // If the next character isn't a comma, continue
  1056. endIdx++
  1057. continue
  1058. }
  1059. break
  1060. }
  1061. endIdx++
  1062. }
  1063. return header[startIdx:endIdx]
  1064. }
  1065. func ParseAuthRedirectString(authStr string) AuthRedirect {
  1066. authStr = strings.TrimPrefix(authStr, "Bearer ")
  1067. return AuthRedirect{
  1068. Realm: getValue(authStr, "realm"),
  1069. Service: getValue(authStr, "service"),
  1070. Scope: getValue(authStr, "scope"),
  1071. }
  1072. }
  1073. var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again")
  1074. func verifyBlob(digest string) error {
  1075. fp, err := GetBlobsPath(digest)
  1076. if err != nil {
  1077. return err
  1078. }
  1079. f, err := os.Open(fp)
  1080. if err != nil {
  1081. return err
  1082. }
  1083. defer f.Close()
  1084. fileDigest, _ := GetSHA256Digest(f)
  1085. if digest != fileDigest {
  1086. return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest)
  1087. }
  1088. return nil
  1089. }