Browse Source

use checksum reference

Michael Yang 1 year ago
parent
commit
1901044b07
4 changed files with 20 additions and 12 deletions
  1. 6 7
      api/client.go
  2. 2 3
      cmd/cmd.go
  3. 9 0
      server/images.go
  4. 3 2
      server/routes.go

+ 6 - 7
api/client.go

@@ -297,18 +297,17 @@ func (c *Client) Heartbeat(ctx context.Context) error {
 	return nil
 }
 
-func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) (string, error) {
-	var response CreateBlobResponse
-	if err := c.do(ctx, http.MethodGet, fmt.Sprintf("/api/blobs/%s/path", digest), nil, &response); err != nil {
+func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
+	if err := c.do(ctx, http.MethodHead, fmt.Sprintf("/api/blobs/%s", digest), nil, nil); err != nil {
 		var statusError StatusError
 		if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound {
-			return "", err
+			return err
 		}
 
-		if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, &response); err != nil {
-			return "", err
+		if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil); err != nil {
+			return err
 		}
 	}
 
-	return response.Path, nil
+	return nil
 }

+ 2 - 3
cmd/cmd.go

@@ -91,12 +91,11 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 			bin.Seek(0, io.SeekStart)
 
 			digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
-			path, err = client.CreateBlob(cmd.Context(), digest, bin)
-			if err != nil {
+			if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
 				return err
 			}
 
-			modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte(path))
+			modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte("@"+digest))
 		}
 	}
 

+ 9 - 0
server/images.go

@@ -287,6 +287,15 @@ func CreateModel(ctx context.Context, name string, commands []parser.Command, fn
 
 		switch c.Name {
 		case "model":
+			if strings.HasPrefix(c.Args, "@") {
+				blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
+				if err != nil {
+					return err
+				}
+
+				c.Args = blobPath
+			}
+
 			bin, err := os.Open(realpath(c.Args))
 			if err != nil {
 				// not a file on disk so must be a model reference

+ 3 - 2
server/routes.go

@@ -650,7 +650,7 @@ func CopyModelHandler(c *gin.Context) {
 	}
 }
 
-func GetBlobHandler(c *gin.Context) {
+func HeadBlobHandler(c *gin.Context) {
 	path, err := GetBlobsPath(c.Param("digest"))
 	if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -771,9 +771,10 @@ func Serve(ln net.Listener, allowOrigins []string) error {
 		})
 
 		r.Handle(method, "/api/tags", ListModelsHandler)
-		r.Handle(method, "/api/blobs/:digest/path", GetBlobHandler)
 	}
 
+	r.HEAD("/api/blobs/:digest", HeadBlobHandler)
+
 	log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
 	s := &http.Server{
 		Handler: r,