Browse Source

x/registry: fixing tests wip

Blake Mizerany 1 năm trước cách đây
mục cha
commit
f7cfe946dc
3 tập tin đã thay đổi với 67 bổ sung76 xóa
  1. 29 8
      x/registry/client.go
  2. 0 2
      x/registry/server.go
  3. 38 66
      x/registry/server_test.go

+ 29 - 8
x/registry/client.go

@@ -4,9 +4,11 @@ import (
 	"cmp"
 	"context"
 	"encoding/xml"
+	"errors"
 	"fmt"
 	"io"
 	"net/http"
+	"strings"
 
 	"bllamo.com/client/ollama"
 	"bllamo.com/registry/apitype"
@@ -40,23 +42,42 @@ func (c *Client) Push(ctx context.Context, ref string, manifest []byte, p *PushP
 	return v.Requirements, nil
 }
 
-func PushLayer(ctx context.Context, dstURL string, off, size int64, file io.ReaderAt) (etag string, err error) {
-	sr := io.NewSectionReader(file, off, size)
-	req, err := http.NewRequestWithContext(ctx, "PUT", dstURL, sr)
+func PushLayer(ctx context.Context, body io.ReaderAt, url string, off, n int64) (apitype.CompletePart, error) {
+	var zero apitype.CompletePart
+	if off < 0 {
+		return zero, errors.New("off must be >0")
+	}
+
+	file := io.NewSectionReader(body, off, n)
+	req, err := http.NewRequest("PUT", url, file)
 	if err != nil {
-		return "", err
+		return zero, err
+	}
+	req.ContentLength = n
+
+	// TODO(bmizerany): take content type param
+	req.Header.Set("Content-Type", "text/plain")
+
+	if n >= 0 {
+		req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", off, off+n-1))
 	}
-	req.ContentLength = size
 
 	res, err := http.DefaultClient.Do(req)
 	if err != nil {
-		return "", err
+		return zero, err
 	}
 	defer res.Body.Close()
 	if res.StatusCode != 200 {
-		return "", parseS3Error(res)
+		e := parseS3Error(res)
+		return zero, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e)
+	}
+	etag := strings.Trim(res.Header.Get("ETag"), `"`)
+	cp := apitype.CompletePart{
+		URL:  url,
+		ETag: etag,
+		// TODO(bmizerany): checksum
 	}
-	return res.Header.Get("ETag"), nil
+	return cp, nil
 }
 
 type s3Error struct {

+ 0 - 2
x/registry/server.go

@@ -6,7 +6,6 @@ import (
 	"cmp"
 	"context"
 	"errors"
-	"fmt"
 	"log"
 	"net/http"
 	"net/url"
@@ -131,7 +130,6 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
 			PartNumber: partNumber,
 			ETag:       etag,
 		})
-		fmt.Println("uploadID", uploadID, "partNumber", partNumber, "etag", etag)
 		completePartsByUploadID[uploadID] = cp
 	}
 

+ 38 - 66
x/registry/server_test.go

