|
@@ -59,6 +59,11 @@ var (
|
|
|
// ErrCached is passed to [Trace.PushUpdate] when a layer already
|
|
|
// exists. It is a non-fatal error and is never returned by [Registry.Push].
|
|
|
ErrCached = errors.New("cached")
|
|
|
+
|
|
|
+ // ErrIncomplete is returned by [Registry.Pull] when a model pull was
|
|
|
+ // incomplete due to one or more layer download failures. Users that
|
|
|
+ // want specific errors should use [WithTrace].
|
|
|
+ ErrIncomplete = errors.New("incomplete")
|
|
|
)
|
|
|
|
|
|
// Defaults
|
|
@@ -271,8 +276,19 @@ func DefaultRegistry() (*Registry, error) {
|
|
|
|
|
|
func UserAgent() string {
|
|
|
buildinfo, _ := debug.ReadBuildInfo()
|
|
|
+
|
|
|
+ version := buildinfo.Main.Version
|
|
|
+ if version == "(devel)" {
|
|
|
+ // When using `go run .` the version is "(devel)". This is seen
|
|
|
+ // as an invalid version by ollama.com and so it defaults to
|
|
|
+ // "needs upgrade" for some requests, such as pulls. These
|
|
|
+ // checks can be skipped by using the special version "v0.0.0",
|
|
|
+ // so we set it to that here.
|
|
|
+ version = "v0.0.0"
|
|
|
+ }
|
|
|
+
|
|
|
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
|
|
|
- buildinfo.Main.Version,
|
|
|
+ version,
|
|
|
runtime.GOARCH,
|
|
|
runtime.GOOS,
|
|
|
runtime.Version(),
|
|
@@ -418,13 +434,14 @@ func canRetry(err error) bool {
|
|
|
//
|
|
|
// It always calls update with a nil error.
|
|
|
type trackingReader struct {
|
|
|
- r io.Reader
|
|
|
- n *atomic.Int64
|
|
|
+ l *Layer
|
|
|
+ r io.Reader
|
|
|
+ update func(l *Layer, n int64, err error)
|
|
|
}
|
|
|
|
|
|
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
|
|
n, err = r.r.Read(p)
|
|
|
- r.n.Add(int64(n))
|
|
|
+ r.update(r.l, int64(n), nil)
|
|
|
return
|
|
|
}
|
|
|
|
|
@@ -462,16 +479,20 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|
|
|
|
|
// Send initial layer trace events to allow clients to have an
|
|
|
// understanding of work to be done before work starts.
|
|
|
+ var expected int64
|
|
|
t := traceFromContext(ctx)
|
|
|
for _, l := range layers {
|
|
|
t.update(l, 0, nil)
|
|
|
+ expected += l.Size
|
|
|
}
|
|
|
|
|
|
+ var total atomic.Int64
|
|
|
var g errgroup.Group
|
|
|
g.SetLimit(r.maxStreams())
|
|
|
for _, l := range layers {
|
|
|
info, err := c.Get(l.Digest)
|
|
|
if err == nil && info.Size == l.Size {
|
|
|
+ total.Add(l.Size)
|
|
|
t.update(l, l.Size, ErrCached)
|
|
|
continue
|
|
|
}
|
|
@@ -484,21 +505,25 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|
|
// TODO(bmizerany): fix this unbounded use of defer
|
|
|
defer chunked.Close()
|
|
|
|
|
|
- var progress atomic.Int64
|
|
|
for cs, err := range r.chunksums(ctx, name, l) {
|
|
|
if err != nil {
|
|
|
- // Bad chunksums response, update tracing
|
|
|
- // clients and then bail.
|
|
|
- t.update(l, progress.Load(), err)
|
|
|
- return err
|
|
|
+ // Chunksum stream was interrupted, so tell
|
|
|
+ // trace about it, and let in-flight chunk
|
|
|
+ // downloads finish. Once they finish, return
|
|
|
+ // ErrIncomplete, which is triggered by the
|
|
|
+ // fact that the total bytes received is less
|
|
|
+ // than the expected bytes.
|
|
|
+ t.update(l, 0, err)
|
|
|
+ break
|
|
|
}
|
|
|
|
|
|
g.Go(func() (err error) {
|
|
|
defer func() {
|
|
|
- if err != nil {
|
|
|
+ if err == nil || errors.Is(err, ErrCached) {
|
|
|
+ total.Add(cs.Chunk.Size())
|
|
|
+ } else {
|
|
|
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
|
|
|
}
|
|
|
- t.update(l, progress.Load(), err)
|
|
|
}()
|
|
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
|
@@ -522,7 +547,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|
|
// download rate since it knows better than a
|
|
|
// client that is measuring rate based on
|
|
|
// wall-clock time-since-last-update.
|
|
|
- body := &trackingReader{r: res.Body, n: &progress}
|
|
|
+ body := &trackingReader{l: l, r: res.Body, update: t.update}
|
|
|
|
|
|
return chunked.Put(cs.Chunk, cs.Digest, body)
|
|
|
})
|
|
@@ -531,6 +556,9 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|
|
if err := g.Wait(); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
+ if total.Load() != expected {
|
|
|
+ return fmt.Errorf("%w: received %d/%d", ErrIncomplete, total.Load(), expected)
|
|
|
+ }
|
|
|
|
|
|
md := blob.DigestFromBytes(m.Data)
|
|
|
if err := blob.PutBytes(c, md, m.Data); err != nil {
|
|
@@ -757,15 +785,12 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
|
|
|
}
|
|
|
blobURL := res.Header.Get("Content-Location")
|
|
|
|
|
|
- var size int64
|
|
|
s := bufio.NewScanner(res.Body)
|
|
|
s.Split(bufio.ScanWords)
|
|
|
for {
|
|
|
if !s.Scan() {
|
|
|
if s.Err() != nil {
|
|
|
yield(chunksum{}, s.Err())
|
|
|
- } else if size != l.Size {
|
|
|
- yield(chunksum{}, fmt.Errorf("size mismatch: layer size %d != sum of chunks %d", size, l.Size))
|
|
|
}
|
|
|
return
|
|
|
}
|
|
@@ -789,12 +814,6 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- size += chunk.Size()
|
|
|
- if size > l.Size {
|
|
|
- yield(chunksum{}, fmt.Errorf("chunk size %d exceeds layer size %d", size, l.Size))
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
cs := chunksum{
|
|
|
URL: blobURL,
|
|
|
Chunk: chunk,
|