|
@@ -0,0 +1,475 @@
|
|
|
+//go:build goexperiment.synctest
|
|
|
+
|
|
|
+package ollama
|
|
|
+
|
|
|
+import (
|
|
|
+ "bufio"
|
|
|
+ "cmp"
|
|
|
+ "context"
|
|
|
+ "errors"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "io/fs"
|
|
|
+ "net"
|
|
|
+ "net/http"
|
|
|
+ "os"
|
|
|
+ "strings"
|
|
|
+ "sync/atomic"
|
|
|
+ "testing"
|
|
|
+ "testing/synctest"
|
|
|
+
|
|
|
+ "github.com/ollama/ollama/server/internal/cache/blob"
|
|
|
+)
|
|
|
+
|
|
|
+func newHTTPClient(cn net.Conn) *http.Client {
|
|
|
+ return &http.Client{
|
|
|
+ Transport: &http.Transport{
|
|
|
+ DialContext: func(context.Context, string, string) (net.Conn, error) {
|
|
|
+ return cn, nil
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+type clientTester struct {
|
|
|
+ t *testing.T
|
|
|
+ rc *Registry
|
|
|
+ sc net.Conn
|
|
|
+ br *bufio.Reader
|
|
|
+ inProgress atomic.Int64
|
|
|
+}
|
|
|
+
|
|
|
+// newClientTester creates a clientTester with a new pipe connection. If the
|
|
|
+// provided cache is nil, a new cache is created with t.TempDir().
|
|
|
+func newClientTester(t *testing.T, c *blob.DiskCache) *clientTester {
|
|
|
+ t.Helper()
|
|
|
+
|
|
|
+ cc, sc := net.Pipe()
|
|
|
+
|
|
|
+ if c == nil {
|
|
|
+ var err error
|
|
|
+ c, err = blob.Open(t.TempDir())
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return &clientTester{
|
|
|
+ t: t,
|
|
|
+ rc: &Registry{
|
|
|
+ Cache: c,
|
|
|
+ ChunkingThreshold: 2, // set low for ease of testing
|
|
|
+ HTTPClient: newHTTPClient(cc),
|
|
|
+ },
|
|
|
+ sc: sc,
|
|
|
+ br: bufio.NewReader(sc),
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (ct *clientTester) setMaxStreams(n int) {
|
|
|
+ ct.rc.MaxStreams = n
|
|
|
+}
|
|
|
+
|
|
|
+func (ct *clientTester) close() {
|
|
|
+ if err := ct.sc.Close(); err != nil {
|
|
|
+ ct.t.Fatal("error closing conn:", err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (ct *clientTester) running() bool {
|
|
|
+ return ct.inProgress.Load() > 0
|
|
|
+}
|
|
|
+
|
|
|
+// pull starts a pull.
|
|
|
+// It tracks the number of in-progress pulls for use with [running].
|
|
|
+// It prefixes the name with "http://example.com/".
|
|
|
+// It does not wait for the pull to complete.
|
|
|
+func (ct *clientTester) pull(ctx context.Context, name string) error {
|
|
|
+ ct.inProgress.Add(1)
|
|
|
+ defer ct.inProgress.Add(-1)
|
|
|
+ return ct.rc.Pull(ctx, fmt.Sprintf("http://example.com/%s", name))
|
|
|
+}
|
|
|
+
|
|
|
+// await reads the next request from the clientTester's bufio.Reader and returns
|
|
|
+// it.
|
|
|
+// If wantPath is not empty, it checks that the request's URL.Path matches
|
|
|
+// wantPath.
|
|
|
+func (ct *clientTester) await(wantPath string) *http.Request {
|
|
|
+ ct.t.Helper()
|
|
|
+ req, err := http.ReadRequest(ct.br)
|
|
|
+ if err != nil {
|
|
|
+ ct.t.Fatal("error reading request:", err)
|
|
|
+ }
|
|
|
+ if wantPath != "" && req.URL.Path != wantPath {
|
|
|
+ ct.t.Fatalf("request = %v; want %v", req.URL.Path, wantPath)
|
|
|
+ }
|
|
|
+ return req
|
|
|
+}
|
|
|
+
|
|
|
+func (ct *clientTester) respond(code int, body string) {
|
|
|
+ ct.t.Helper()
|
|
|
+ err := (&http.Response{
|
|
|
+ ProtoMajor: 1,
|
|
|
+ ProtoMinor: 1,
|
|
|
+
|
|
|
+ StatusCode: code,
|
|
|
+
|
|
|
+ ContentLength: int64(len(body)),
|
|
|
+ Body: io.NopCloser(strings.NewReader(body)),
|
|
|
+ }).Write(ct.sc)
|
|
|
+ if err != nil {
|
|
|
+ ct.t.Fatal("error writing response:", err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+type stringReadCloser struct {
|
|
|
+ length int
|
|
|
+ io.Reader
|
|
|
+}
|
|
|
+
|
|
|
+func (r *stringReadCloser) Close() error { return nil }
|
|
|
+
|
|
|
+func stringBody(format string, args ...any) io.ReadCloser {
|
|
|
+ s := fmt.Sprintf(format, args...)
|
|
|
+ return &stringReadCloser{len(s), strings.NewReader(s)}
|
|
|
+}
|
|
|
+
|
|
|
+func (ct *clientTester) respondWith(res *http.Response) {
|
|
|
+ ct.t.Helper()
|
|
|
+ if b, ok := res.Body.(*stringReadCloser); ok {
|
|
|
+ res.ContentLength = int64(b.length)
|
|
|
+ }
|
|
|
+ if res.Body != nil && res.Body != http.NoBody && res.ContentLength == 0 {
|
|
|
+ panic("response with Body must have ContentLength")
|
|
|
+ }
|
|
|
+ res.ProtoMajor = cmp.Or(res.ProtoMajor, 1)
|
|
|
+ res.ProtoMinor = cmp.Or(res.ProtoMinor, 1)
|
|
|
+ res.StatusCode = cmp.Or(res.StatusCode, 200)
|
|
|
+ err := res.Write(ct.sc)
|
|
|
+ if err != nil {
|
|
|
+ ct.t.Fatal("error writing response:", err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func checkBlob(t *testing.T, c *blob.DiskCache, d blob.Digest, content string) {
|
|
|
+ t.Helper()
|
|
|
+ info, err := c.Get(d)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("Get(%v) = %v", d, err)
|
|
|
+ }
|
|
|
+ if int(info.Size) != len(content) {
|
|
|
+ t.Errorf("info.Size = %v; want 3", info.Size)
|
|
|
+ }
|
|
|
+ data, err := os.ReadFile(c.GetFile(d))
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("ReadFile = %v", err)
|
|
|
+ }
|
|
|
+ if string(data) != content {
|
|
|
+ t.Errorf("data = %q; want abc", data)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestPull(t *testing.T) {
|
|
|
+ t.Run("single", func(t *testing.T) {
|
|
|
+ synctest.Run(func() {
|
|
|
+ ctx := context.Background()
|
|
|
+ ctx = WithTrace(ctx, &Trace{
|
|
|
+ Update: func(l *Layer, n int64, err error) {
|
|
|
+ if errors.Is(err, ErrCached) {
|
|
|
+ t.Errorf("unexpected ErrCached for %v", l.Digest)
|
|
|
+ }
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ c, err := blob.Open(t.TempDir())
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ ct := newClientTester(t, c)
|
|
|
+ defer ct.close()
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ err := ct.pull(ctx, "library/abc")
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("pull = %v", err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ const content = "a" // below chunking threshold
|
|
|
+ sendManifest := func() {
|
|
|
+ ct.respond(200, fmt.Sprintf(`{"layers":[{"digest":%q,"size":%d}]}`, blob.DigestFromBytes(content), len(content)))
|
|
|
+ }
|
|
|
+
|
|
|
+ d := blob.DigestFromBytes(content)
|
|
|
+ ct.await("/v2/library/abc/manifests/latest")
|
|
|
+ sendManifest()
|
|
|
+ synctest.Wait()
|
|
|
+ if !ct.running() {
|
|
|
+ t.Error("pull is not running")
|
|
|
+ }
|
|
|
+
|
|
|
+ // cache should be empty
|
|
|
+ _, err = c.Get(d)
|
|
|
+ if !errors.Is(err, fs.ErrNotExist) {
|
|
|
+ t.Fatalf("Get(%v) = %v; want fs.ErrNotExist", d, err)
|
|
|
+ }
|
|
|
+
|
|
|
+ // blob request/response
|
|
|
+ ct.await("/v2/library/abc/blobs/" + d.String())
|
|
|
+ ct.respond(200, content)
|
|
|
+ synctest.Wait()
|
|
|
+ if ct.running() {
|
|
|
+ t.Error("pull is still running")
|
|
|
+ }
|
|
|
+ checkBlob(t, c, d, content)
|
|
|
+ _, err = c.Resolve("example.com/library/abc:latest")
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("expected manifest to be linked: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ // repull should be cached
|
|
|
+ ctx = WithTrace(ctx, &Trace{
|
|
|
+ Update: func(l *Layer, n int64, err error) {
|
|
|
+ if n > 0 && !errors.Is(err, ErrCached) {
|
|
|
+ t.Errorf("unexpected error: %v", err)
|
|
|
+ }
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ err := ct.pull(ctx, "library/abc")
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("pull = %v", err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ ct.await("/v2/library/abc/manifests/latest")
|
|
|
+ sendManifest()
|
|
|
+ synctest.Wait()
|
|
|
+ if ct.running() {
|
|
|
+ t.Error("pull is still running")
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ t.Run("chunked", func(t *testing.T) {
|
|
|
+ synctest.Run(func() {
|
|
|
+ c, err := blob.Open(t.TempDir())
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ ct := newClientTester(t, c)
|
|
|
+ defer ct.close()
|
|
|
+
|
|
|
+ ctx := WithTrace(t.Context(), &Trace{
|
|
|
+ Update: func(l *Layer, n int64, err error) {
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unexpected error: %v", err)
|
|
|
+ }
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ ct.setMaxStreams(1)
|
|
|
+ err := ct.pull(ctx, "library/abc")
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("pull = %v", err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ const content = "abc" // above chunking threshold
|
|
|
+
|
|
|
+ d := blob.DigestFromBytes(content)
|
|
|
+ sendManifest := func() {
|
|
|
+ ct.respond(200, fmt.Sprintf(`{"layers":[{"digest":%q,"size":%d}]}`, d, len(content)))
|
|
|
+ }
|
|
|
+ ct.await("/v2/library/abc/manifests/latest")
|
|
|
+ sendManifest()
|
|
|
+ ct.await("/v2/library/abc/chunksums/" + d.String())
|
|
|
+
|
|
|
+ s0 := blob.DigestFromBytes("ab")
|
|
|
+ s1 := blob.DigestFromBytes("c")
|
|
|
+ ct.respondWith(&http.Response{
|
|
|
+ Header: http.Header{
|
|
|
+ "Content-Location": []string{"http://example.com/v2/library/abc/blobs/" + d.String()},
|
|
|
+ },
|
|
|
+ Body: stringBody(`
|
|
|
+ %s 0-1
|
|
|
+ %s 2-2
|
|
|
+ `, s0, s1),
|
|
|
+ })
|
|
|
+
|
|
|
+ for i := range 2 {
|
|
|
+ t.Logf("checking range request %d", i)
|
|
|
+
|
|
|
+ req := ct.await("/v2/library/abc/blobs/" + d.String())
|
|
|
+ switch rng := req.Header.Get("Range"); rng {
|
|
|
+ case "bytes=0-1":
|
|
|
+ ct.respond(200, "ab")
|
|
|
+ case "bytes=2-2":
|
|
|
+ ct.respond(200, "c")
|
|
|
+ default:
|
|
|
+ t.Errorf("unexpected range: %q", rng)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ synctest.Wait()
|
|
|
+ if ct.running() {
|
|
|
+ t.Error("pull is still running")
|
|
|
+ }
|
|
|
+ checkBlob(t, ct.rc.Cache, d, content)
|
|
|
+ _, err = c.Resolve("example.com/library/abc:latest")
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("expected manifest to be linked: %v", err)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+ t.Run("errors", func(t *testing.T) {
|
|
|
+ synctest.Run(func() {
|
|
|
+ c, err := blob.Open(t.TempDir())
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ ct := newClientTester(t, c)
|
|
|
+ defer ct.close()
|
|
|
+
|
|
|
+ type update struct {
|
|
|
+ l *Layer
|
|
|
+ n int64
|
|
|
+ err error
|
|
|
+ }
|
|
|
+
|
|
|
+ var got strings.Builder
|
|
|
+ ctx := WithTrace(t.Context(), &Trace{
|
|
|
+ Update: func(l *Layer, n int64, err error) {
|
|
|
+ fmt.Fprintf(&got, "%v %d %v\n", l.Digest.Short(), n, err)
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ ct.setMaxStreams(1)
|
|
|
+ err := ct.pull(ctx, "library/abc")
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("pull = %v", err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ // makeManifest makes a single layer manifest using
|
|
|
+ // content and returns the digest of the content, and
|
|
|
+ // the content of the manifest.
|
|
|
+ makeManifest := func(content string) (blob.Digest, string) {
|
|
|
+ d := blob.DigestFromBytes(content)
|
|
|
+ return d, fmt.Sprintf(`{"layers":[{"digest":%q,"size":%d}]}`, d, len(content))
|
|
|
+ }
|
|
|
+
|
|
|
+ ct.await("/v2/library/abc/manifests/latest")
|
|
|
+ d, man := makeManifest("a")
|
|
|
+ ct.respond(200, man)
|
|
|
+ ct.await("/v2/library/abc/blobs/" + d.String())
|
|
|
+ ct.respond(200, "a")
|
|
|
+ synctest.Wait()
|
|
|
+ if ct.running() {
|
|
|
+ t.Error("pull is still running")
|
|
|
+ }
|
|
|
+ var want strings.Builder
|
|
|
+ want.WriteString(d.Short() + " 0 <nil>\n") // initial announcement
|
|
|
+ want.WriteString(d.Short() + " 1 <nil>\n") // final
|
|
|
+ if got.String() != want.String() {
|
|
|
+ t.Errorf("\ngot:\n%s\nwant:\n%s", got.String(), want.String())
|
|
|
+ }
|
|
|
+
|
|
|
+ // error on manifest fetch
|
|
|
+ got.Reset()
|
|
|
+ done := make(chan error)
|
|
|
+ go func() { done <- ct.pull(ctx, "library/abc") }()
|
|
|
+ ct.await("/v2/library/abc/manifests/latest")
|
|
|
+ ct.respond(400, `some error`)
|
|
|
+ synctest.Wait()
|
|
|
+ if ct.running() {
|
|
|
+ t.Error("pull is still running")
|
|
|
+ }
|
|
|
+ err = <-done
|
|
|
+ if err == nil || !strings.Contains(err.Error(), "some error") {
|
|
|
+ t.Errorf("err = %v; want some error", err)
|
|
|
+ }
|
|
|
+ if got.String() != "" {
|
|
|
+ t.Errorf("\nunexpected traces:\n%s", got.String())
|
|
|
+ }
|
|
|
+
|
|
|
+ // error on blob fetch
|
|
|
+ got.Reset()
|
|
|
+ go func() { done <- ct.pull(ctx, "library/abc") }()
|
|
|
+ ct.await("/v2/library/abc/manifests/latest")
|
|
|
+ d, man = makeManifest("b")
|
|
|
+ ct.respond(200, man)
|
|
|
+ ct.await("/v2/library/abc/blobs/" + d.String())
|
|
|
+ ct.respond(501, `blob store error`)
|
|
|
+ synctest.Wait()
|
|
|
+ if ct.running() {
|
|
|
+ t.Error("pull is still running")
|
|
|
+ }
|
|
|
+ err = <-done
|
|
|
+ if err == nil || !strings.Contains(err.Error(), "blob store error") {
|
|
|
+ t.Errorf("err = %v; want some error", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ // check we get a trace error on blob fetch after some
|
|
|
+ // progress and that one chunksum error does not
|
|
|
+ // prevent the next chunksum request.
|
|
|
+ got.Reset()
|
|
|
+ go func() { done <- ct.pull(ctx, "library/abc") }()
|
|
|
+ ct.await("/v2/library/abc/manifests/latest")
|
|
|
+ d, man = makeManifest("ccc")
|
|
|
+ ct.respond(200, man)
|
|
|
+ ct.await("/v2/library/abc/chunksums/" + d.String())
|
|
|
+ ct.respondWith(&http.Response{
|
|
|
+ Header: http.Header{
|
|
|
+ "Content-Location": []string{"http://example.com/v2/library/abc/blobs/" + d.String()},
|
|
|
+ },
|
|
|
+ Body: stringBody(`
|
|
|
+ %[1]s 0-0
|
|
|
+ %[1]s 1-1
|
|
|
+ %[1]s 2-2
|
|
|
+ `, blob.DigestFromBytes("c")),
|
|
|
+ })
|
|
|
+ req := ct.await("/v2/library/abc/blobs/" + d.String())
|
|
|
+ if rng := req.Header.Get("Range"); rng != "bytes=0-0" {
|
|
|
+ t.Errorf("unexpected range: %q", rng)
|
|
|
+ }
|
|
|
+ ct.respond(200, "c")
|
|
|
+ req = ct.await("/v2/library/abc/blobs/" + d.String())
|
|
|
+ if rng := req.Header.Get("Range"); rng != "bytes=1-1" {
|
|
|
+ t.Errorf("unexpected range: %q", rng)
|
|
|
+ }
|
|
|
+ ct.respond(501, `blob store error`)
|
|
|
+ req = ct.await("/v2/library/abc/blobs/" + d.String())
|
|
|
+ if rng := req.Header.Get("Range"); rng != "bytes=2-2" {
|
|
|
+ t.Errorf("unexpected range: %q", rng)
|
|
|
+ }
|
|
|
+ ct.respond(501, `blob store error`)
|
|
|
+ synctest.Wait()
|
|
|
+ if ct.running() {
|
|
|
+ t.Error("pull is still running")
|
|
|
+ }
|
|
|
+ err = <-done
|
|
|
+ if err == nil || !strings.Contains(err.Error(), "blob store error") {
|
|
|
+ t.Errorf("err = %v; want some error", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ var errorsSeen int
|
|
|
+ for line := range strings.Lines(got.String()) {
|
|
|
+ if strings.Contains(line, "blob store error") {
|
|
|
+ errorsSeen++
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if errorsSeen != 2 {
|
|
|
+ t.Errorf("errorsSeen = %d; want 2", errorsSeen)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+}
|