|
@@ -0,0 +1,384 @@
|
|
|
+package registry
|
|
|
+
|
|
|
+import (
|
|
|
+ "cmp"
|
|
|
+ "context"
|
|
|
+ "encoding/json"
|
|
|
+ "encoding/xml"
|
|
|
+ "errors"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "iter"
|
|
|
+ "log/slog"
|
|
|
+ "net/http"
|
|
|
+ "net/url"
|
|
|
+ "os"
|
|
|
+ "sync"
|
|
|
+
|
|
|
+ "github.com/ollama/ollama/client/ollama"
|
|
|
+ "github.com/ollama/ollama/client/registry/apitype"
|
|
|
+ "golang.org/x/exp/constraints"
|
|
|
+ "golang.org/x/sync/errgroup"
|
|
|
+)
|
|
|
+
|
|
|
+// Errors
|
|
|
+var (
|
|
|
+ ErrLayerNotFound = errors.New("layer not found")
|
|
|
+)
|
|
|
+
|
|
|
+type Client struct {
|
|
|
+ BaseURL string
|
|
|
+
|
|
|
+ // TODO(bmizerany): remove NameFill (once we remove model dep here)
|
|
|
+ NameFill string
|
|
|
+
|
|
|
+ Logger *slog.Logger
|
|
|
+}
|
|
|
+
|
|
|
+func (c *Client) logger() *slog.Logger {
|
|
|
+ return cmp.Or(c.Logger, slog.Default())
|
|
|
+}
|
|
|
+
|
|
|
+func (c *Client) oclient() *ollama.Client {
|
|
|
+ return &ollama.Client{
|
|
|
+ BaseURL: c.BaseURL,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+type ReadAtSeekCloser interface {
|
|
|
+ io.ReaderAt
|
|
|
+ io.Seeker
|
|
|
+ io.Closer
|
|
|
+}
|
|
|
+
|
|
|
+type Cache interface {
|
|
|
+ // LayerFile returns the absolute file path to the layer file for
|
|
|
+ // the given model digest.
|
|
|
+ //
|
|
|
+ // If the digest is invalid, or the layer does not exist, the empty
|
|
|
+ // string is returned.
|
|
|
+ LayerFile(digest string) string
|
|
|
+
|
|
|
+ // OpenLayer opens the layer file for the given model digest and
|
|
|
+ // returns it, or an if any. The caller is responsible for closing
|
|
|
+ // the returned file.
|
|
|
+ OpenLayer(digest string) (ReadAtSeekCloser, error)
|
|
|
+
|
|
|
+ // PutLayerFile moves the layer file at fromPath to the cache for
|
|
|
+ // the given model digest. It is a hack intended to short circuit a
|
|
|
+ // file copy operation.
|
|
|
+ //
|
|
|
+ // TODO(bmizerany): remove this; find a better way. Once we move
|
|
|
+ // this into a build package, we should be able to get rid of this.
|
|
|
+ PutLayerFile(digest, fromPath string) error
|
|
|
+
|
|
|
+ // SetManifestData sets the provided manifest data for the given
|
|
|
+ // model name. If the manifest data is empty, the manifest is
|
|
|
+ // removed. If the manifeest exists, it is overwritten.
|
|
|
+ //
|
|
|
+ // It is an error to call SetManifestData with a name that is not
|
|
|
+ // complete.
|
|
|
+ SetManifestData(name string, data []byte) error
|
|
|
+
|
|
|
+ // ManifestData returns the manifest data for the given model name.
|
|
|
+ //
|
|
|
+ // If the name incomplete, or the manifest does not exist, the empty
|
|
|
+ // string is returned.
|
|
|
+ ManifestData(name string) []byte
|
|
|
+}
|
|
|
+
|
|
|
+// Pull pulls the manifest for name, and downloads any of its required
|
|
|
+// layers that are not already in the cache. It returns an error if any part
|
|
|
+// of the process fails, specifically:
|
|
|
+func (c *Client) Pull(ctx context.Context, cache Cache, name string) error {
|
|
|
+ log := c.logger().With("name", name)
|
|
|
+
|
|
|
+ pr, err := ollama.Do[*apitype.PullResponse](ctx, c.oclient(), "GET", "/v1/pull/"+name, nil)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("ollama: pull: %w: %s", err, name)
|
|
|
+ }
|
|
|
+
|
|
|
+ if pr.Manifest == nil || len(pr.Manifest.Layers) == 0 {
|
|
|
+ return fmt.Errorf("ollama: pull: invalid manifest: %s: no layers found", name)
|
|
|
+ }
|
|
|
+
|
|
|
+ // download required layers we do not already have
|
|
|
+ for _, l := range pr.Manifest.Layers {
|
|
|
+ if cache.LayerFile(l.Digest) != "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ err := func() error {
|
|
|
+ log := log.With("digest", l.Digest, "mediaType", l.MediaType, "size", l.Size)
|
|
|
+ log.Debug("starting download")
|
|
|
+
|
|
|
+ // TODO(bmizerany): stop using temp which might not
|
|
|
+ // be on same device as cache.... instead let cache
|
|
|
+ // give us a place to store parts...
|
|
|
+ tmpFile, err := os.CreateTemp("", "ollama-download-")
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ defer func() {
|
|
|
+ tmpFile.Close()
|
|
|
+ os.Remove(tmpFile.Name()) // in case we fail before committing
|
|
|
+ }()
|
|
|
+
|
|
|
+ g, ctx := errgroup.WithContext(ctx)
|
|
|
+ g.SetLimit(8) // TODO(bmizerany): make this configurable
|
|
|
+
|
|
|
+ // TODO(bmizerany): make chunk size configurable
|
|
|
+ const chunkSize = 50 * 1024 * 1024 // 50MB
|
|
|
+ chunks(l.Size, chunkSize)(func(_ int, rng chunkRange[int64]) bool {
|
|
|
+ g.Go(func() (err error) {
|
|
|
+ defer func() {
|
|
|
+ if err == nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ safeURL := redactAmzSignature(l.URL)
|
|
|
+ err = fmt.Errorf("%w: %s %s bytes=%s: %s", err, pr.Name, l.Digest, rng, safeURL)
|
|
|
+ }()
|
|
|
+
|
|
|
+ log.Debug("downloading", "range", rng)
|
|
|
+
|
|
|
+ // TODO(bmizerany): retry
|
|
|
+ // TODO(bmizerany): use real http client
|
|
|
+ // TODO(bmizerany): resumable
|
|
|
+ // TODO(bmizerany): multipart download
|
|
|
+ req, err := http.NewRequestWithContext(ctx, "GET", l.URL, nil)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ req.Header.Set("Range", "bytes="+rng.String())
|
|
|
+
|
|
|
+ res, err := http.DefaultClient.Do(req)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ defer res.Body.Close()
|
|
|
+ if res.StatusCode/100 != 2 {
|
|
|
+ log.Debug("unexpected non-2XX status code", "status", res.StatusCode)
|
|
|
+ return fmt.Errorf("unexpected status code fetching layer: %d", res.StatusCode)
|
|
|
+ }
|
|
|
+ if res.ContentLength != rng.Size() {
|
|
|
+ return fmt.Errorf("unexpected content length: %d", res.ContentLength)
|
|
|
+ }
|
|
|
+ w := io.NewOffsetWriter(tmpFile, rng.Start)
|
|
|
+ _, err = io.Copy(w, res.Body)
|
|
|
+ return err
|
|
|
+ })
|
|
|
+ return true
|
|
|
+ })
|
|
|
+ if err := g.Wait(); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ tmpFile.Close() // release our hold on the file before moving it
|
|
|
+ return cache.PutLayerFile(l.Digest, tmpFile.Name())
|
|
|
+ }()
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("ollama: pull: %w", err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // do not store the presigned URLs in the cache
|
|
|
+ for i := range pr.Manifest.Layers {
|
|
|
+ pr.Manifest.Layers[i].URL = ""
|
|
|
+ }
|
|
|
+ data, err := json.Marshal(pr.Manifest)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ // TODO(bmizerany): remove dep on model.Name
|
|
|
+ return cache.SetManifestData(name, data)
|
|
|
+}
|
|
|
+
|
|
|
+type nopSeeker struct {
|
|
|
+ io.Reader
|
|
|
+}
|
|
|
+
|
|
|
+func (nopSeeker) Seek(int64, int) (int64, error) {
|
|
|
+ return 0, nil
|
|
|
+}
|
|
|
+
|
|
|
+// Push pushes a manifest to the server and responds to the server's
|
|
|
+// requests for layer uploads, if any, and finally commits the manifest for
|
|
|
+// name. It returns an error if any part of the process fails, specifically:
|
|
|
+//
|
|
|
+// If the server requests layers not found in the cache, ErrLayerNotFound is
|
|
|
+// returned.
|
|
|
+func (c *Client) Push(ctx context.Context, cache Cache, name string) error {
|
|
|
+ // TODO(bmizerany): remove dep on model.Name
|
|
|
+ manifest := cache.ManifestData(name)
|
|
|
+ if len(manifest) == 0 {
|
|
|
+ return fmt.Errorf("manifest not found: %s", name)
|
|
|
+ }
|
|
|
+
|
|
|
+ var mu sync.Mutex
|
|
|
+ var completed []*apitype.CompletePart
|
|
|
+ push := func() (*apitype.PushResponse, error) {
|
|
|
+ v, err := ollama.Do[*apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{
|
|
|
+ Name: name,
|
|
|
+ Manifest: manifest,
|
|
|
+ CompleteParts: completed,
|
|
|
+ })
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("Do: %w", err)
|
|
|
+ }
|
|
|
+ return v, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ pr, err := push()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ var g errgroup.Group
|
|
|
+ for _, need := range pr.Needs {
|
|
|
+ g.Go(func() error {
|
|
|
+ f, err := cache.OpenLayer(need.Digest)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("OpenLayer: %w: %s", err, need.Digest)
|
|
|
+ }
|
|
|
+ defer f.Close()
|
|
|
+
|
|
|
+ cp, err := PushLayer(ctx, f, need.URL, need.Start, need.End)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("PushLayer: %w: %s", err, need.Digest)
|
|
|
+ }
|
|
|
+ mu.Lock()
|
|
|
+ completed = append(completed, cp)
|
|
|
+ mu.Unlock()
|
|
|
+ return nil
|
|
|
+ })
|
|
|
+ }
|
|
|
+ if err := g.Wait(); err != nil {
|
|
|
+ return fmt.Errorf("Push: Required: %w", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(completed) > 0 {
|
|
|
+ pr, err := push()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if len(pr.Needs) > 0 {
|
|
|
+ var errs []error
|
|
|
+ for _, r := range pr.Needs {
|
|
|
+ errs = append(errs, fmt.Errorf("Push: server failed to find part: %q", r.Digest))
|
|
|
+ }
|
|
|
+ return errors.Join(errs...)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return cache.SetManifestData(name, manifest)
|
|
|
+}
|
|
|
+
|
|
|
+func PushLayer(ctx context.Context, body io.ReaderAt, url string, start, end int64) (*apitype.CompletePart, error) {
|
|
|
+ if start < 0 || end < start {
|
|
|
+ return nil, errors.New("start must satisfy 0 <= start <= end")
|
|
|
+ }
|
|
|
+
|
|
|
+ file := io.NewSectionReader(body, start, end-start+1)
|
|
|
+ req, err := http.NewRequest("PUT", url, file)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ req.ContentLength = end - start + 1
|
|
|
+
|
|
|
+ // TODO(bmizerany): take content type param
|
|
|
+ req.Header.Set("Content-Type", "text/plain")
|
|
|
+
|
|
|
+ if start != 0 || end != 0 {
|
|
|
+ req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", start, end))
|
|
|
+ }
|
|
|
+
|
|
|
+ res, err := http.DefaultClient.Do(req)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ defer res.Body.Close()
|
|
|
+ if res.StatusCode != 200 {
|
|
|
+ e := parseS3Error(res)
|
|
|
+ return nil, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e)
|
|
|
+ }
|
|
|
+ cp := &apitype.CompletePart{
|
|
|
+ URL: url,
|
|
|
+ ETag: res.Header.Get("ETag"),
|
|
|
+ // TODO(bmizerany): checksum
|
|
|
+ }
|
|
|
+ return cp, nil
|
|
|
+}
|
|
|
+
|
|
|
+type s3Error struct {
|
|
|
+ XMLName xml.Name `xml:"Error"`
|
|
|
+ Code string `xml:"Code"`
|
|
|
+ Message string `xml:"Message"`
|
|
|
+ Resource string `xml:"Resource"`
|
|
|
+ RequestId string `xml:"RequestId"`
|
|
|
+}
|
|
|
+
|
|
|
+func (e *s3Error) Error() string {
|
|
|
+ return fmt.Sprintf("S3 (%s): %s: %s: %s", e.RequestId, e.Resource, e.Code, e.Message)
|
|
|
+}
|
|
|
+
|
|
|
+// parseS3Error parses an XML error response from S3.
|
|
|
+func parseS3Error(res *http.Response) error {
|
|
|
+ var se *s3Error
|
|
|
+ if err := xml.NewDecoder(res.Body).Decode(&se); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ return se
|
|
|
+}
|
|
|
+
|
|
|
+// TODO: replace below by using upload pkg after we have rangefunc; until
|
|
|
+// then, we need to keep this free of rangefunc for now.
|
|
|
+type chunkRange[I constraints.Integer] struct {
|
|
|
+ // Start is the byte offset of the chunk.
|
|
|
+ Start I
|
|
|
+
|
|
|
+ // End is the byte offset of the last byte in the chunk.
|
|
|
+ End I
|
|
|
+}
|
|
|
+
|
|
|
+func (c chunkRange[I]) Size() I {
|
|
|
+ return c.End - c.Start + 1
|
|
|
+}
|
|
|
+
|
|
|
+func (c chunkRange[I]) String() string {
|
|
|
+ return fmt.Sprintf("%d-%d", c.Start, c.End)
|
|
|
+}
|
|
|
+
|
|
|
+func (c chunkRange[I]) LogValue() slog.Value {
|
|
|
+ return slog.StringValue(c.String())
|
|
|
+}
|
|
|
+
|
|
|
+// Chunks yields a sequence of a part number and a Chunk. The Chunk is the offset
|
|
|
+// and size of the chunk. The last chunk may be smaller than chunkSize if size is
|
|
|
+// not a multiple of chunkSize.
|
|
|
+//
|
|
|
+// The first part number is 1 and increases monotonically.
|
|
|
+func chunks[I constraints.Integer](size, chunkSize I) iter.Seq2[int, chunkRange[I]] {
|
|
|
+ return func(yield func(int, chunkRange[I]) bool) {
|
|
|
+ var n int
|
|
|
+ for off := I(0); off < size; off += chunkSize {
|
|
|
+ n++
|
|
|
+ if !yield(n, chunkRange[I]{
|
|
|
+ Start: off,
|
|
|
+ End: off + min(chunkSize, size-off) - 1,
|
|
|
+ }) {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func redactAmzSignature(s string) string {
|
|
|
+ u, err := url.Parse(s)
|
|
|
+ if err != nil {
|
|
|
+ return ""
|
|
|
+ }
|
|
|
+ q := u.Query()
|
|
|
+ q.Set("X-Amz-Signature", "REDACTED")
|
|
|
+ u.RawQuery = q.Encode()
|
|
|
+ return u.String()
|
|
|
+}
|