浏览代码

... wip still broke

Blake Mizerany 1 年之前
父节点
当前提交
def4d902bf
共有 6 个文件被更改,包括 85 次插入20 次删除
  1. 2 2
      x/api/api.go
  2. 11 0
      x/build/blob/ref.go
  3. 11 0
      x/client/ollama/ollama.go
  4. 5 4
      x/oweb/oweb.go
  5. 20 11
      x/registry/server.go
  6. 36 3
      x/registry/server_test.go

+ 2 - 2
x/api/api.go

@@ -15,8 +15,8 @@ import (
 
 // Common API Errors
 var (
-	errUnqualifiedRef = oweb.Mistake("invalid", "name", "must be fully qualified")
-	errRefNotFound    = oweb.Mistake("not_found", "name", "no such model")
+	errUnqualifiedRef = oweb.Invalid("invalid", "name", "must be fully qualified")
+	errRefNotFound    = oweb.Invalid("not_found", "name", "no such model")
 )
 
 type Server struct {

+ 11 - 0
x/build/blob/ref.go

@@ -253,6 +253,17 @@ func ParseRef(s string) Ref {
 	return r
 }
 
+// Complete is the same as ParseRef(s).Complete().
+//
+// Future versions may be faster than calling ParseRef(s).Complete(), so if
+// need to know if a ref is complete and don't need the ref, use this
+// function.
+func Complete(s string) bool {
+	// TODO(bmizerany): fast-path this with a quick scan withput
+	// allocating strings
+	return ParseRef(s).Complete()
+}
+
 func (r Ref) Valid() bool {
 	// Name is required
 	if !isValidPart(r.name) {

+ 11 - 0
x/client/ollama/ollama.go

@@ -92,12 +92,23 @@ type Error struct {
 
 	// Field is the field in the request that caused the error, if any.
 	Field string `json:"field,omitempty"`
+
+	// Value is the value of the field that caused the error, if any.
+	Value string `json:"value,omitempty"`
 }
 
 func (e *Error) Error() string {
 	var b strings.Builder
 	b.WriteString("ollama: ")
 	b.WriteString(e.Code)
+	if e.Field != "" {
+		b.WriteString(" ")
+		b.WriteString(e.Field)
+	}
+	if e.Value != "" {
+		b.WriteString(": ")
+		b.WriteString(e.Value)
+	}
 	if e.Message != "" {
 		b.WriteString(": ")
 		b.WriteString(e.Message)

+ 5 - 4
x/oweb/oweb.go

@@ -21,12 +21,13 @@ func Missing(field string) error {
 	}
 }
 
-func Mistake(code, field, message string) error {
+func Invalid(field, value, format string, args ...any) error {
 	return &ollama.Error{
 		Status:  400,
-		Code:    code,
+		Code:    "invalid",
 		Field:   field,
-		Message: fmt.Sprintf("%s: %s", field, message),
+		Value:   value,
+		Message: fmt.Sprintf(format, args...),
 	}
 }
 
@@ -69,7 +70,7 @@ func DecodeUserJSON[T any](field string, r io.Reader) (*T, error) {
 	if errors.As(err, &se) {
 		msg = fmt.Sprintf("%s (%q) is not a %s", se.Field, se.Value, se.Type)
 	}
-	return nil, Mistake("invalid_json", field, msg)
+	return nil, Invalid("invalid_json", field, "", msg)
 }
 
 func DecodeJSON[T any](r io.Reader) (*T, error) {

+ 20 - 11
x/registry/server.go

@@ -84,7 +84,7 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
 
 	ref := blob.ParseRef(pr.Ref)
 	if !ref.Complete() {
-		return oweb.Mistake("invalid", "name", "must be complete")
+		return oweb.Invalid("name", pr.Ref, "must be complete")
 	}
 
 	m, err := oweb.DecodeUserJSON[apitype.Manifest]("manifest", bytes.NewReader(pr.Manifest))
@@ -107,24 +107,30 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
 		if err != nil {
 			return err
 		}
+
 		q := u.Query()
+
+		// Check if this is a part upload, if not, skip
 		uploadID := q.Get("uploadId")
 		if uploadID == "" {
 			// not a part upload
 			continue
 		}
-		partNumber, err := strconv.Atoi(q.Get("partNumber"))
+
+		// PartNumber is required
+		queryPartNumber := q.Get("partNumber")
+		partNumber, err := strconv.Atoi(queryPartNumber)
 		if err != nil {
-			return oweb.Mistake("invalid", "url", "invalid or missing PartNumber")
+			return oweb.Invalid("partNumber", queryPartNumber, "invalid or missing PartNumber")
 		}
+
+		// ETag is required
 		if mcp.ETag == "" {
-			return oweb.Mistake("invalid", "etag", "missing")
-		}
-		cp, ok := completePartsByUploadID[uploadID]
-		if !ok {
-			cp = completeParts{key: u.Path}
-			completePartsByUploadID[uploadID] = cp
+			return oweb.Missing("etag")
 		}
+
+		cp := completePartsByUploadID[uploadID]
+		cp.key = u.Path
 		cp.parts = append(cp.parts, minio.CompletePart{
 			PartNumber: partNumber,
 			ETag:       mcp.ETag,
@@ -136,8 +142,11 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
 		var zeroOpts minio.PutObjectOptions
 		_, err := mcc.CompleteMultipartUpload(r.Context(), bucketTODO, cp.key, uploadID, cp.parts, zeroOpts)
 		if err != nil {
-			// log and continue; put backpressure on the client
-			log.Printf("error completing upload: %v", err)
+			var e minio.ErrorResponse
+			if errors.As(err, &e) && e.Code == "NoSuchUpload" {
+				return oweb.Invalid("uploadId", uploadID, "unknown uploadId")
+			}
+			return err
 		}
 	}
 

+ 36 - 3
x/registry/server_test.go

@@ -28,6 +28,39 @@ import (
 	"kr.dev/diff"
 )
 
+// const ref = "registry.ollama.ai/x/y:latest+Z"
+// const manifest = `{
+// 	"layers": [
+// 		{"digest": "sha256-1", "size": 1},
+// 		{"digest": "sha256-2", "size": 2},
+// 		{"digest": "sha256-3", "size": 3}
+// 	]
+// }`
+
+// ts := newTestServer(t)
+// ts.pushNotOK(ref, `{}`, &ollama.Error{
+// 	Status:  400,
+// 	Code:    "invalid",
+// 	Message: "name must be fully qualified",
+// })
+
+// ts.push(ref, `{
+// 	"layers": [
+// 		{"digest": "sha256-1", "size": 1},
+// 		{"digest": "sha256-2", "size": 2},
+// 		{"digest": "sha256-3", "size": 3}
+// 	]
+// }`)
+
+type tWriter struct {
+	t *testing.T
+}
+
+func (w tWriter) Write(p []byte) (n int, err error) {
+	w.t.Logf("%s", p)
+	return len(p), nil
+}
+
 func TestPushBasic(t *testing.T) {
 	const MB = 1024 * 1024
 
@@ -41,6 +74,8 @@ func TestPushBasic(t *testing.T) {
 		}
 	}()
 
+	const ref = "registry.ollama.ai/x/y:latest+Z"
+
 	// Upload two small layers and one large layer that will
 	// trigger a multipart upload.
 	manifest := []byte(`{
@@ -49,9 +84,7 @@ func TestPushBasic(t *testing.T) {
 				{"digest": "sha256-2", "size": 2},
 				{"digest": "sha256-3", "size": 11000000}
 			]
-		}`)
-
-	const ref = "registry.ollama.ai/x/y:latest+Z"
+	}`)
 
 	hs := httptest.NewServer(&Server{
 		minioClient:     mc,