registry.go 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095
  1. // Package ollama provides a client for interacting with an Ollama registry
  2. // which pushes and pulls model manifests and layers as defined by the
  3. // [ollama.com/manifest].
  4. package ollama
  5. import (
  6. "bufio"
  7. "bytes"
  8. "cmp"
  9. "context"
  10. "crypto"
  11. "crypto/ed25519"
  12. "crypto/sha256"
  13. "crypto/tls"
  14. "encoding/base64"
  15. "encoding/hex"
  16. "encoding/json"
  17. "errors"
  18. "fmt"
  19. "io"
  20. "io/fs"
  21. "iter"
  22. "log/slog"
  23. "net/http"
  24. "os"
  25. "path/filepath"
  26. "runtime"
  27. "runtime/debug"
  28. "slices"
  29. "strconv"
  30. "strings"
  31. "sync"
  32. "sync/atomic"
  33. "time"
  34. "golang.org/x/crypto/ssh"
  35. "golang.org/x/sync/errgroup"
  36. "github.com/ollama/ollama/server/internal/cache/blob"
  37. "github.com/ollama/ollama/server/internal/internal/names"
  38. _ "embed"
  39. )
  40. // Errors
  41. var (
  42. // ErrModelNotFound is returned when a manifest is not found in the
  43. // cache or registry.
  44. ErrModelNotFound = errors.New("model not found")
  45. // ErrManifestInvalid is returned when a manifest found in a local or
  46. // remote cache is invalid.
  47. ErrManifestInvalid = errors.New("invalid manifest")
  48. // ErrMissingModel is returned when the model part of a name is missing
  49. // or invalid.
  50. ErrNameInvalid = errors.New("invalid or missing name")
  51. // ErrCached is passed to [Trace.PushUpdate] when a layer already
  52. // exists. It is a non-fatal error and is never returned by [Registry.Push].
  53. ErrCached = errors.New("cached")
  54. // ErrIncomplete is returned by [Registry.Pull] when a model pull was
  55. // incomplete due to one or more layer download failures. Users that
  56. // want specific errors should use [WithTrace].
  57. ErrIncomplete = errors.New("incomplete")
  58. )
  59. // Defaults
  60. const (
  61. // DefaultChunkingThreshold is the threshold at which a layer should be
  62. // split up into chunks when downloading.
  63. DefaultChunkingThreshold = 64 << 20
  64. )
  65. var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
  66. dir := os.Getenv("OLLAMA_MODELS")
  67. if dir == "" {
  68. home, _ := os.UserHomeDir()
  69. home = cmp.Or(home, ".")
  70. dir = filepath.Join(home, ".ollama", "models")
  71. }
  72. return blob.Open(dir)
  73. })
  74. // DefaultCache returns the default cache used by the registry. It is
  75. // configured from the OLLAMA_MODELS environment variable, or defaults to
  76. // $HOME/.ollama/models, or, if an error occurs obtaining the home directory,
  77. // it uses the current working directory.
  78. func DefaultCache() (*blob.DiskCache, error) {
  79. return defaultCache()
  80. }
  81. // Error is the standard error returned by Ollama APIs. It can represent a
  82. // single or multiple error response.
  83. //
  84. // Single error responses have the following format:
  85. //
  86. // {"code": "optional_code","error":"error message"}
  87. //
  88. // Multiple error responses have the following format:
  89. //
  90. // {"errors": [{"code": "optional_code","message":"error message"}]}
  91. //
  92. // Note, that the error field is used in single error responses, while the
  93. // message field is used in multiple error responses.
  94. //
  95. // In both cases, the code field is optional and may be empty.
  96. type Error struct {
  97. Status int `json:"-"` // TODO(bmizerany): remove this
  98. Code string `json:"code"`
  99. Message string `json:"message"`
  100. }
  101. func (e *Error) Error() string {
  102. var b strings.Builder
  103. b.WriteString("registry responded with status ")
  104. b.WriteString(strconv.Itoa(e.Status))
  105. if e.Code != "" {
  106. b.WriteString(": code ")
  107. b.WriteString(e.Code)
  108. }
  109. if e.Message != "" {
  110. b.WriteString(": ")
  111. b.WriteString(e.Message)
  112. }
  113. return b.String()
  114. }
  115. func (e *Error) LogValue() slog.Value {
  116. return slog.GroupValue(
  117. slog.Int("status", e.Status),
  118. slog.String("code", e.Code),
  119. slog.String("message", e.Message),
  120. )
  121. }
  122. // UnmarshalJSON implements json.Unmarshaler.
  123. func (e *Error) UnmarshalJSON(b []byte) error {
  124. type E Error
  125. var v struct {
  126. // Single error
  127. Code string
  128. Error string
  129. // Multiple errors
  130. Errors []E
  131. }
  132. if err := json.Unmarshal(b, &v); err != nil {
  133. return err
  134. }
  135. if v.Error != "" {
  136. // Single error case
  137. e.Code = v.Code
  138. e.Message = v.Error
  139. return nil
  140. }
  141. if len(v.Errors) == 0 {
  142. return fmt.Errorf("no messages in error response: %s", string(b))
  143. }
  144. *e = Error(v.Errors[0]) // our registry only returns one error.
  145. return nil
  146. }
  147. const DefaultMask = "registry.ollama.ai/library/_:latest"
  148. var defaultMask = func() names.Name {
  149. n := names.Parse(DefaultMask)
  150. if !n.IsFullyQualified() {
  151. panic("default mask is not fully qualified")
  152. }
  153. return n
  154. }()
  155. // CompleteName returns a fully qualified name by merging the given name with
  156. // the default mask. If the name is already fully qualified, it is returned
  157. // unchanged.
  158. func CompleteName(name string) string {
  159. return names.Merge(names.Parse(name), defaultMask).String()
  160. }
  161. // Registry is a client for performing push and pull operations against an
  162. // Ollama registry.
  163. type Registry struct {
  164. // Cache is the cache used to store models. If nil, [DefaultCache] is
  165. // used.
  166. Cache *blob.DiskCache
  167. // UserAgent is the User-Agent header to send with requests to the
  168. // registry. If empty, the User-Agent is determined by HTTPClient.
  169. UserAgent string
  170. // Key is the key used to authenticate with the registry.
  171. //
  172. // Currently, only Ed25519 keys are supported.
  173. Key crypto.PrivateKey
  174. // HTTPClient is the HTTP client used to make requests to the registry.
  175. //
  176. // If nil, [http.DefaultClient] is used.
  177. //
  178. // As a quick note: If a Registry function that makes a call to a URL
  179. // with the "https+insecure" scheme, the client will be cloned and the
  180. // transport will be set to skip TLS verification, unless the client's
  181. // Transport done not have a Clone method with the same signature as
  182. // [http.Transport.Clone], which case, the call will fail.
  183. HTTPClient *http.Client
  184. // MaxStreams is the maximum number of concurrent streams to use when
  185. // pushing or pulling models. If zero, the number of streams is
  186. // determined by [runtime.GOMAXPROCS].
  187. //
  188. // A negative value means no limit.
  189. MaxStreams int
  190. // ChunkingThreshold is the maximum size of a layer to download in a single
  191. // request. If zero, [DefaultChunkingThreshold] is used.
  192. ChunkingThreshold int64
  193. // Mask, if set, is the name used to convert non-fully qualified names
  194. // to fully qualified names. If empty, [DefaultMask] is used.
  195. Mask string
  196. }
  197. func (r *Registry) cache() (*blob.DiskCache, error) {
  198. if r.Cache != nil {
  199. return r.Cache, nil
  200. }
  201. return defaultCache()
  202. }
  203. func (r *Registry) parseName(name string) (names.Name, error) {
  204. mask := defaultMask
  205. if r.Mask != "" {
  206. mask = names.Parse(r.Mask)
  207. }
  208. n := names.Merge(names.Parse(name), mask)
  209. if !n.IsFullyQualified() {
  210. return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name)
  211. }
  212. return n, nil
  213. }
  214. // DefaultRegistry returns a new Registry configured from the environment. The
  215. // key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the
  216. // value of OLLAMA_REGISTRY_MAXSTREAMS, and ChunkingDirectory is set to the
  217. // system's temporary directory.
  218. //
  219. // It returns an error if any configuration in the environment is invalid.
  220. func DefaultRegistry() (*Registry, error) {
  221. home, err := os.UserHomeDir()
  222. if err != nil {
  223. return nil, err
  224. }
  225. keyPEM, err := os.ReadFile(filepath.Join(home, ".ollama/id_ed25519"))
  226. if err != nil && errors.Is(err, fs.ErrNotExist) {
  227. return nil, err
  228. }
  229. var rc Registry
  230. rc.UserAgent = UserAgent()
  231. rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
  232. if err != nil {
  233. return nil, err
  234. }
  235. maxStreams := os.Getenv("OLLAMA_REGISTRY_MAXSTREAMS")
  236. if maxStreams != "" {
  237. var err error
  238. rc.MaxStreams, err = strconv.Atoi(maxStreams)
  239. if err != nil {
  240. return nil, fmt.Errorf("invalid OLLAMA_REGISTRY_MAXSTREAMS: %w", err)
  241. }
  242. }
  243. return &rc, nil
  244. }
  245. func UserAgent() string {
  246. buildinfo, _ := debug.ReadBuildInfo()
  247. version := buildinfo.Main.Version
  248. if version == "(devel)" {
  249. // When using `go run .` the version is "(devel)". This is seen
  250. // as an invalid version by ollama.com and so it defaults to
  251. // "needs upgrade" for some requests, such as pulls. These
  252. // checks can be skipped by using the special version "v0.0.0",
  253. // so we set it to that here.
  254. version = "v0.0.0"
  255. }
  256. return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
  257. version,
  258. runtime.GOARCH,
  259. runtime.GOOS,
  260. runtime.Version(),
  261. )
  262. }
  263. func (r *Registry) maxStreams() int {
  264. return cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
  265. }
  266. func (r *Registry) maxChunkingThreshold() int64 {
  267. return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
  268. }
  269. type PushParams struct {
  270. // From is an optional destination name for the model. If empty, the
  271. // destination name is the same as the source name.
  272. From string
  273. }
  274. // Push pushes the model with the name in the cache to the remote registry.
  275. func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
  276. if p == nil {
  277. p = &PushParams{}
  278. }
  279. c, err := r.cache()
  280. if err != nil {
  281. return err
  282. }
  283. m, err := r.ResolveLocal(cmp.Or(p.From, name))
  284. if err != nil {
  285. return err
  286. }
  287. // Before much else happens, check layers at not null, the blobs exist,
  288. // and the sizes match. This prevents long uploads followed by
  289. // disappointment.
  290. for _, l := range m.Layers {
  291. if l == nil {
  292. return fmt.Errorf("%w: null layer", ErrManifestInvalid)
  293. }
  294. info, err := c.Get(l.Digest)
  295. if err != nil {
  296. return fmt.Errorf("error getting %s: %w", l.Digest.Short(), err)
  297. }
  298. if info.Size != l.Size {
  299. return fmt.Errorf("size mismatch for %s: %d != %d", l.Digest.Short(), info.Size, l.Size)
  300. }
  301. }
  302. t := traceFromContext(ctx)
  303. scheme, n, _, err := r.parseNameExtended(name)
  304. if err != nil {
  305. // This should never happen since ResolveLocal should have
  306. // already validated the name.
  307. panic(err)
  308. }
  309. ctx, cancel := context.WithCancel(ctx)
  310. defer cancel()
  311. var g errgroup.Group
  312. g.SetLimit(r.maxStreams())
  313. for _, l := range m.Layers {
  314. var progress atomic.Int64
  315. g.Go(func() (err error) {
  316. defer func() { t.update(l, progress.Load(), err) }()
  317. t.update(l, 0, nil)
  318. startURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/uploads/?digest=%s",
  319. scheme,
  320. n.Host(),
  321. n.Namespace(),
  322. n.Model(),
  323. l.Digest,
  324. )
  325. res, err := r.send(ctx, "POST", startURL, nil)
  326. if err != nil {
  327. return err
  328. }
  329. res.Body.Close()
  330. f, err := os.Open(c.GetFile(l.Digest))
  331. if err != nil {
  332. return err
  333. }
  334. defer f.Close()
  335. uploadURL := res.Header.Get("Location")
  336. if uploadURL == "" {
  337. t.update(l, l.Size, ErrCached)
  338. return nil
  339. }
  340. req, err := r.newRequest(ctx, "PUT", uploadURL, f)
  341. if err != nil {
  342. return fmt.Errorf("invalid upload URL returned from registry: %q: %w", uploadURL, err)
  343. }
  344. req.ContentLength = l.Size
  345. res, err = sendRequest(r.client(), req)
  346. if err == nil {
  347. res.Body.Close()
  348. }
  349. return err
  350. })
  351. }
  352. if err := g.Wait(); err != nil {
  353. return err
  354. }
  355. // Commit
  356. path := fmt.Sprintf("%s://%s/v2/%s/%s/manifests/%s",
  357. scheme,
  358. n.Host(),
  359. n.Namespace(),
  360. n.Model(),
  361. n.Tag(),
  362. )
  363. res, err := r.send(ctx, "PUT", path, bytes.NewReader(m.Data))
  364. if err == nil {
  365. res.Body.Close()
  366. }
  367. // TODO(bmizerany): add a "commit" trace event
  368. return err
  369. }
  370. func canRetry(err error) bool {
  371. var re *Error
  372. if !errors.As(err, &re) {
  373. return false
  374. }
  375. return re.Status >= 500
  376. }
  377. // trackingReader is an io.Reader that tracks the number of bytes read and
  378. // calls the update function with the layer, the number of bytes read.
  379. //
  380. // It always calls update with a nil error.
  381. type trackingReader struct {
  382. l *Layer
  383. r io.Reader
  384. update func(l *Layer, n int64, err error)
  385. }
  386. func (r *trackingReader) Read(p []byte) (n int, err error) {
  387. n, err = r.r.Read(p)
  388. r.update(r.l, int64(n), nil)
  389. return
  390. }
  391. // Pull pulls the model with the given name from the remote registry into the
  392. // cache.
  393. //
  394. // For layers larger then [Registry.MaxChunkSize], the layer is downloaded in
  395. // chunks of the specified size, and then reassembled and verified. This is
  396. // typically slower than splitting the model up across layers, and is mostly
  397. // utilized for layers of type equal to "application/vnd.ollama.image".
  398. func (r *Registry) Pull(ctx context.Context, name string) error {
  399. m, err := r.Resolve(ctx, name)
  400. if err != nil {
  401. return err
  402. }
  403. // TODO(bmizerany): decide if this should be considered valid. Maybe
  404. // server-side we special case '{}' to have some special meaning? Maybe
  405. // "archiving" a tag (which is how we reason about it in the registry
  406. // already, just with a different twist).
  407. if len(m.Layers) == 0 {
  408. return fmt.Errorf("%w: no layers", ErrManifestInvalid)
  409. }
  410. c, err := r.cache()
  411. if err != nil {
  412. return err
  413. }
  414. // TODO(bmizerany): work to remove the need to do this
  415. layers := m.Layers
  416. if m.Config != nil && m.Config.Digest.IsValid() {
  417. layers = append(layers, m.Config)
  418. }
  419. // Send initial layer trace events to allow clients to have an
  420. // understanding of work to be done before work starts.
  421. var expected int64
  422. t := traceFromContext(ctx)
  423. for _, l := range layers {
  424. t.update(l, 0, nil)
  425. expected += l.Size
  426. }
  427. var total atomic.Int64
  428. var g errgroup.Group
  429. g.SetLimit(r.maxStreams())
  430. for _, l := range layers {
  431. info, err := c.Get(l.Digest)
  432. if err == nil && info.Size == l.Size {
  433. total.Add(l.Size)
  434. t.update(l, l.Size, ErrCached)
  435. continue
  436. }
  437. chunked, err := c.Chunked(l.Digest, l.Size)
  438. if err != nil {
  439. t.update(l, 0, err)
  440. continue
  441. }
  442. // TODO(bmizerany): fix this unbounded use of defer
  443. defer chunked.Close()
  444. for cs, err := range r.chunksums(ctx, name, l) {
  445. if err != nil {
  446. // Chunksum stream was interrupted, so tell
  447. // trace about it, and let in-flight chunk
  448. // downloads finish. Once they finish, return
  449. // ErrIncomplete, which is triggered by the
  450. // fact that the total bytes received is less
  451. // than the expected bytes.
  452. t.update(l, 0, err)
  453. break
  454. }
  455. g.Go(func() (err error) {
  456. defer func() {
  457. if err == nil || errors.Is(err, ErrCached) {
  458. total.Add(cs.Chunk.Size())
  459. } else {
  460. err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
  461. }
  462. }()
  463. req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
  464. if err != nil {
  465. return err
  466. }
  467. req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
  468. res, err := sendRequest(r.client(), req)
  469. if err != nil {
  470. return err
  471. }
  472. defer res.Body.Close()
  473. // Count bytes towards progress, as they
  474. // arrive, so that our bytes piggyback other
  475. // chunk updates on completion.
  476. //
  477. // This tactic is enough to show "smooth"
  478. // progress given the current CLI client. In
  479. // the near future, the server should report
  480. // download rate since it knows better than a
  481. // client that is measuring rate based on
  482. // wall-clock time-since-last-update.
  483. body := &trackingReader{l: l, r: res.Body, update: t.update}
  484. return chunked.Put(cs.Chunk, cs.Digest, body)
  485. })
  486. }
  487. }
  488. if err := g.Wait(); err != nil {
  489. return err
  490. }
  491. if total.Load() != expected {
  492. return fmt.Errorf("%w: received %d/%d", ErrIncomplete, total.Load(), expected)
  493. }
  494. md := blob.DigestFromBytes(m.Data)
  495. if err := blob.PutBytes(c, md, m.Data); err != nil {
  496. return err
  497. }
  498. return c.Link(m.Name, md)
  499. }
  500. // Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
  501. // before attempting to unlink the model.
  502. func (r *Registry) Unlink(name string) (ok bool, _ error) {
  503. n, err := r.parseName(name)
  504. if err != nil {
  505. return false, err
  506. }
  507. c, err := r.cache()
  508. if err != nil {
  509. return false, err
  510. }
  511. return c.Unlink(n.String())
  512. }
  513. // Manifest represents a [ollama.com/manifest].
  514. type Manifest struct {
  515. Name string `json:"-"` // the canonical name of the model
  516. Data []byte `json:"-"` // the raw data of the manifest
  517. Layers []*Layer `json:"layers"`
  518. // For legacy reasons, we still have to download the config layer.
  519. Config *Layer `json:"config"`
  520. }
  521. // Layer returns the layer with the given
  522. // digest, or nil if not found.
  523. func (m *Manifest) Layer(d blob.Digest) *Layer {
  524. for _, l := range m.Layers {
  525. if l.Digest == d {
  526. return l
  527. }
  528. }
  529. return nil
  530. }
  531. // MarshalJSON implements json.Marshaler.
  532. //
  533. // NOTE: It adds an empty config object to the manifest, which is required by
  534. // the registry, but not used by the client. In the future, the config object
  535. // will not be required by the registry and this will should be removed.
  536. func (m Manifest) MarshalJSON() ([]byte, error) {
  537. type M Manifest
  538. v := struct {
  539. M
  540. // This is ignored, mostly, by the registry But, if not
  541. // present, it will cause an error to be returned during the
  542. // last phase of the commit which expects it, but does nothing
  543. // with it. This will be fixed in a future release of
  544. // ollama.com.
  545. Config Layer `json:"config"`
  546. }{
  547. M: M(m),
  548. }
  549. return json.Marshal(v)
  550. }
  551. // unmarshalManifest unmarshals the data into a manifest, and sets the name
  552. // field to the string representation of the name.
  553. //
  554. // It panics if the name is not fully qualified. Callers should ensure the name
  555. // is fully qualified before calling this function.
  556. func unmarshalManifest(n names.Name, data []byte) (*Manifest, error) {
  557. if !n.IsFullyQualified() {
  558. panic(fmt.Sprintf("unmarshalManifest: name is not fully qualified: %s", n.String()))
  559. }
  560. var m Manifest
  561. if err := json.Unmarshal(data, &m); err != nil {
  562. return nil, err
  563. }
  564. m.Name = n.String()
  565. m.Data = data
  566. return &m, nil
  567. }
  568. // Layer is a layer in a model.
  569. type Layer struct {
  570. Digest blob.Digest `json:"digest"`
  571. MediaType string `json:"mediaType"`
  572. Size int64 `json:"size"`
  573. }
  574. // ResolveLocal resolves a name to a Manifest in the local cache.
  575. func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
  576. _, n, d, err := r.parseNameExtended(name)
  577. if err != nil {
  578. return nil, err
  579. }
  580. c, err := r.cache()
  581. if err != nil {
  582. return nil, err
  583. }
  584. if !d.IsValid() {
  585. // No digest, so resolve the manifest by name.
  586. d, err = c.Resolve(n.String())
  587. if err != nil {
  588. return nil, err
  589. }
  590. }
  591. data, err := os.ReadFile(c.GetFile(d))
  592. if err != nil {
  593. if errors.Is(err, fs.ErrNotExist) {
  594. return nil, fmt.Errorf("%w: %s", ErrModelNotFound, name)
  595. }
  596. return nil, err
  597. }
  598. m, err := unmarshalManifest(n, data)
  599. if err != nil {
  600. return nil, fmt.Errorf("%s: %w", name, errors.Join(ErrManifestInvalid, err))
  601. }
  602. return m, nil
  603. }
  604. // Resolve resolves a name to a Manifest in the remote registry.
  605. func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
  606. scheme, n, d, err := r.parseNameExtended(name)
  607. if err != nil {
  608. return nil, err
  609. }
  610. manifestURL := fmt.Sprintf("%s://%s/v2/%s/%s/manifests/%s", scheme, n.Host(), n.Namespace(), n.Model(), n.Tag())
  611. if d.IsValid() {
  612. manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d)
  613. }
  614. res, err := r.send(ctx, "GET", manifestURL, nil)
  615. if err != nil {
  616. return nil, err
  617. }
  618. defer res.Body.Close()
  619. data, err := io.ReadAll(res.Body)
  620. if err != nil {
  621. return nil, err
  622. }
  623. // TODO(bmizerany): return digest here
  624. m, err := unmarshalManifest(n, data)
  625. if err != nil {
  626. return nil, fmt.Errorf("%s: %w", name, errors.Join(ErrManifestInvalid, err))
  627. }
  628. return m, nil
  629. }
  630. type chunksum struct {
  631. URL string
  632. Chunk blob.Chunk
  633. Digest blob.Digest
  634. }
  635. // chunksums returns a sequence of chunksums for the given layer. If the layer is under the
  636. // chunking threshold, a single chunksum is returned that covers the entire layer. If the layer
  637. // is over the chunking threshold, the chunksums are read from the chunksums endpoint.
  638. func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] {
  639. return func(yield func(chunksum, error) bool) {
  640. scheme, n, _, err := r.parseNameExtended(name)
  641. if err != nil {
  642. yield(chunksum{}, err)
  643. return
  644. }
  645. if l.Size < r.maxChunkingThreshold() {
  646. // any layer under the threshold should be downloaded
  647. // in one go.
  648. cs := chunksum{
  649. URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
  650. scheme,
  651. n.Host(),
  652. n.Namespace(),
  653. n.Model(),
  654. l.Digest,
  655. ),
  656. Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
  657. Digest: l.Digest,
  658. }
  659. yield(cs, nil)
  660. return
  661. }
  662. // A chunksums response is a sequence of chunksums in a
  663. // simple, easy to parse line-oriented format.
  664. //
  665. // Example:
  666. //
  667. // >> GET /v2/<namespace>/<model>/chunksums/<digest>
  668. //
  669. // << HTTP/1.1 200 OK
  670. // << Content-Location: <blobURL>
  671. // <<
  672. // << <digest> <start>-<end>
  673. // << ...
  674. //
  675. // The blobURL is the URL to download the chunks from.
  676. chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
  677. scheme,
  678. n.Host(),
  679. n.Namespace(),
  680. n.Model(),
  681. l.Digest,
  682. )
  683. req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
  684. if err != nil {
  685. yield(chunksum{}, err)
  686. return
  687. }
  688. res, err := sendRequest(r.client(), req)
  689. if err != nil {
  690. yield(chunksum{}, err)
  691. return
  692. }
  693. defer res.Body.Close()
  694. if res.StatusCode != 200 {
  695. err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
  696. yield(chunksum{}, err)
  697. return
  698. }
  699. blobURL := res.Header.Get("Content-Location")
  700. s := bufio.NewScanner(res.Body)
  701. s.Split(bufio.ScanWords)
  702. for {
  703. if !s.Scan() {
  704. if s.Err() != nil {
  705. yield(chunksum{}, s.Err())
  706. }
  707. return
  708. }
  709. d, err := blob.ParseDigest(s.Bytes())
  710. if err != nil {
  711. yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes()))
  712. return
  713. }
  714. if !s.Scan() {
  715. err := s.Err()
  716. if err == nil {
  717. err = fmt.Errorf("missing chunk range for digest %s", d)
  718. }
  719. yield(chunksum{}, err)
  720. return
  721. }
  722. chunk, err := parseChunk(s.Bytes())
  723. if err != nil {
  724. yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()))
  725. return
  726. }
  727. cs := chunksum{
  728. URL: blobURL,
  729. Chunk: chunk,
  730. Digest: d,
  731. }
  732. if !yield(cs, nil) {
  733. return
  734. }
  735. }
  736. }
  737. }
  738. func (r *Registry) client() *http.Client {
  739. if r.HTTPClient != nil {
  740. return r.HTTPClient
  741. }
  742. return http.DefaultClient
  743. }
  744. // newRequest constructs a new request, ready to use, with the given method,
  745. // url, and body, pre-signed with client [Key] and [UserAgent].
  746. func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
  747. req, err := http.NewRequestWithContext(ctx, method, url, body)
  748. if err != nil {
  749. return nil, err
  750. }
  751. if r.UserAgent != "" {
  752. req.Header.Set("User-Agent", r.UserAgent)
  753. }
  754. if r.Key != nil {
  755. token, err := makeAuthToken(r.Key)
  756. if err != nil {
  757. return nil, err
  758. }
  759. req.Header.Set("Authorization", "Bearer "+token)
  760. }
  761. return req, nil
  762. }
  763. // sendRequest makes a request with the given client and request, and returns the
  764. // response if the status code is 200. If the status code is not 200, an Error
  765. // is parsed from the response body and returned. If any other error occurs, it
  766. // is returned.
  767. func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) {
  768. defer func() {
  769. if err != nil {
  770. err = fmt.Errorf("request error %s: %w", r.URL, err)
  771. }
  772. }()
  773. if r.URL.Scheme == "https+insecure" {
  774. // TODO(bmizerany): clone client.Transport, set
  775. // InsecureSkipVerify, etc.
  776. type cloner interface {
  777. Clone() *http.Transport
  778. }
  779. // Attempt to configure the transport to skip TLS verification
  780. // if we can clone it, otherwise fall through and let the http
  781. // client complain and the scheme being invalid.
  782. x, ok := cmp.Or(c.Transport, http.DefaultTransport).(cloner)
  783. if ok {
  784. tr := x.Clone()
  785. tr.TLSClientConfig = cmp.Or(tr.TLSClientConfig, &tls.Config{})
  786. tr.TLSClientConfig.InsecureSkipVerify = true
  787. cc := *c // shallow copy
  788. cc.Transport = tr
  789. c = &cc
  790. r = r.Clone(r.Context())
  791. r.URL.Scheme = "https"
  792. // fall through
  793. }
  794. }
  795. res, err := c.Do(r)
  796. if err != nil {
  797. return nil, err
  798. }
  799. if res.StatusCode/100 != 2 {
  800. out, err := io.ReadAll(res.Body)
  801. if err != nil {
  802. return nil, err
  803. }
  804. var re Error
  805. if err := json.Unmarshal(out, &re); err != nil {
  806. // Use the raw body if we can't parse it as an error object.
  807. re.Message = string(out)
  808. }
  809. // coerce MANIFEST_UNKNOWN to ErrManifestNotFound
  810. if strings.EqualFold(re.Code, "MANIFEST_UNKNOWN") {
  811. return nil, ErrModelNotFound
  812. }
  813. re.Status = res.StatusCode
  814. return nil, &re
  815. }
  816. return res, nil
  817. }
  818. // send is a convenience method for making a request with newRequest and
  819. // passing it to send with r.client().
  820. func (r *Registry) send(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
  821. req, err := r.newRequest(ctx, method, path, body)
  822. if err != nil {
  823. return nil, err
  824. }
  825. return sendRequest(r.client(), req)
  826. }
  827. // makeAuthToken creates an Ollama auth token for the given private key.
  828. //
  829. // NOTE: This format is OLD, overly complex, and should be replaced. We're
  830. // inheriting it from the original Ollama client and ollama.com
  831. // implementations, so we need to support it for now.
  832. func makeAuthToken(key crypto.PrivateKey) (string, error) {
  833. privKey, _ := key.(*ed25519.PrivateKey)
  834. if privKey == nil {
  835. return "", fmt.Errorf("unsupported private key type: %T", key)
  836. }
  837. url := fmt.Sprintf("https://ollama.com?ts=%d", time.Now().Unix())
  838. // Part 1: the checkData (e.g. the URL with a timestamp)
  839. // Part 2: the public key
  840. pubKeyShort, err := func() ([]byte, error) {
  841. sshPubKey, err := ssh.NewPublicKey(privKey.Public())
  842. if err != nil {
  843. return nil, err
  844. }
  845. pubKeyParts := bytes.Fields(ssh.MarshalAuthorizedKey(sshPubKey))
  846. if len(pubKeyParts) < 2 {
  847. return nil, fmt.Errorf("malformed public key: %q", pubKeyParts)
  848. }
  849. pubKeyShort := pubKeyParts[1]
  850. return pubKeyShort, nil
  851. }()
  852. if err != nil {
  853. return "", err
  854. }
  855. // Part 3: the signature
  856. sig := ed25519.Sign(*privKey, []byte(checkData(url)))
  857. // Assemble the token: <checkData>:<pubKey>:<signature>
  858. var b strings.Builder
  859. io.WriteString(&b, base64.StdEncoding.EncodeToString([]byte(url)))
  860. b.WriteByte(':')
  861. b.Write(pubKeyShort)
  862. b.WriteByte(':')
  863. io.WriteString(&b, base64.StdEncoding.EncodeToString(sig))
  864. return b.String(), nil
  865. }
  866. // The original spec for Ollama tokens was to use the SHA256 of the zero
  867. // string as part of the signature. I'm not sure why that was, but we still
  868. // need it to verify the signature.
  869. var zeroSum = func() string {
  870. sha256sum := sha256.Sum256(nil)
  871. x := base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))
  872. return x
  873. }()
  874. // checkData takes a URL and creates the original string format of the
  875. // data signature that is used by the ollama client to sign requests
  876. func checkData(url string) string {
  877. return fmt.Sprintf("GET,%s,%s", url, zeroSum)
  878. }
  879. type publicError struct {
  880. wrapped error
  881. message string
  882. }
  883. func withPublicMessagef(err error, message string, args ...any) error {
  884. return publicError{wrapped: err, message: fmt.Sprintf(message, args...)}
  885. }
  886. func (e publicError) Error() string { return e.message }
  887. func (e publicError) Unwrap() error { return e.wrapped }
  888. var supportedSchemes = []string{
  889. "http",
  890. "https",
  891. "https+insecure",
  892. }
  893. var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
  894. // parseNameExtended parses and validates an extended name, returning the scheme, name,
  895. // and digest.
  896. //
  897. // If the scheme is empty, scheme will be "https". If an unsupported scheme is
  898. // given, [ErrNameInvalid] wrapped with a display friendly message is returned.
  899. //
  900. // If the digest is invalid, [ErrNameInvalid] wrapped with a display friendly
  901. // message is returned.
  902. //
  903. // If the name is not, once merged with the mask, fully qualified,
  904. // [ErrNameInvalid] wrapped with a display friendly message is returned.
  905. func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
  906. scheme, name, digest := splitExtended(s)
  907. scheme = cmp.Or(scheme, "https")
  908. if !slices.Contains(supportedSchemes, scheme) {
  909. err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
  910. return "", names.Name{}, blob.Digest{}, err
  911. }
  912. var d blob.Digest
  913. if digest != "" {
  914. var err error
  915. d, err = blob.ParseDigest(digest)
  916. if err != nil {
  917. err = withPublicMessagef(ErrNameInvalid, "invalid digest: %q", digest)
  918. return "", names.Name{}, blob.Digest{}, err
  919. }
  920. if name == "" {
  921. // We have can resolve a manifest from a digest only,
  922. // so skip name validation and return the scheme and
  923. // digest.
  924. return scheme, names.Name{}, d, nil
  925. }
  926. }
  927. n, err := r.parseName(name)
  928. if err != nil {
  929. return "", names.Name{}, blob.Digest{}, err
  930. }
  931. return scheme, n, d, nil
  932. }
  933. // splitExtended splits an extended name string into its scheme, name, and digest
  934. // parts.
  935. //
  936. // Examples:
  937. //
  938. // http://ollama.com/bmizerany/smol:latest@digest
  939. // https://ollama.com/bmizerany/smol:latest
  940. // ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
  941. // model@digest
  942. // @digest
  943. func splitExtended(s string) (scheme, name, digest string) {
  944. i := strings.Index(s, "://")
  945. if i >= 0 {
  946. scheme = s[:i]
  947. s = s[i+3:]
  948. }
  949. i = strings.LastIndex(s, "@")
  950. if i >= 0 {
  951. digest = s[i+1:]
  952. s = s[:i]
  953. }
  954. return scheme, s, digest
  955. }
  956. // parseChunk parses a string in the form "start-end" and returns the Chunk.
  957. func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) {
  958. startPart, endPart, found := strings.Cut(string(s), "-")
  959. if !found {
  960. return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s)
  961. }
  962. start, err := strconv.ParseInt(startPart, 10, 64)
  963. if err != nil {
  964. return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err)
  965. }
  966. end, err := strconv.ParseInt(endPart, 10, 64)
  967. if err != nil {
  968. return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err)
  969. }
  970. if start > end {
  971. return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s)
  972. }
  973. return blob.Chunk{Start: start, End: end}, nil
  974. }