123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095 |
- // Package ollama provides a client for interacting with an Ollama registry
- // which pushes and pulls model manifests and layers as defined by the
- // [ollama.com/manifest].
- package ollama
- import (
- "bufio"
- "bytes"
- "cmp"
- "context"
- "crypto"
- "crypto/ed25519"
- "crypto/sha256"
- "crypto/tls"
- "encoding/base64"
- "encoding/hex"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "io/fs"
- "iter"
- "log/slog"
- "net/http"
- "os"
- "path/filepath"
- "runtime"
- "runtime/debug"
- "slices"
- "strconv"
- "strings"
- "sync"
- "sync/atomic"
- "time"
- "golang.org/x/crypto/ssh"
- "golang.org/x/sync/errgroup"
- "github.com/ollama/ollama/server/internal/cache/blob"
- "github.com/ollama/ollama/server/internal/internal/names"
- _ "embed"
- )
- // Errors
- var (
- // ErrModelNotFound is returned when a manifest is not found in the
- // cache or registry.
- ErrModelNotFound = errors.New("model not found")
- // ErrManifestInvalid is returned when a manifest found in a local or
- // remote cache is invalid.
- ErrManifestInvalid = errors.New("invalid manifest")
- // ErrMissingModel is returned when the model part of a name is missing
- // or invalid.
- ErrNameInvalid = errors.New("invalid or missing name")
- // ErrCached is passed to [Trace.PushUpdate] when a layer already
- // exists. It is a non-fatal error and is never returned by [Registry.Push].
- ErrCached = errors.New("cached")
- // ErrIncomplete is returned by [Registry.Pull] when a model pull was
- // incomplete due to one or more layer download failures. Users that
- // want specific errors should use [WithTrace].
- ErrIncomplete = errors.New("incomplete")
- )
- // Defaults
- const (
- // DefaultChunkingThreshold is the threshold at which a layer should be
- // split up into chunks when downloading.
- DefaultChunkingThreshold = 64 << 20
- )
- var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
- dir := os.Getenv("OLLAMA_MODELS")
- if dir == "" {
- home, _ := os.UserHomeDir()
- home = cmp.Or(home, ".")
- dir = filepath.Join(home, ".ollama", "models")
- }
- return blob.Open(dir)
- })
- // DefaultCache returns the default cache used by the registry. It is
- // configured from the OLLAMA_MODELS environment variable, or defaults to
- // $HOME/.ollama/models, or, if an error occurs obtaining the home directory,
- // it uses the current working directory.
- func DefaultCache() (*blob.DiskCache, error) {
- return defaultCache()
- }
- // Error is the standard error returned by Ollama APIs. It can represent a
- // single or multiple error response.
- //
- // Single error responses have the following format:
- //
- // {"code": "optional_code","error":"error message"}
- //
- // Multiple error responses have the following format:
- //
- // {"errors": [{"code": "optional_code","message":"error message"}]}
- //
- // Note, that the error field is used in single error responses, while the
- // message field is used in multiple error responses.
- //
- // In both cases, the code field is optional and may be empty.
- type Error struct {
- Status int `json:"-"` // TODO(bmizerany): remove this
- Code string `json:"code"`
- Message string `json:"message"`
- }
- func (e *Error) Error() string {
- var b strings.Builder
- b.WriteString("registry responded with status ")
- b.WriteString(strconv.Itoa(e.Status))
- if e.Code != "" {
- b.WriteString(": code ")
- b.WriteString(e.Code)
- }
- if e.Message != "" {
- b.WriteString(": ")
- b.WriteString(e.Message)
- }
- return b.String()
- }
- func (e *Error) LogValue() slog.Value {
- return slog.GroupValue(
- slog.Int("status", e.Status),
- slog.String("code", e.Code),
- slog.String("message", e.Message),
- )
- }
- // UnmarshalJSON implements json.Unmarshaler.
- func (e *Error) UnmarshalJSON(b []byte) error {
- type E Error
- var v struct {
- // Single error
- Code string
- Error string
- // Multiple errors
- Errors []E
- }
- if err := json.Unmarshal(b, &v); err != nil {
- return err
- }
- if v.Error != "" {
- // Single error case
- e.Code = v.Code
- e.Message = v.Error
- return nil
- }
- if len(v.Errors) == 0 {
- return fmt.Errorf("no messages in error response: %s", string(b))
- }
- *e = Error(v.Errors[0]) // our registry only returns one error.
- return nil
- }
- const DefaultMask = "registry.ollama.ai/library/_:latest"
- var defaultMask = func() names.Name {
- n := names.Parse(DefaultMask)
- if !n.IsFullyQualified() {
- panic("default mask is not fully qualified")
- }
- return n
- }()
- // CompleteName returns a fully qualified name by merging the given name with
- // the default mask. If the name is already fully qualified, it is returned
- // unchanged.
- func CompleteName(name string) string {
- return names.Merge(names.Parse(name), defaultMask).String()
- }
- // Registry is a client for performing push and pull operations against an
- // Ollama registry.
- type Registry struct {
- // Cache is the cache used to store models. If nil, [DefaultCache] is
- // used.
- Cache *blob.DiskCache
- // UserAgent is the User-Agent header to send with requests to the
- // registry. If empty, the User-Agent is determined by HTTPClient.
- UserAgent string
- // Key is the key used to authenticate with the registry.
- //
- // Currently, only Ed25519 keys are supported.
- Key crypto.PrivateKey
- // HTTPClient is the HTTP client used to make requests to the registry.
- //
- // If nil, [http.DefaultClient] is used.
- //
- // As a quick note: If a Registry function that makes a call to a URL
- // with the "https+insecure" scheme, the client will be cloned and the
- // transport will be set to skip TLS verification, unless the client's
- // Transport done not have a Clone method with the same signature as
- // [http.Transport.Clone], which case, the call will fail.
- HTTPClient *http.Client
- // MaxStreams is the maximum number of concurrent streams to use when
- // pushing or pulling models. If zero, the number of streams is
- // determined by [runtime.GOMAXPROCS].
- //
- // A negative value means no limit.
- MaxStreams int
- // ChunkingThreshold is the maximum size of a layer to download in a single
- // request. If zero, [DefaultChunkingThreshold] is used.
- ChunkingThreshold int64
- // Mask, if set, is the name used to convert non-fully qualified names
- // to fully qualified names. If empty, [DefaultMask] is used.
- Mask string
- }
- func (r *Registry) cache() (*blob.DiskCache, error) {
- if r.Cache != nil {
- return r.Cache, nil
- }
- return defaultCache()
- }
- func (r *Registry) parseName(name string) (names.Name, error) {
- mask := defaultMask
- if r.Mask != "" {
- mask = names.Parse(r.Mask)
- }
- n := names.Merge(names.Parse(name), mask)
- if !n.IsFullyQualified() {
- return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name)
- }
- return n, nil
- }
- // DefaultRegistry returns a new Registry configured from the environment. The
- // key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the
- // value of OLLAMA_REGISTRY_MAXSTREAMS, and ChunkingDirectory is set to the
- // system's temporary directory.
- //
- // It returns an error if any configuration in the environment is invalid.
- func DefaultRegistry() (*Registry, error) {
- home, err := os.UserHomeDir()
- if err != nil {
- return nil, err
- }
- keyPEM, err := os.ReadFile(filepath.Join(home, ".ollama/id_ed25519"))
- if err != nil && errors.Is(err, fs.ErrNotExist) {
- return nil, err
- }
- var rc Registry
- rc.UserAgent = UserAgent()
- rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
- if err != nil {
- return nil, err
- }
- maxStreams := os.Getenv("OLLAMA_REGISTRY_MAXSTREAMS")
- if maxStreams != "" {
- var err error
- rc.MaxStreams, err = strconv.Atoi(maxStreams)
- if err != nil {
- return nil, fmt.Errorf("invalid OLLAMA_REGISTRY_MAXSTREAMS: %w", err)
- }
- }
- return &rc, nil
- }
- func UserAgent() string {
- buildinfo, _ := debug.ReadBuildInfo()
- version := buildinfo.Main.Version
- if version == "(devel)" {
- // When using `go run .` the version is "(devel)". This is seen
- // as an invalid version by ollama.com and so it defaults to
- // "needs upgrade" for some requests, such as pulls. These
- // checks can be skipped by using the special version "v0.0.0",
- // so we set it to that here.
- version = "v0.0.0"
- }
- return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
- version,
- runtime.GOARCH,
- runtime.GOOS,
- runtime.Version(),
- )
- }
- func (r *Registry) maxStreams() int {
- return cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
- }
- func (r *Registry) maxChunkingThreshold() int64 {
- return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
- }
- type PushParams struct {
- // From is an optional destination name for the model. If empty, the
- // destination name is the same as the source name.
- From string
- }
- // Push pushes the model with the name in the cache to the remote registry.
- func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
- if p == nil {
- p = &PushParams{}
- }
- c, err := r.cache()
- if err != nil {
- return err
- }
- m, err := r.ResolveLocal(cmp.Or(p.From, name))
- if err != nil {
- return err
- }
- // Before much else happens, check layers at not null, the blobs exist,
- // and the sizes match. This prevents long uploads followed by
- // disappointment.
- for _, l := range m.Layers {
- if l == nil {
- return fmt.Errorf("%w: null layer", ErrManifestInvalid)
- }
- info, err := c.Get(l.Digest)
- if err != nil {
- return fmt.Errorf("error getting %s: %w", l.Digest.Short(), err)
- }
- if info.Size != l.Size {
- return fmt.Errorf("size mismatch for %s: %d != %d", l.Digest.Short(), info.Size, l.Size)
- }
- }
- t := traceFromContext(ctx)
- scheme, n, _, err := r.parseNameExtended(name)
- if err != nil {
- // This should never happen since ResolveLocal should have
- // already validated the name.
- panic(err)
- }
- ctx, cancel := context.WithCancel(ctx)
- defer cancel()
- var g errgroup.Group
- g.SetLimit(r.maxStreams())
- for _, l := range m.Layers {
- var progress atomic.Int64
- g.Go(func() (err error) {
- defer func() { t.update(l, progress.Load(), err) }()
- t.update(l, 0, nil)
- startURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/uploads/?digest=%s",
- scheme,
- n.Host(),
- n.Namespace(),
- n.Model(),
- l.Digest,
- )
- res, err := r.send(ctx, "POST", startURL, nil)
- if err != nil {
- return err
- }
- res.Body.Close()
- f, err := os.Open(c.GetFile(l.Digest))
- if err != nil {
- return err
- }
- defer f.Close()
- uploadURL := res.Header.Get("Location")
- if uploadURL == "" {
- t.update(l, l.Size, ErrCached)
- return nil
- }
- req, err := r.newRequest(ctx, "PUT", uploadURL, f)
- if err != nil {
- return fmt.Errorf("invalid upload URL returned from registry: %q: %w", uploadURL, err)
- }
- req.ContentLength = l.Size
- res, err = sendRequest(r.client(), req)
- if err == nil {
- res.Body.Close()
- }
- return err
- })
- }
- if err := g.Wait(); err != nil {
- return err
- }
- // Commit
- path := fmt.Sprintf("%s://%s/v2/%s/%s/manifests/%s",
- scheme,
- n.Host(),
- n.Namespace(),
- n.Model(),
- n.Tag(),
- )
- res, err := r.send(ctx, "PUT", path, bytes.NewReader(m.Data))
- if err == nil {
- res.Body.Close()
- }
- // TODO(bmizerany): add a "commit" trace event
- return err
- }
- func canRetry(err error) bool {
- var re *Error
- if !errors.As(err, &re) {
- return false
- }
- return re.Status >= 500
- }
- // trackingReader is an io.Reader that tracks the number of bytes read and
- // calls the update function with the layer, the number of bytes read.
- //
- // It always calls update with a nil error.
- type trackingReader struct {
- l *Layer
- r io.Reader
- update func(l *Layer, n int64, err error)
- }
- func (r *trackingReader) Read(p []byte) (n int, err error) {
- n, err = r.r.Read(p)
- r.update(r.l, int64(n), nil)
- return
- }
- // Pull pulls the model with the given name from the remote registry into the
- // cache.
- //
- // For layers larger then [Registry.MaxChunkSize], the layer is downloaded in
- // chunks of the specified size, and then reassembled and verified. This is
- // typically slower than splitting the model up across layers, and is mostly
- // utilized for layers of type equal to "application/vnd.ollama.image".
- func (r *Registry) Pull(ctx context.Context, name string) error {
- m, err := r.Resolve(ctx, name)
- if err != nil {
- return err
- }
- // TODO(bmizerany): decide if this should be considered valid. Maybe
- // server-side we special case '{}' to have some special meaning? Maybe
- // "archiving" a tag (which is how we reason about it in the registry
- // already, just with a different twist).
- if len(m.Layers) == 0 {
- return fmt.Errorf("%w: no layers", ErrManifestInvalid)
- }
- c, err := r.cache()
- if err != nil {
- return err
- }
- // TODO(bmizerany): work to remove the need to do this
- layers := m.Layers
- if m.Config != nil && m.Config.Digest.IsValid() {
- layers = append(layers, m.Config)
- }
- // Send initial layer trace events to allow clients to have an
- // understanding of work to be done before work starts.
- var expected int64
- t := traceFromContext(ctx)
- for _, l := range layers {
- t.update(l, 0, nil)
- expected += l.Size
- }
- var total atomic.Int64
- var g errgroup.Group
- g.SetLimit(r.maxStreams())
- for _, l := range layers {
- info, err := c.Get(l.Digest)
- if err == nil && info.Size == l.Size {
- total.Add(l.Size)
- t.update(l, l.Size, ErrCached)
- continue
- }
- chunked, err := c.Chunked(l.Digest, l.Size)
- if err != nil {
- t.update(l, 0, err)
- continue
- }
- // TODO(bmizerany): fix this unbounded use of defer
- defer chunked.Close()
- for cs, err := range r.chunksums(ctx, name, l) {
- if err != nil {
- // Chunksum stream was interrupted, so tell
- // trace about it, and let in-flight chunk
- // downloads finish. Once they finish, return
- // ErrIncomplete, which is triggered by the
- // fact that the total bytes received is less
- // than the expected bytes.
- t.update(l, 0, err)
- break
- }
- g.Go(func() (err error) {
- defer func() {
- if err == nil || errors.Is(err, ErrCached) {
- total.Add(cs.Chunk.Size())
- } else {
- err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
- }
- }()
- req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
- if err != nil {
- return err
- }
- req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
- res, err := sendRequest(r.client(), req)
- if err != nil {
- return err
- }
- defer res.Body.Close()
- // Count bytes towards progress, as they
- // arrive, so that our bytes piggyback other
- // chunk updates on completion.
- //
- // This tactic is enough to show "smooth"
- // progress given the current CLI client. In
- // the near future, the server should report
- // download rate since it knows better than a
- // client that is measuring rate based on
- // wall-clock time-since-last-update.
- body := &trackingReader{l: l, r: res.Body, update: t.update}
- return chunked.Put(cs.Chunk, cs.Digest, body)
- })
- }
- }
- if err := g.Wait(); err != nil {
- return err
- }
- if total.Load() != expected {
- return fmt.Errorf("%w: received %d/%d", ErrIncomplete, total.Load(), expected)
- }
- md := blob.DigestFromBytes(m.Data)
- if err := blob.PutBytes(c, md, m.Data); err != nil {
- return err
- }
- return c.Link(m.Name, md)
- }
- // Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
- // before attempting to unlink the model.
- func (r *Registry) Unlink(name string) (ok bool, _ error) {
- n, err := r.parseName(name)
- if err != nil {
- return false, err
- }
- c, err := r.cache()
- if err != nil {
- return false, err
- }
- return c.Unlink(n.String())
- }
- // Manifest represents a [ollama.com/manifest].
- type Manifest struct {
- Name string `json:"-"` // the canonical name of the model
- Data []byte `json:"-"` // the raw data of the manifest
- Layers []*Layer `json:"layers"`
- // For legacy reasons, we still have to download the config layer.
- Config *Layer `json:"config"`
- }
- // Layer returns the layer with the given
- // digest, or nil if not found.
- func (m *Manifest) Layer(d blob.Digest) *Layer {
- for _, l := range m.Layers {
- if l.Digest == d {
- return l
- }
- }
- return nil
- }
- // MarshalJSON implements json.Marshaler.
- //
- // NOTE: It adds an empty config object to the manifest, which is required by
- // the registry, but not used by the client. In the future, the config object
- // will not be required by the registry and this will should be removed.
- func (m Manifest) MarshalJSON() ([]byte, error) {
- type M Manifest
- v := struct {
- M
- // This is ignored, mostly, by the registry But, if not
- // present, it will cause an error to be returned during the
- // last phase of the commit which expects it, but does nothing
- // with it. This will be fixed in a future release of
- // ollama.com.
- Config Layer `json:"config"`
- }{
- M: M(m),
- }
- return json.Marshal(v)
- }
- // unmarshalManifest unmarshals the data into a manifest, and sets the name
- // field to the string representation of the name.
- //
- // It panics if the name is not fully qualified. Callers should ensure the name
- // is fully qualified before calling this function.
- func unmarshalManifest(n names.Name, data []byte) (*Manifest, error) {
- if !n.IsFullyQualified() {
- panic(fmt.Sprintf("unmarshalManifest: name is not fully qualified: %s", n.String()))
- }
- var m Manifest
- if err := json.Unmarshal(data, &m); err != nil {
- return nil, err
- }
- m.Name = n.String()
- m.Data = data
- return &m, nil
- }
- // Layer is a layer in a model.
- type Layer struct {
- Digest blob.Digest `json:"digest"`
- MediaType string `json:"mediaType"`
- Size int64 `json:"size"`
- }
- // ResolveLocal resolves a name to a Manifest in the local cache.
- func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
- _, n, d, err := r.parseNameExtended(name)
- if err != nil {
- return nil, err
- }
- c, err := r.cache()
- if err != nil {
- return nil, err
- }
- if !d.IsValid() {
- // No digest, so resolve the manifest by name.
- d, err = c.Resolve(n.String())
- if err != nil {
- return nil, err
- }
- }
- data, err := os.ReadFile(c.GetFile(d))
- if err != nil {
- if errors.Is(err, fs.ErrNotExist) {
- return nil, fmt.Errorf("%w: %s", ErrModelNotFound, name)
- }
- return nil, err
- }
- m, err := unmarshalManifest(n, data)
- if err != nil {
- return nil, fmt.Errorf("%s: %w", name, errors.Join(ErrManifestInvalid, err))
- }
- return m, nil
- }
- // Resolve resolves a name to a Manifest in the remote registry.
- func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
- scheme, n, d, err := r.parseNameExtended(name)
- if err != nil {
- return nil, err
- }
- manifestURL := fmt.Sprintf("%s://%s/v2/%s/%s/manifests/%s", scheme, n.Host(), n.Namespace(), n.Model(), n.Tag())
- if d.IsValid() {
- manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d)
- }
- res, err := r.send(ctx, "GET", manifestURL, nil)
- if err != nil {
- return nil, err
- }
- defer res.Body.Close()
- data, err := io.ReadAll(res.Body)
- if err != nil {
- return nil, err
- }
- // TODO(bmizerany): return digest here
- m, err := unmarshalManifest(n, data)
- if err != nil {
- return nil, fmt.Errorf("%s: %w", name, errors.Join(ErrManifestInvalid, err))
- }
- return m, nil
- }
- type chunksum struct {
- URL string
- Chunk blob.Chunk
- Digest blob.Digest
- }
- // chunksums returns a sequence of chunksums for the given layer. If the layer is under the
- // chunking threshold, a single chunksum is returned that covers the entire layer. If the layer
- // is over the chunking threshold, the chunksums are read from the chunksums endpoint.
- func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] {
- return func(yield func(chunksum, error) bool) {
- scheme, n, _, err := r.parseNameExtended(name)
- if err != nil {
- yield(chunksum{}, err)
- return
- }
- if l.Size < r.maxChunkingThreshold() {
- // any layer under the threshold should be downloaded
- // in one go.
- cs := chunksum{
- URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
- scheme,
- n.Host(),
- n.Namespace(),
- n.Model(),
- l.Digest,
- ),
- Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
- Digest: l.Digest,
- }
- yield(cs, nil)
- return
- }
- // A chunksums response is a sequence of chunksums in a
- // simple, easy to parse line-oriented format.
- //
- // Example:
- //
- // >> GET /v2/<namespace>/<model>/chunksums/<digest>
- //
- // << HTTP/1.1 200 OK
- // << Content-Location: <blobURL>
- // <<
- // << <digest> <start>-<end>
- // << ...
- //
- // The blobURL is the URL to download the chunks from.
- chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
- scheme,
- n.Host(),
- n.Namespace(),
- n.Model(),
- l.Digest,
- )
- req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
- if err != nil {
- yield(chunksum{}, err)
- return
- }
- res, err := sendRequest(r.client(), req)
- if err != nil {
- yield(chunksum{}, err)
- return
- }
- defer res.Body.Close()
- if res.StatusCode != 200 {
- err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
- yield(chunksum{}, err)
- return
- }
- blobURL := res.Header.Get("Content-Location")
- s := bufio.NewScanner(res.Body)
- s.Split(bufio.ScanWords)
- for {
- if !s.Scan() {
- if s.Err() != nil {
- yield(chunksum{}, s.Err())
- }
- return
- }
- d, err := blob.ParseDigest(s.Bytes())
- if err != nil {
- yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes()))
- return
- }
- if !s.Scan() {
- err := s.Err()
- if err == nil {
- err = fmt.Errorf("missing chunk range for digest %s", d)
- }
- yield(chunksum{}, err)
- return
- }
- chunk, err := parseChunk(s.Bytes())
- if err != nil {
- yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()))
- return
- }
- cs := chunksum{
- URL: blobURL,
- Chunk: chunk,
- Digest: d,
- }
- if !yield(cs, nil) {
- return
- }
- }
- }
- }
- func (r *Registry) client() *http.Client {
- if r.HTTPClient != nil {
- return r.HTTPClient
- }
- return http.DefaultClient
- }
- // newRequest constructs a new request, ready to use, with the given method,
- // url, and body, pre-signed with client [Key] and [UserAgent].
- func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
- req, err := http.NewRequestWithContext(ctx, method, url, body)
- if err != nil {
- return nil, err
- }
- if r.UserAgent != "" {
- req.Header.Set("User-Agent", r.UserAgent)
- }
- if r.Key != nil {
- token, err := makeAuthToken(r.Key)
- if err != nil {
- return nil, err
- }
- req.Header.Set("Authorization", "Bearer "+token)
- }
- return req, nil
- }
- // sendRequest makes a request with the given client and request, and returns the
- // response if the status code is 200. If the status code is not 200, an Error
- // is parsed from the response body and returned. If any other error occurs, it
- // is returned.
- func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) {
- defer func() {
- if err != nil {
- err = fmt.Errorf("request error %s: %w", r.URL, err)
- }
- }()
- if r.URL.Scheme == "https+insecure" {
- // TODO(bmizerany): clone client.Transport, set
- // InsecureSkipVerify, etc.
- type cloner interface {
- Clone() *http.Transport
- }
- // Attempt to configure the transport to skip TLS verification
- // if we can clone it, otherwise fall through and let the http
- // client complain and the scheme being invalid.
- x, ok := cmp.Or(c.Transport, http.DefaultTransport).(cloner)
- if ok {
- tr := x.Clone()
- tr.TLSClientConfig = cmp.Or(tr.TLSClientConfig, &tls.Config{})
- tr.TLSClientConfig.InsecureSkipVerify = true
- cc := *c // shallow copy
- cc.Transport = tr
- c = &cc
- r = r.Clone(r.Context())
- r.URL.Scheme = "https"
- // fall through
- }
- }
- res, err := c.Do(r)
- if err != nil {
- return nil, err
- }
- if res.StatusCode/100 != 2 {
- out, err := io.ReadAll(res.Body)
- if err != nil {
- return nil, err
- }
- var re Error
- if err := json.Unmarshal(out, &re); err != nil {
- // Use the raw body if we can't parse it as an error object.
- re.Message = string(out)
- }
- // coerce MANIFEST_UNKNOWN to ErrManifestNotFound
- if strings.EqualFold(re.Code, "MANIFEST_UNKNOWN") {
- return nil, ErrModelNotFound
- }
- re.Status = res.StatusCode
- return nil, &re
- }
- return res, nil
- }
- // send is a convenience method for making a request with newRequest and
- // passing it to send with r.client().
- func (r *Registry) send(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
- req, err := r.newRequest(ctx, method, path, body)
- if err != nil {
- return nil, err
- }
- return sendRequest(r.client(), req)
- }
- // makeAuthToken creates an Ollama auth token for the given private key.
- //
- // NOTE: This format is OLD, overly complex, and should be replaced. We're
- // inheriting it from the original Ollama client and ollama.com
- // implementations, so we need to support it for now.
- func makeAuthToken(key crypto.PrivateKey) (string, error) {
- privKey, _ := key.(*ed25519.PrivateKey)
- if privKey == nil {
- return "", fmt.Errorf("unsupported private key type: %T", key)
- }
- url := fmt.Sprintf("https://ollama.com?ts=%d", time.Now().Unix())
- // Part 1: the checkData (e.g. the URL with a timestamp)
- // Part 2: the public key
- pubKeyShort, err := func() ([]byte, error) {
- sshPubKey, err := ssh.NewPublicKey(privKey.Public())
- if err != nil {
- return nil, err
- }
- pubKeyParts := bytes.Fields(ssh.MarshalAuthorizedKey(sshPubKey))
- if len(pubKeyParts) < 2 {
- return nil, fmt.Errorf("malformed public key: %q", pubKeyParts)
- }
- pubKeyShort := pubKeyParts[1]
- return pubKeyShort, nil
- }()
- if err != nil {
- return "", err
- }
- // Part 3: the signature
- sig := ed25519.Sign(*privKey, []byte(checkData(url)))
- // Assemble the token: <checkData>:<pubKey>:<signature>
- var b strings.Builder
- io.WriteString(&b, base64.StdEncoding.EncodeToString([]byte(url)))
- b.WriteByte(':')
- b.Write(pubKeyShort)
- b.WriteByte(':')
- io.WriteString(&b, base64.StdEncoding.EncodeToString(sig))
- return b.String(), nil
- }
- // The original spec for Ollama tokens was to use the SHA256 of the zero
- // string as part of the signature. I'm not sure why that was, but we still
- // need it to verify the signature.
- var zeroSum = func() string {
- sha256sum := sha256.Sum256(nil)
- x := base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))
- return x
- }()
- // checkData takes a URL and creates the original string format of the
- // data signature that is used by the ollama client to sign requests
- func checkData(url string) string {
- return fmt.Sprintf("GET,%s,%s", url, zeroSum)
- }
- type publicError struct {
- wrapped error
- message string
- }
- func withPublicMessagef(err error, message string, args ...any) error {
- return publicError{wrapped: err, message: fmt.Sprintf(message, args...)}
- }
- func (e publicError) Error() string { return e.message }
- func (e publicError) Unwrap() error { return e.wrapped }
- var supportedSchemes = []string{
- "http",
- "https",
- "https+insecure",
- }
- var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
- // parseNameExtended parses and validates an extended name, returning the scheme, name,
- // and digest.
- //
- // If the scheme is empty, scheme will be "https". If an unsupported scheme is
- // given, [ErrNameInvalid] wrapped with a display friendly message is returned.
- //
- // If the digest is invalid, [ErrNameInvalid] wrapped with a display friendly
- // message is returned.
- //
- // If the name is not, once merged with the mask, fully qualified,
- // [ErrNameInvalid] wrapped with a display friendly message is returned.
- func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
- scheme, name, digest := splitExtended(s)
- scheme = cmp.Or(scheme, "https")
- if !slices.Contains(supportedSchemes, scheme) {
- err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
- return "", names.Name{}, blob.Digest{}, err
- }
- var d blob.Digest
- if digest != "" {
- var err error
- d, err = blob.ParseDigest(digest)
- if err != nil {
- err = withPublicMessagef(ErrNameInvalid, "invalid digest: %q", digest)
- return "", names.Name{}, blob.Digest{}, err
- }
- if name == "" {
- // We have can resolve a manifest from a digest only,
- // so skip name validation and return the scheme and
- // digest.
- return scheme, names.Name{}, d, nil
- }
- }
- n, err := r.parseName(name)
- if err != nil {
- return "", names.Name{}, blob.Digest{}, err
- }
- return scheme, n, d, nil
- }
- // splitExtended splits an extended name string into its scheme, name, and digest
- // parts.
- //
- // Examples:
- //
- // http://ollama.com/bmizerany/smol:latest@digest
- // https://ollama.com/bmizerany/smol:latest
- // ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
- // model@digest
- // @digest
- func splitExtended(s string) (scheme, name, digest string) {
- i := strings.Index(s, "://")
- if i >= 0 {
- scheme = s[:i]
- s = s[i+3:]
- }
- i = strings.LastIndex(s, "@")
- if i >= 0 {
- digest = s[i+1:]
- s = s[:i]
- }
- return scheme, s, digest
- }
- // parseChunk parses a string in the form "start-end" and returns the Chunk.
- func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) {
- startPart, endPart, found := strings.Cut(string(s), "-")
- if !found {
- return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s)
- }
- start, err := strconv.ParseInt(startPart, 10, 64)
- if err != nil {
- return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err)
- }
- end, err := strconv.ParseInt(endPart, 10, 64)
- if err != nil {
- return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err)
- }
- if start > end {
- return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s)
- }
- return blob.Chunk{Start: start, End: end}, nil
- }
|