@@ -11,13 +11,11 @@ import (
 	"fmt"
 	"io"
 	"net"
-	"net/http"
 	"net/http/httptest"
 	"net/url"
 	"os"
 	"os/exec"
 	"strconv"
-	"strings"
 	"syscall"
 	"testing"
 	"time"
@@ -30,8 +28,6 @@ import (
 	"kr.dev/diff"
 )
 
-const abc = "abcdefghijklmnopqrstuvwxyz"
-
 func testPush(t *testing.T, chunkSize int64) {
 	t.Run(fmt.Sprintf("chunkSize=%d", chunkSize), func(t *testing.T) {
 		mc := startMinio(t, true)
@@ -71,15 +67,11 @@ func testPush(t *testing.T, chunkSize int64) {
 		for i, r := range requirements {
 			t.Logf("[%d] pushing layer: offset=%d size=%d", i, r.Offset, r.Size)
 
-			body := strings.NewReader(abc)
-			etag, err := PushLayer(context.Background(), r.URL, r.Offset, r.Size, body)
+			cp, err := PushLayer(context.Background(), &abcReader{}, r.URL, r.Offset, r.Size)
 			if err != nil {
 				t.Fatal(err)
 			}
-			uploaded = append(uploaded, apitype.CompletePart{
-				URL:  r.URL,
-				ETag: etag,
-			})
+			uploaded = append(uploaded, cp)
 		}
 
 		requirements, err = c.Push(context.Background(), ref, manifest, &PushParams{
@@ -142,15 +134,8 @@ func testPush(t *testing.T, chunkSize int64) {
 			}
 			t.Logf("[%d] layer info: name=%q l.Size=%d size=%d", i, info.Key, l.Size, info.Size)
 
-			data, err := io.ReadAll(obj)
-			if err != nil {
-				t.Fatal(err)
-			}
-
-			got := string(data)
-			want := abc[:l.Size]
-			if got != want {
-				t.Errorf("[%d] got layer data = %q; want %q", i, got, want)
+			if msg := checkABCs(obj, int(l.Size)); msg != "" {
+				t.Errorf("[%d] %s", i, msg)
 			}
 		}
 	})
@@ -161,44 +146,6 @@ func TestPush(t *testing.T) {
 	testPush(t, 1)
 }
 
-func pushLayer(body io.ReaderAt, url string, off, n int64) (apitype.CompletePart, error) {
-	var zero apitype.CompletePart
-	if off < 0 {
-		return zero, errors.New("off must be >0")
-	}
-
-	file := io.NewSectionReader(body, off, n)
-	req, err := http.NewRequest("PUT", url, file)
-	if err != nil {
-		return zero, err
-	}
-	req.ContentLength = n
-
-	// TODO(bmizerany): take content type param
-	req.Header.Set("Content-Type", "text/plain")
-
-	if n >= 0 {
-		req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", off, off+n-1))
-	}
-
-	res, err := http.DefaultClient.Do(req)
-	if err != nil {
-		return zero, err
-	}
-	defer res.Body.Close()
-	if res.StatusCode != 200 {
-		e := parseS3Error(res)
-		return zero, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e)
-	}
-	etag := strings.Trim(res.Header.Get("ETag"), `"`)
-	cp := apitype.CompletePart{
-		URL:  url,
-		ETag: etag,
-		// TODO(bmizerany): checksum
-	}
-	return cp, nil
-}
-
 // TestBasicPresignS3MultipartReferenceDoNotDelete tests the basic flow of
 // presigning a multipart upload, uploading the parts, and completing the
 // upload. It is for future reference and should not be deleted. This flow
@@ -230,7 +177,7 @@ func TestBasicPresignS3MultipartReferenceDoNotDelete(t *testing.T) {
 		t.Logf("[partNumber=%d]: %v", partNumber, u)
 
 		var body abcReader
-		cp, err := pushLayer(&body, u.String(), c.Offset, c.N)
+		cp, err := PushLayer(context.Background(), &body, u.String(), c.Offset, c.N)
 		if err != nil {
 			t.Fatalf("[partNumber=%d]: %v", partNumber, err)
 		}
@@ -306,7 +253,7 @@ func startMinio(t *testing.T, trace bool) *minio.Client {
 	// explicitly setting trace to true.
 	trace = cmp.Or(trace, os.Getenv("OLLAMA_MINIO_TRACE") != "")
 
-	dir := t.TempDir() + "-keep" // prevent tempdir from auto delete
+	dir := t.TempDir()
 
 	t.Cleanup(func() {
 		// TODO(bmizerany): trim temp dir based on dates so that
@@ -317,19 +264,18 @@ func startMinio(t *testing.T, trace bool) *minio.Client {
 		if err := cmd.Wait(); err != nil {
 			var e *exec.ExitError
 			if errors.As(err, &e) {
-				if !e.Exited() {
-					// died due to our signal
+				if e.Exited() {
 					return
 				}
-				t.Errorf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
-				t.Errorf("startMinio: %s exit status: %v", cmd.Path, e.ExitCode())
-				t.Errorf("startMinio: %s exited: %v", cmd.Path, e.Exited())
-				t.Errorf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
+				t.Logf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
+				t.Logf("startMinio: %s exit status: %v", cmd.Path, e.ExitCode())
+				t.Logf("startMinio: %s exited: %v", cmd.Path, e.Exited())
+				t.Logf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
 			} else {
 				if errors.Is(err, context.Canceled) {
 					return
 				}
-				t.Errorf("startMinio: %s exit error: %v", cmd.Path, err)
+				t.Logf("startMinio: %s exit error: %v", cmd.Path, err)
 			}
 		}
 	}
@@ -343,6 +289,7 @@ func startMinio(t *testing.T, trace bool) *minio.Client {
 	}
 
 	t.Logf(">> minio: minio server %s", dir)
+
 	addr := availableAddr()
 	cmd := exec.CommandContext(ctx, "minio", "server", "--address", addr, dir)
 	cmd.Env = os.Environ()
@@ -463,3 +410,28 @@ func (r *abcReader) ReadAt(p []byte, off int64) (n int, err error) {
 	}
 	return len(p), nil
 }
+
+func checkABCs(r io.Reader, size int) (reason string) {
+	h := sha256.New()
+	n, err := io.CopyN(h, &abcReader{}, int64(size))
+	if err != nil {
+		return err.Error()
+	}
+	if n != int64(size) {
+		panic("short read; should not happen")
+	}
+	want := h.Sum(nil)
+	h = sha256.New()
+	n, err = io.Copy(h, r)
+	if err != nil {
+		return err.Error()
+	}
+	if n != int64(size) {
+		return fmt.Sprintf("got len(r) = %d; want %d", n, size)
+	}
+	got := h.Sum(nil)
+	if !bytes.Equal(got, want) {
+		return fmt.Sprintf("got sum = %x; want %x", got, want)
+	}
+	return ""
+}