Blake Mizerany 1 year ago
parent
commit
fdef9a0eb2
2 changed files with 479 additions and 0 deletions
  1. 95 0
      client/registry/apitype/apitype.go
  2. 384 0
      client/registry/registry.go

+ 95 - 0
client/registry/apitype/apitype.go

@@ -0,0 +1,95 @@
+package apitype
+
+import (
+	"cmp"
+	"encoding/json"
+	"log/slog"
+	"net/url"
+	"slices"
+)
+
+type Manifest struct {
+	Layers []*Layer `json:"layers"`
+}
+
+type CompletePart struct {
+	URL  string `json:"url"` // contains partNumber and uploadId from server
+	ETag string `json:"etag"`
+}
+
+func queryFromString(s string) url.Values {
+	u, err := url.Parse(s)
+	if err != nil {
+		return nil
+	}
+	return u.Query()
+}
+
+func (cp *CompletePart) Compare(o *CompletePart) int {
+	qa := queryFromString(cp.URL)
+	qb := queryFromString(o.URL)
+	return cmp.Or(
+		cmp.Compare(qa.Get("partNumber"), qb.Get("partNumber")),
+		cmp.Compare(qa.Get("uploadId"), qb.Get("uploadId")),
+		cmp.Compare(cp.ETag, o.ETag),
+	)
+}
+
+func SortCompleteParts(a []*CompletePart) {
+	slices.SortFunc(a, (*CompletePart).Compare)
+}
+
+type Layer struct {
+	Digest    string `json:"digest"`
+	MediaType string `json:"mediaType"`
+	Size      int64  `json:"size"`
+
+	// If present, URL is a remote location of the layer for fetching.
+	URL string `json:"url,omitempty"`
+}
+
+func (l *Layer) LogValue() slog.Value {
+	return slog.GroupValue(
+		slog.String("digest", l.Digest),
+		slog.String("mediaType", l.MediaType),
+		slog.Int64("size", l.Size),
+		slog.String("url", l.URL),
+	)
+}
+
+type PushRequest struct {
+	Name     string          `json:"ref"`
+	Manifest json.RawMessage `json:"manifest,omitempty"`
+
+	// Parts is a list of upload parts that the client upload in the previous
+	// push.
+	CompleteParts []*CompletePart `json:"part_uploads"`
+}
+
+type Need struct {
+	Digest string `json:"digest"`
+
+	Start int64 `json:"start"`
+	End   int64 `json:"end"`
+
+	// URL is the url to PUT the layer to.
+	//
+	// Clients must include it as the URL, along with the ETag in the
+	// response headers from the PUT request, in the next push request
+	// in the Uploaded field.
+	URL string `json:"url"`
+}
+
+type PushResponse struct {
+	// Needs is a list of digests that the client needs to push before
+	// repushing the manifest.
+	Needs []*Need `json:"requirements,omitempty"`
+}
+
+type PullResponse struct {
+	// Name is the name of the model being pulled.
+	Name string `json:"name"`
+
+	// Manifest is the manifest of the model being pulled.
+	Manifest *Manifest `json:"manifest"`
+}

+ 384 - 0
client/registry/registry.go

