Sfoglia il codice sorgente

update delete handler to use model.Name

Michael Yang 1 anno fa
parent
commit
a2fc933fed
3 ha cambiato i file con 99 aggiunte e 35 eliminazioni
  1. 23 0
      server/layer.go
  2. 67 9
      server/manifest.go
  3. 9 26
      server/routes.go

+ 23 - 0
server/layer.go

@@ -88,3 +88,26 @@ func (l *Layer) Open() (io.ReadCloser, error) {
 
 
 	return os.Open(blob)
 	return os.Open(blob)
 }
 }
+
+func (l *Layer) Remove() error {
+	ms, err := Manifests()
+	if err != nil {
+		return err
+	}
+
+	for _, m := range ms {
+		for _, layer := range append(m.Layers, m.Config) {
+			if layer.Digest == l.Digest {
+				// something is using this layer
+				return nil
+			}
+		}
+	}
+
+	blob, err := GetBlobsPath(l.Digest)
+	if err != nil {
+		return err
+	}
+
+	return os.Remove(blob)
+}

+ 67 - 9
server/manifest.go

@@ -14,7 +14,9 @@ import (
 
 
 type Manifest struct {
 type Manifest struct {
 	ManifestV2
 	ManifestV2
-	Digest string `json:"-"`
+
+	filepath string
+	digest   string
 }
 }
 
 
 func (m *Manifest) Size() (size int64) {
 func (m *Manifest) Size() (size int64) {
@@ -25,9 +27,28 @@ func (m *Manifest) Size() (size int64) {
 	return
 	return
 }
 }
 
 
-func ParseNamedManifest(name model.Name) (*Manifest, error) {
-	if !name.IsFullyQualified() {
-		return nil, model.Unqualified(name)
+func (m *Manifest) Remove() error {
+	if err := os.Remove(m.filepath); err != nil {
+		return err
+	}
+
+	for _, layer := range append(m.Layers, m.Config) {
+		if err := layer.Remove(); err != nil {
+			return err
+		}
+	}
+
+	manifests, err := GetManifestPath()
+	if err != nil {
+		return err
+	}
+
+	return PruneDirectory(manifests)
+}
+
+func ParseNamedManifest(n model.Name) (*Manifest, error) {
+	if !n.IsFullyQualified() {
+		return nil, model.Unqualified(n)
 	}
 	}
 
 
 	manifests, err := GetManifestPath()
 	manifests, err := GetManifestPath()
@@ -35,20 +56,24 @@ func ParseNamedManifest(name model.Name) (*Manifest, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	var manifest ManifestV2
-	manifestfile, err := os.Open(filepath.Join(manifests, name.Filepath()))
+	p := filepath.Join(manifests, n.Filepath())
+
+	var m ManifestV2
+	f, err := os.Open(p)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
+	defer f.Close()
 
 
 	sha256sum := sha256.New()
 	sha256sum := sha256.New()
-	if err := json.NewDecoder(io.TeeReader(manifestfile, sha256sum)).Decode(&manifest); err != nil {
+	if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
 	return &Manifest{
 	return &Manifest{
-		ManifestV2: manifest,
-		Digest:     fmt.Sprintf("%x", sha256sum.Sum(nil)),
+		ManifestV2: m,
+		filepath:   p,
+		digest:     fmt.Sprintf("%x", sha256sum.Sum(nil)),
 	}, nil
 	}, nil
 }
 }
 
 
@@ -77,3 +102,36 @@ func WriteManifest(name string, config *Layer, layers []*Layer) error {
 
 
 	return os.WriteFile(manifestPath, b.Bytes(), 0o644)
 	return os.WriteFile(manifestPath, b.Bytes(), 0o644)
 }
 }
+
+func Manifests() (map[model.Name]*Manifest, error) {
+	manifests, err := GetManifestPath()
+	if err != nil {
+		return nil, err
+	}
+
+	// TODO(mxyng): use something less brittle
+	matches, err := filepath.Glob(fmt.Sprintf("%s/*/*/*/*", manifests))
+	if err != nil {
+		return nil, err
+	}
+
+	ms := make(map[model.Name]*Manifest)
+	for _, match := range matches {
+		rel, err := filepath.Rel(manifests, match)
+		if err != nil {
+			return nil, err
+		}
+
+		n := model.ParseNameFromFilepath(rel)
+		if n.IsValid() {
+			m, err := ParseNamedManifest(n)
+			if err != nil {
+				return nil, err
+			}
+
+			ms[n] = m
+		}
+	}
+
+	return ms, nil
+}

+ 9 - 26
server/routes.go

@@ -574,48 +574,31 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
 }
 }
 
 
 func (s *Server) DeleteModelHandler(c *gin.Context) {
 func (s *Server) DeleteModelHandler(c *gin.Context) {
-	var req api.DeleteRequest
-	err := c.ShouldBindJSON(&req)
-	switch {
-	case errors.Is(err, io.EOF):
+	var r api.DeleteRequest
+	if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
 		return
 		return
-	case err != nil:
+	} else if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}
 
 
-	var model string
-	if req.Model != "" {
-		model = req.Model
-	} else if req.Name != "" {
-		model = req.Name
-	} else {
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
-		return
-	}
-
-	if err := DeleteModel(model); err != nil {
-		if os.IsNotExist(err) {
-			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)})
-		} else {
-			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		}
+	n := model.ParseName(cmp.Or(r.Model, r.Name))
+	if !n.IsValid() {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
 		return
 		return
 	}
 	}
 
 
-	manifestsPath, err := GetManifestPath()
+	m, err := ParseNamedManifest(n)
 	if err != nil {
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}
 
 
-	if err := PruneDirectory(manifestsPath); err != nil {
+	if err := m.Remove(); err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}
-
-	c.JSON(http.StatusOK, nil)
 }
 }
 
 
 func (s *Server) ShowModelHandler(c *gin.Context) {
 func (s *Server) ShowModelHandler(c *gin.Context) {
@@ -769,7 +752,7 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
 				Model:      n.DisplayShortest(),
 				Model:      n.DisplayShortest(),
 				Name:       n.DisplayShortest(),
 				Name:       n.DisplayShortest(),
 				Size:       m.Size(),
 				Size:       m.Size(),
-				Digest:     m.Digest,
+				Digest:     m.digest,
 				ModifiedAt: info.ModTime(),
 				ModifiedAt: info.ModTime(),
 				Details: api.ModelDetails{
 				Details: api.ModelDetails{
 					Format:            c.ModelFormat,
 					Format:            c.ModelFormat,