Browse Source

refactor layer pruning

Michael Yang 8 months ago
parent
commit
745706c765
5 changed files with 170 additions and 119 deletions
  1. 1 108
      server/images.go
  2. 41 1
      server/layer.go
  3. 37 2
      server/manifest.go
  4. 5 8
      server/routes.go
  5. 86 0
      server/routes_test.go

+ 1 - 108
server/images.go

@@ -501,7 +501,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 						return false
 					}
 
-					if err := layer.Remove(); err != nil {
+					if err := layer.Prune(); err != nil {
 						return false
 					}
 
@@ -689,113 +689,6 @@ func CopyModel(src, dst model.Name) error {
 	return err
 }
 
-func deleteUnusedLayers(deleteMap map[string]struct{}) error {
-	manifests, err := Manifests()
-	if err != nil {
-		return err
-	}
-
-	for _, manifest := range manifests {
-		for _, layer := range manifest.Layers {
-			delete(deleteMap, layer.Digest)
-		}
-
-		delete(deleteMap, manifest.Config.Digest)
-	}
-
-	// only delete the files which are still in the deleteMap
-	for k := range deleteMap {
-		fp, err := GetBlobsPath(k)
-		if err != nil {
-			slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
-			continue
-		}
-		if err := os.Remove(fp); err != nil {
-			slog.Info(fmt.Sprintf("couldn't remove file '%s': %v", fp, err))
-			continue
-		}
-	}
-
-	return nil
-}
-
-func PruneLayers() error {
-	deleteMap := make(map[string]struct{})
-	p, err := GetBlobsPath("")
-	if err != nil {
-		return err
-	}
-
-	blobs, err := os.ReadDir(p)
-	if err != nil {
-		slog.Info(fmt.Sprintf("couldn't read dir '%s': %v", p, err))
-		return err
-	}
-
-	for _, blob := range blobs {
-		name := blob.Name()
-		name = strings.ReplaceAll(name, "-", ":")
-
-		_, err := GetBlobsPath(name)
-		if err != nil {
-			if errors.Is(err, ErrInvalidDigestFormat) {
-				// remove invalid blobs (e.g. partial downloads)
-				if err := os.Remove(filepath.Join(p, blob.Name())); err != nil {
-					slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err)
-				}
-			}
-
-			continue
-		}
-
-		deleteMap[name] = struct{}{}
-	}
-
-	slog.Info(fmt.Sprintf("total blobs: %d", len(deleteMap)))
-
-	if err := deleteUnusedLayers(deleteMap); err != nil {
-		slog.Error(fmt.Sprintf("couldn't remove unused layers: %v", err))
-		return nil
-	}
-
-	slog.Info(fmt.Sprintf("total unused blobs removed: %d", len(deleteMap)))
-
-	return nil
-}
-
-func PruneDirectory(path string) error {
-	info, err := os.Lstat(path)
-	if err != nil {
-		return err
-	}
-
-	if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
-		entries, err := os.ReadDir(path)
-		if err != nil {
-			return err
-		}
-
-		for _, entry := range entries {
-			if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
-				return err
-			}
-		}
-
-		entries, err = os.ReadDir(path)
-		if err != nil {
-			return err
-		}
-
-		if len(entries) > 0 {
-			return nil
-		}
-
-		return os.Remove(path)
-	}
-
-	return nil
-}
-
 func PushModel(ctx context.Context, name model.Name, opts registryOptions, fn func(api.ProgressResponse)) error {
 	m, err := ParseNamedManifest(name)
 	if err != nil {

+ 41 - 1
server/layer.go

@@ -5,7 +5,10 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"log/slog"
 	"os"
+	"path/filepath"
+	"strings"
 )
 
 type Layer struct {
@@ -101,7 +104,8 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
 	return os.Open(blob)
 }
 
-func (l *Layer) Remove() error {
+// Prune removes the layer from the filesystem if it is not referenced any manifest.
+func (l *Layer) Prune() error {
 	if l.Digest == "" {
 		return nil
 	}
@@ -125,5 +129,41 @@ func (l *Layer) Remove() error {
 		return err
 	}
 
+	slog.Debug("pruning layer", "digest", l.Digest)
 	return os.Remove(blob)
 }
+
+func Layers() (map[string]Layer, error) {
+	blobs, err := GetBlobsPath("")
+	if err != nil {
+		return nil, err
+	}
+
+	// TODO(mxyng): use something less brittle
+	matches, err := filepath.Glob(filepath.Join(blobs, "*"))
+	if err != nil {
+		return nil, err
+	}
+
+	layers := make(map[string]Layer)
+	for _, match := range matches {
+		rel, err := filepath.Rel(blobs, match)
+		if err != nil {
+			slog.Warn("bad filepath", "path", match, "error", err)
+			continue
+		}
+
+		// TODO(mxyng): this should ideally use model.Digest but
+		// that's currently incompatible with the manifest digest
+		digest := strings.Replace(rel, "sha256-", "sha256:", 1)
+		layer, err := NewLayerFromLayer(digest, "", "")
+		if err != nil {
+			slog.Warn("bad blob", "digest", digest, "error", err)
+			layer = Layer{Digest: rel}
+		}
+
+		layers[digest] = layer
+	}
+
+	return layers, nil
+}

+ 37 - 2
server/manifest.go

@@ -43,13 +43,13 @@ func (m *Manifest) Remove() error {
 		return err
 	}
 
-	return PruneDirectory(manifests)
+	return pruneEmptyDirectory(manifests)
 }
 
 func (m *Manifest) RemoveLayers() error {
 	for _, layer := range append(m.Layers, m.Config) {
 		if layer.Digest != "" {
-			if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
+			if err := layer.Prune(); errors.Is(err, os.ErrNotExist) {
 				slog.Debug("layer does not exist", "digest", layer.Digest)
 			} else if err != nil {
 				return err
@@ -169,3 +169,38 @@ func Manifests() (map[model.Name]*Manifest, error) {
 
 	return ms, nil
 }
+
+func pruneEmptyDirectory(p string) error {
+	fi, err := os.Lstat(p)
+	if err != nil {
+		return err
+	}
+
+	if fi.Mode()&os.ModeSymlink == 0 {
+		entries, err := os.ReadDir(p)
+		if err != nil {
+			return err
+		}
+
+		for _, entry := range entries {
+			if entry.IsDir() {
+				if err := pruneEmptyDirectory(filepath.Join(p, entry.Name())); err != nil {
+					return err
+				}
+			}
+		}
+
+		entries, err = os.ReadDir(p)
+		if err != nil {
+			return err
+		}
+
+		if len(entries) == 0 {
+			if err := os.Remove(p); err != nil {
+				return err
+			}
+		}
+	}
+
+	return nil
+}

+ 5 - 8
server/routes.go

@@ -1131,18 +1131,15 @@ func Serve(ln net.Listener) error {
 	}
 
 	if !envconfig.NoPrune() {
-		// clean up unused layers and manifests
-		if err := PruneLayers(); err != nil {
-			return err
-		}
-
-		manifestsPath, err := GetManifestPath()
+		layers, err := Layers()
 		if err != nil {
 			return err
 		}
 
-		if err := PruneDirectory(manifestsPath); err != nil {
-			return err
+		for _, layer := range layers {
+			if err := layer.Prune(); err != nil {
+				return err
+			}
 		}
 	}
 

+ 86 - 0
server/routes_test.go

@@ -5,16 +5,21 @@ import (
 	"context"
 	"encoding/binary"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"math"
+	"net"
 	"net/http"
 	"net/http/httptest"
 	"os"
+	"path/filepath"
 	"sort"
 	"strings"
 	"testing"
+	"time"
 
+	"github.com/gin-gonic/gin"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
 
@@ -452,3 +457,84 @@ func TestNormalize(t *testing.T) {
 		})
 	}
 }
+
+func TestServe(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+	p := t.TempDir()
+	t.Setenv("OLLAMA_MODELS", p)
+	var s Server
+
+	// seed some models
+	createRequest(t, s.CreateHandler, api.CreateRequest{
+		Name:      "test-model",
+		Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
+	})
+
+	createRequest(t, s.CreateHandler, api.CreateRequest{
+		Name:      "test-model-2",
+		Modelfile: "FROM test-model\nSYSTEM You are a good robot.",
+	})
+
+	createRequest(t, s.CreateHandler, api.CreateRequest{
+		Name:      "test-model-3",
+		Modelfile: "FROM test-model\nSYSTEM You are a bad robot.",
+	})
+
+	checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
+		filepath.Join(p, "blobs", "sha256-1c515c46e60f849c6aeffa86e256508ac450464762a31ca08648e418f07c9819"),
+		filepath.Join(p, "blobs", "sha256-461fd034bb72312965d46160399b1b882c6a2f8c7305237ed7dd65f848fba10c"),
+		filepath.Join(p, "blobs", "sha256-66e9776a5bb7e5f6093681aa8ba01a7a6b6ae1dd697281f11fa714eaa948a6a4"),
+		filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
+		filepath.Join(p, "blobs", "sha256-b3a5b5b438604c5103ba403a5455af94ea98494b5bbc177f4665716a37b99c1e"),
+		filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
+	})
+
+	ln, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer ln.Close()
+
+	//nolint:errcheck
+	go Serve(ln)
+
+	// wait for server to be healthy (GET / => 200)
+	ctx, cancel := context.WithTimeout(context.TODO(), time.Second)
+	defer cancel()
+
+	if err := func() error {
+		tick := time.NewTicker(20 * time.Millisecond)
+		defer tick.Stop()
+
+		for {
+			select {
+			case <-ctx.Done():
+				return errors.New("server did not become healthy")
+			case <-tick.C:
+				r, err := http.Get(fmt.Sprintf("http://%s", ln.Addr()))
+				if err != nil {
+					continue
+				}
+
+				if err := r.Body.Close(); err != nil {
+					return err
+				}
+
+				if r.StatusCode == http.StatusOK {
+					return nil
+				}
+			}
+		}
+	}(); err != nil {
+		t.Fatal(err)
+	}
+
+	checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
+		filepath.Join(p, "blobs", "sha256-1c515c46e60f849c6aeffa86e256508ac450464762a31ca08648e418f07c9819"),
+		filepath.Join(p, "blobs", "sha256-461fd034bb72312965d46160399b1b882c6a2f8c7305237ed7dd65f848fba10c"),
+		filepath.Join(p, "blobs", "sha256-66e9776a5bb7e5f6093681aa8ba01a7a6b6ae1dd697281f11fa714eaa948a6a4"),
+		filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
+		filepath.Join(p, "blobs", "sha256-b3a5b5b438604c5103ba403a5455af94ea98494b5bbc177f4665716a37b99c1e"),
+		filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
+	})
+}