@@ -0,0 +1,384 @@
+package registry
+
+import (
+	"cmp"
+	"context"
+	"encoding/json"
+	"encoding/xml"
+	"errors"
+	"fmt"
+	"io"
+	"iter"
+	"log/slog"
+	"net/http"
+	"net/url"
+	"os"
+	"sync"
+
+	"github.com/ollama/ollama/client/ollama"
+	"github.com/ollama/ollama/client/registry/apitype"
+	"golang.org/x/exp/constraints"
+	"golang.org/x/sync/errgroup"
+)
+
+// Errors
+var (
+	ErrLayerNotFound = errors.New("layer not found")
+)
+
+type Client struct {
+	BaseURL string
+
+	// TODO(bmizerany): remove NameFill (once we remove model dep here)
+	NameFill string
+
+	Logger *slog.Logger
+}
+
+func (c *Client) logger() *slog.Logger {
+	return cmp.Or(c.Logger, slog.Default())
+}
+
+func (c *Client) oclient() *ollama.Client {
+	return &ollama.Client{
+		BaseURL: c.BaseURL,
+	}
+}
+
+type ReadAtSeekCloser interface {
+	io.ReaderAt
+	io.Seeker
+	io.Closer
+}
+
+type Cache interface {
+	// LayerFile returns the absolute file path to the layer file for
+	// the given model digest.
+	//
+	// If the digest is invalid, or the layer does not exist, the empty
+	// string is returned.
+	LayerFile(digest string) string
+
+	// OpenLayer opens the layer file for the given model digest and
+	// returns it, or an if any. The caller is responsible for closing
+	// the returned file.
+	OpenLayer(digest string) (ReadAtSeekCloser, error)
+
+	// PutLayerFile moves the layer file at fromPath to the cache for
+	// the given model digest. It is a hack intended to short circuit a
+	// file copy operation.
+	//
+	// TODO(bmizerany): remove this; find a better way. Once we move
+	// this into a build package, we should be able to get rid of this.
+	PutLayerFile(digest, fromPath string) error
+
+	// SetManifestData sets the provided manifest data for the given
+	// model name. If the manifest data is empty, the manifest is
+	// removed. If the manifeest exists, it is overwritten.
+	//
+	// It is an error to call SetManifestData with a name that is not
+	// complete.
+	SetManifestData(name string, data []byte) error
+
+	// ManifestData returns the manifest data for the given model name.
+	//
+	// If the name incomplete, or the manifest does not exist, the empty
+	// string is returned.
+	ManifestData(name string) []byte
+}
+
+// Pull pulls the manifest for name, and downloads any of its required
+// layers that are not already in the cache. It returns an error if any part
+// of the process fails, specifically:
+func (c *Client) Pull(ctx context.Context, cache Cache, name string) error {
+	log := c.logger().With("name", name)
+
+	pr, err := ollama.Do[*apitype.PullResponse](ctx, c.oclient(), "GET", "/v1/pull/"+name, nil)
+	if err != nil {
+		return fmt.Errorf("ollama: pull: %w: %s", err, name)
+	}
+
+	if pr.Manifest == nil || len(pr.Manifest.Layers) == 0 {
+		return fmt.Errorf("ollama: pull: invalid manifest: %s: no layers found", name)
+	}
+
+	// download required layers we do not already have
+	for _, l := range pr.Manifest.Layers {
+		if cache.LayerFile(l.Digest) != "" {
+			continue
+		}
+		err := func() error {
+			log := log.With("digest", l.Digest, "mediaType", l.MediaType, "size", l.Size)
+			log.Debug("starting download")
+
+			// TODO(bmizerany): stop using temp which might not
+			// be on same device as cache.... instead let cache
+			// give us a place to store parts...
+			tmpFile, err := os.CreateTemp("", "ollama-download-")
+			if err != nil {
+				return err
+			}
+			defer func() {
+				tmpFile.Close()
+				os.Remove(tmpFile.Name()) // in case we fail before committing
+			}()
+
+			g, ctx := errgroup.WithContext(ctx)
+			g.SetLimit(8) // TODO(bmizerany): make this configurable
+
+			// TODO(bmizerany): make chunk size configurable
+			const chunkSize = 50 * 1024 * 1024 // 50MB
+			chunks(l.Size, chunkSize)(func(_ int, rng chunkRange[int64]) bool {
+				g.Go(func() (err error) {
+					defer func() {
+						if err == nil {
+							return
+						}
+						safeURL := redactAmzSignature(l.URL)
+						err = fmt.Errorf("%w: %s %s bytes=%s: %s", err, pr.Name, l.Digest, rng, safeURL)
+					}()
+
+					log.Debug("downloading", "range", rng)
+
+					// TODO(bmizerany): retry
+					// TODO(bmizerany): use real http client
+					// TODO(bmizerany): resumable
+					// TODO(bmizerany): multipart download
+					req, err := http.NewRequestWithContext(ctx, "GET", l.URL, nil)
+					if err != nil {
+						return err
+					}
+					req.Header.Set("Range", "bytes="+rng.String())
+
+					res, err := http.DefaultClient.Do(req)
+					if err != nil {
+						return err
+					}
+					defer res.Body.Close()
+					if res.StatusCode/100 != 2 {
+						log.Debug("unexpected non-2XX status code", "status", res.StatusCode)
+						return fmt.Errorf("unexpected status code fetching layer: %d", res.StatusCode)
+					}
+					if res.ContentLength != rng.Size() {
+						return fmt.Errorf("unexpected content length: %d", res.ContentLength)
+					}
+					w := io.NewOffsetWriter(tmpFile, rng.Start)
+					_, err = io.Copy(w, res.Body)
+					return err
+				})
+				return true
+			})
+			if err := g.Wait(); err != nil {
+				return err
+			}
+
+			tmpFile.Close() // release our hold on the file before moving it
+			return cache.PutLayerFile(l.Digest, tmpFile.Name())
+		}()
+		if err != nil {
+			return fmt.Errorf("ollama: pull: %w", err)
+		}
+	}
+
+	// do not store the presigned URLs in the cache
+	for i := range pr.Manifest.Layers {
+		pr.Manifest.Layers[i].URL = ""
+	}
+	data, err := json.Marshal(pr.Manifest)
+	if err != nil {
+		return err
+	}
+
+	// TODO(bmizerany): remove dep on model.Name
+	return cache.SetManifestData(name, data)
+}
+
+type nopSeeker struct {
+	io.Reader
+}
+
+func (nopSeeker) Seek(int64, int) (int64, error) {
+	return 0, nil
+}
+
+// Push pushes a manifest to the server and responds to the server's
+// requests for layer uploads, if any, and finally commits the manifest for
+// name. It returns an error if any part of the process fails, specifically:
+//
+// If the server requests layers not found in the cache, ErrLayerNotFound is
+// returned.
+func (c *Client) Push(ctx context.Context, cache Cache, name string) error {
+	// TODO(bmizerany): remove dep on model.Name
+	manifest := cache.ManifestData(name)
+	if len(manifest) == 0 {
+		return fmt.Errorf("manifest not found: %s", name)
+	}
+
+	var mu sync.Mutex
+	var completed []*apitype.CompletePart
+	push := func() (*apitype.PushResponse, error) {
+		v, err := ollama.Do[*apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{
+			Name:          name,
+			Manifest:      manifest,
+			CompleteParts: completed,
+		})
+		if err != nil {
+			return nil, fmt.Errorf("Do: %w", err)
+		}
+		return v, nil
+	}
+
+	pr, err := push()
+	if err != nil {
+		return err
+	}
+
+	var g errgroup.Group
+	for _, need := range pr.Needs {
+		g.Go(func() error {
+			f, err := cache.OpenLayer(need.Digest)
+			if err != nil {
+				return fmt.Errorf("OpenLayer: %w: %s", err, need.Digest)
+			}
+			defer f.Close()
+
+			cp, err := PushLayer(ctx, f, need.URL, need.Start, need.End)
+			if err != nil {
+				return fmt.Errorf("PushLayer: %w: %s", err, need.Digest)
+			}
+			mu.Lock()
+			completed = append(completed, cp)
+			mu.Unlock()
+			return nil
+		})
+	}
+	if err := g.Wait(); err != nil {
+		return fmt.Errorf("Push: Required: %w", err)
+	}
+
+	if len(completed) > 0 {
+		pr, err := push()
+		if err != nil {
+			return err
+		}
+		if len(pr.Needs) > 0 {
+			var errs []error
+			for _, r := range pr.Needs {
+				errs = append(errs, fmt.Errorf("Push: server failed to find part: %q", r.Digest))
+			}
+			return errors.Join(errs...)
+		}
+	}
+
+	return cache.SetManifestData(name, manifest)
+}
+
+func PushLayer(ctx context.Context, body io.ReaderAt, url string, start, end int64) (*apitype.CompletePart, error) {
+	if start < 0 || end < start {
+		return nil, errors.New("start must satisfy 0 <= start <= end")
+	}
+
+	file := io.NewSectionReader(body, start, end-start+1)
+	req, err := http.NewRequest("PUT", url, file)
+	if err != nil {
+		return nil, err
+	}
+	req.ContentLength = end - start + 1
+
+	// TODO(bmizerany): take content type param
+	req.Header.Set("Content-Type", "text/plain")
+
+	if start != 0 || end != 0 {
+		req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", start, end))
+	}
+
+	res, err := http.DefaultClient.Do(req)
+	if err != nil {
+		return nil, err
+	}
+	defer res.Body.Close()
+	if res.StatusCode != 200 {
+		e := parseS3Error(res)
+		return nil, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e)
+	}
+	cp := &apitype.CompletePart{
+		URL:  url,
+		ETag: res.Header.Get("ETag"),
+		// TODO(bmizerany): checksum
+	}
+	return cp, nil
+}
+
+type s3Error struct {
+	XMLName   xml.Name `xml:"Error"`
+	Code      string   `xml:"Code"`
+	Message   string   `xml:"Message"`
+	Resource  string   `xml:"Resource"`
+	RequestId string   `xml:"RequestId"`
+}
+
+func (e *s3Error) Error() string {
+	return fmt.Sprintf("S3 (%s): %s: %s: %s", e.RequestId, e.Resource, e.Code, e.Message)
+}
+
+// parseS3Error parses an XML error response from S3.
+func parseS3Error(res *http.Response) error {
+	var se *s3Error
+	if err := xml.NewDecoder(res.Body).Decode(&se); err != nil {
+		return err
+	}
+	return se
+}
+
+// TODO: replace below by using upload pkg after we have rangefunc; until
+// then, we need to keep this free of rangefunc for now.
+type chunkRange[I constraints.Integer] struct {
+	// Start is the byte offset of the chunk.
+	Start I
+
+	// End is the byte offset of the last byte in the chunk.
+	End I
+}
+
+func (c chunkRange[I]) Size() I {
+	return c.End - c.Start + 1
+}
+
+func (c chunkRange[I]) String() string {
+	return fmt.Sprintf("%d-%d", c.Start, c.End)
+}
+
+func (c chunkRange[I]) LogValue() slog.Value {
+	return slog.StringValue(c.String())
+}
+
+// Chunks yields a sequence of a part number and a Chunk. The Chunk is the offset
+// and size of the chunk. The last chunk may be smaller than chunkSize if size is
+// not a multiple of chunkSize.
+//
+// The first part number is 1 and increases monotonically.
+func chunks[I constraints.Integer](size, chunkSize I) iter.Seq2[int, chunkRange[I]] {
+	return func(yield func(int, chunkRange[I]) bool) {
+		var n int
+		for off := I(0); off < size; off += chunkSize {
+			n++
+			if !yield(n, chunkRange[I]{
+				Start: off,
+				End:   off + min(chunkSize, size-off) - 1,
+			}) {
+				return
+			}
+		}
+	}
+}
+
+func redactAmzSignature(s string) string {
+	u, err := url.Parse(s)
+	if err != nil {
+		return ""
+	}
+	q := u.Query()
+	q.Set("X-Amz-Signature", "REDACTED")
+	u.RawQuery = q.Encode()
+	return u.String()
+}