123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401 |
- package registry
- import (
- "cmp"
- "context"
- "encoding/json"
- "encoding/xml"
- "errors"
- "fmt"
- "io"
- "iter"
- "log/slog"
- "net/http"
- "net/url"
- "os"
- "sync"
- "github.com/ollama/ollama/client/ollama"
- "github.com/ollama/ollama/client/registry/apitype"
- "github.com/ollama/ollama/types/model"
- "golang.org/x/exp/constraints"
- "golang.org/x/sync/errgroup"
- )
- // Errors
- var (
- ErrLayerNotFound = errors.New("layer not found")
- )
- type Client struct {
- BaseURL string
- Logger *slog.Logger
- }
- func (c *Client) logger() *slog.Logger {
- return cmp.Or(c.Logger, slog.Default())
- }
- func (c *Client) oclient() *ollama.Client {
- return &ollama.Client{
- BaseURL: c.BaseURL,
- }
- }
- type ReadAtSeekCloser interface {
- io.ReaderAt
- io.Seeker
- io.Closer
- }
- type Cache interface {
- // LayerFile returns the absolute file path to the layer file for
- // the given model digest.
- //
- // If the digest is invalid, or the layer does not exist, the empty
- // string is returned.
- LayerFile(model.Digest) string
- // OpenLayer opens the layer file for the given model digest and
- // returns it, or an if any. The caller is responsible for closing
- // the returned file.
- OpenLayer(model.Digest) (ReadAtSeekCloser, error)
- // PutLayerFile moves the layer file at fromPath to the cache for
- // the given model digest. It is a hack intended to short circuit a
- // file copy operation.
- //
- // The file returned is expected to exist for the lifetime of the
- // cache.
- //
- // TODO(bmizerany): remove this; find a better way. Once we move
- // this into a build package, we should be able to get rid of this.
- PutLayerFile(_ model.Digest, fromPath string) error
- // SetManifestData sets the provided manifest data for the given
- // model name. If the manifest data is empty, the manifest is
- // removed. If the manifeest exists, it is overwritten.
- //
- // It is an error to call SetManifestData with a name that is not
- // complete.
- SetManifestData(model.Name, []byte) error
- // ManifestData returns the manifest data for the given model name.
- //
- // If the name incomplete, or the manifest does not exist, the empty
- // string is returned.
- ManifestData(name model.Name) []byte
- }
- // Pull pulls the manifest for name, and downloads any of its required
- // layers that are not already in the cache. It returns an error if any part
- // of the process fails, specifically:
- func (c *Client) Pull(ctx context.Context, cache Cache, name string) error {
- mn := model.ParseName(name)
- if !mn.IsFullyQualified() {
- return fmt.Errorf("ollama: pull: invalid name: %s", name)
- }
- log := c.logger().With("name", name)
- pr, err := ollama.Do[*apitype.PullResponse](ctx, c.oclient(), "GET", "/v1/pull/"+name, nil)
- if err != nil {
- return fmt.Errorf("ollama: pull: %w: %s", err, name)
- }
- if pr.Manifest == nil || len(pr.Manifest.Layers) == 0 {
- return fmt.Errorf("ollama: pull: invalid manifest: %s: no layers found", name)
- }
- // download required layers we do not already have
- for _, l := range pr.Manifest.Layers {
- d, err := model.ParseDigest(l.Digest)
- if err != nil {
- return fmt.Errorf("ollama: reading manifest: %w: %s", err, l.Digest)
- }
- if cache.LayerFile(d) != "" {
- continue
- }
- err = func() error {
- log := log.With("digest", l.Digest, "mediaType", l.MediaType, "size", l.Size)
- log.Debug("starting download")
- // TODO(bmizerany): stop using temp which might not
- // be on same device as cache.... instead let cache
- // give us a place to store parts...
- tmpFile, err := os.CreateTemp("", "ollama-download-")
- if err != nil {
- return err
- }
- defer func() {
- tmpFile.Close()
- os.Remove(tmpFile.Name()) // in case we fail before committing
- }()
- g, ctx := errgroup.WithContext(ctx)
- g.SetLimit(8) // TODO(bmizerany): make this configurable
- // TODO(bmizerany): make chunk size configurable
- const chunkSize = 50 * 1024 * 1024 // 50MB
- chunks(l.Size, chunkSize)(func(_ int, rng chunkRange[int64]) bool {
- g.Go(func() (err error) {
- defer func() {
- if err == nil {
- return
- }
- safeURL := redactAmzSignature(l.URL)
- err = fmt.Errorf("%w: %s %s bytes=%s: %s", err, pr.Name, l.Digest, rng, safeURL)
- }()
- log.Debug("downloading", "range", rng)
- // TODO(bmizerany): retry
- // TODO(bmizerany): use real http client
- // TODO(bmizerany): resumable
- // TODO(bmizerany): multipart download
- req, err := http.NewRequestWithContext(ctx, "GET", l.URL, nil)
- if err != nil {
- return err
- }
- req.Header.Set("Range", "bytes="+rng.String())
- res, err := http.DefaultClient.Do(req)
- if err != nil {
- return err
- }
- defer res.Body.Close()
- if res.StatusCode/100 != 2 {
- log.Debug("unexpected non-2XX status code", "status", res.StatusCode)
- return fmt.Errorf("unexpected status code fetching layer: %d", res.StatusCode)
- }
- if res.ContentLength != rng.Size() {
- return fmt.Errorf("unexpected content length: %d", res.ContentLength)
- }
- w := io.NewOffsetWriter(tmpFile, rng.Start)
- _, err = io.Copy(w, res.Body)
- return err
- })
- return true
- })
- if err := g.Wait(); err != nil {
- return err
- }
- tmpFile.Close() // release our hold on the file before moving it
- return cache.PutLayerFile(d, tmpFile.Name())
- }()
- if err != nil {
- return fmt.Errorf("ollama: pull: %w", err)
- }
- }
- // do not store the presigned URLs in the cache
- for i := range pr.Manifest.Layers {
- pr.Manifest.Layers[i].URL = ""
- }
- data, err := json.Marshal(pr.Manifest)
- if err != nil {
- return err
- }
- // TODO(bmizerany): remove dep on model.Name
- return cache.SetManifestData(mn, data)
- }
- type nopSeeker struct {
- io.Reader
- }
- func (nopSeeker) Seek(int64, int) (int64, error) {
- return 0, nil
- }
- // Push pushes a manifest to the server and responds to the server's
- // requests for layer uploads, if any, and finally commits the manifest for
- // name. It returns an error if any part of the process fails, specifically:
- //
- // If the server requests layers not found in the cache, ErrLayerNotFound is
- // returned.
- func (c *Client) Push(ctx context.Context, cache Cache, name string) error {
- mn := model.ParseName(name)
- if !mn.IsFullyQualified() {
- return fmt.Errorf("ollama: push: invalid name: %s", name)
- }
- manifest := cache.ManifestData(mn)
- if len(manifest) == 0 {
- return fmt.Errorf("manifest not found: %s", name)
- }
- var mu sync.Mutex
- var completed []*apitype.CompletePart
- push := func() (*apitype.PushResponse, error) {
- v, err := ollama.Do[*apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{
- Name: name,
- Manifest: manifest,
- CompleteParts: completed,
- })
- if err != nil {
- return nil, fmt.Errorf("Do: %w", err)
- }
- return v, nil
- }
- pr, err := push()
- if err != nil {
- return err
- }
- var g errgroup.Group
- for _, need := range pr.Needs {
- g.Go(func() error {
- nd, err := model.ParseDigest(need.Digest)
- if err != nil {
- return fmt.Errorf("ParseDigest: %w: %s", err, need.Digest)
- }
- f, err := cache.OpenLayer(nd)
- if err != nil {
- return fmt.Errorf("OpenLayer: %w: %s", err, need.Digest)
- }
- defer f.Close()
- cp, err := PushLayer(ctx, f, need.URL, need.Start, need.End)
- if err != nil {
- return fmt.Errorf("PushLayer: %w: %s", err, need.Digest)
- }
- mu.Lock()
- completed = append(completed, cp)
- mu.Unlock()
- return nil
- })
- }
- if err := g.Wait(); err != nil {
- return fmt.Errorf("Push: Required: %w", err)
- }
- if len(completed) > 0 {
- pr, err := push()
- if err != nil {
- return err
- }
- if len(pr.Needs) > 0 {
- var errs []error
- for _, r := range pr.Needs {
- errs = append(errs, fmt.Errorf("Push: server failed to find part: %q", r.Digest))
- }
- return errors.Join(errs...)
- }
- }
- return cache.SetManifestData(mn, manifest)
- }
- func PushLayer(ctx context.Context, body io.ReaderAt, url string, start, end int64) (*apitype.CompletePart, error) {
- if start < 0 || end < start {
- return nil, errors.New("start must satisfy 0 <= start <= end")
- }
- file := io.NewSectionReader(body, start, end-start+1)
- req, err := http.NewRequest("PUT", url, file)
- if err != nil {
- return nil, err
- }
- req.ContentLength = end - start + 1
- // TODO(bmizerany): take content type param
- req.Header.Set("Content-Type", "text/plain")
- if start != 0 || end != 0 {
- req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", start, end))
- }
- res, err := http.DefaultClient.Do(req)
- if err != nil {
- return nil, err
- }
- defer res.Body.Close()
- if res.StatusCode != 200 {
- e := parseS3Error(res)
- return nil, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e)
- }
- cp := &apitype.CompletePart{
- URL: url,
- ETag: res.Header.Get("ETag"),
- // TODO(bmizerany): checksum
- }
- return cp, nil
- }
- type s3Error struct {
- XMLName xml.Name `xml:"Error"`
- Code string `xml:"Code"`
- Message string `xml:"Message"`
- Resource string `xml:"Resource"`
- RequestId string `xml:"RequestId"`
- }
- func (e *s3Error) Error() string {
- return fmt.Sprintf("S3 (%s): %s: %s: %s", e.RequestId, e.Resource, e.Code, e.Message)
- }
- // parseS3Error parses an XML error response from S3.
- func parseS3Error(res *http.Response) error {
- var se *s3Error
- if err := xml.NewDecoder(res.Body).Decode(&se); err != nil {
- return err
- }
- return se
- }
- // TODO: replace below by using upload pkg after we have rangefunc; until
- // then, we need to keep this free of rangefunc for now.
- type chunkRange[I constraints.Integer] struct {
- // Start is the byte offset of the chunk.
- Start I
- // End is the byte offset of the last byte in the chunk.
- End I
- }
- func (c chunkRange[I]) Size() I {
- return c.End - c.Start + 1
- }
- func (c chunkRange[I]) String() string {
- return fmt.Sprintf("%d-%d", c.Start, c.End)
- }
- func (c chunkRange[I]) LogValue() slog.Value {
- return slog.StringValue(c.String())
- }
- // Chunks yields a sequence of a part number and a Chunk. The Chunk is the offset
- // and size of the chunk. The last chunk may be smaller than chunkSize if size is
- // not a multiple of chunkSize.
- //
- // The first part number is 1 and increases monotonically.
- func chunks[I constraints.Integer](size, chunkSize I) iter.Seq2[int, chunkRange[I]] {
- return func(yield func(int, chunkRange[I]) bool) {
- var n int
- for off := I(0); off < size; off += chunkSize {
- n++
- if !yield(n, chunkRange[I]{
- Start: off,
- End: off + min(chunkSize, size-off) - 1,
- }) {
- return
- }
- }
- }
- }
- func redactAmzSignature(s string) string {
- u, err := url.Parse(s)
- if err != nil {
- return ""
- }
- q := u.Query()
- q.Set("X-Amz-Signature", "REDACTED")
- u.RawQuery = q.Encode()
- return u.String()
- }
|