Преглед изворни кода

add `--upgrade-all` flag to refresh any stale models

Patrick Devine пре 1 година
родитељ
комит
021b1bdc4a
5 измењених фајлова са 99 додато и 14 уклоњено
  1. 7 5
      api/types.go
  2. 61 3
      cmd/cmd.go
  3. 4 0
      progress/progress.go
  4. 25 5
      server/images.go
  5. 2 1
      server/routes.go

+ 7 - 5
api/types.go

@@ -183,11 +183,12 @@ type CopyRequest struct {
 }
 
 type PullRequest struct {
-	Model    string `json:"model"`
-	Insecure bool   `json:"insecure,omitempty"`
-	Username string `json:"username"`
-	Password string `json:"password"`
-	Stream   *bool  `json:"stream,omitempty"`
+	Model         string `json:"model"`
+	Insecure      bool   `json:"insecure,omitempty"`
+	Username      string `json:"username"`
+	Password      string `json:"password"`
+	Stream        *bool  `json:"stream,omitempty"`
+	CurrentDigest string `json:"current_digest,omitempty"`
 
 	// Name is deprecated, see Model
 	Name string `json:"name"`
@@ -241,6 +242,7 @@ type GenerateResponse struct {
 
 type ModelDetails struct {
 	ParentModel       string   `json:"parent_model"`
+	Digest            string   `json:"digest"`
 	Format            string   `json:"format"`
 	Family            string   `json:"family"`
 	Families          []string `json:"families"`

+ 61 - 3
cmd/cmd.go

@@ -11,6 +11,7 @@ import (
 	"fmt"
 	"io"
 	"log"
+	"log/slog"
 	"net"
 	"net/http"
 	"os"
@@ -357,6 +358,62 @@ func CopyHandler(cmd *cobra.Command, args []string) error {
 }
 
 func PullHandler(cmd *cobra.Command, args []string) error {
+	upgradeAll, err := cmd.Flags().GetBool("upgrade-all")
+	if err != nil {
+		return err
+	}
+
+	if !upgradeAll {
+		if len(args) == 0 {
+			return fmt.Errorf("no model specified to pull")
+		}
+		return pull(cmd, args[0], "")
+	}
+
+	fp, err := server.GetManifestPath()
+	if err != nil {
+		return err
+	}
+
+	type modelInfo struct {
+		Name   string
+		Digest string
+	}
+
+	var modelList []modelInfo
+
+	walkFunc := func(path string, info os.FileInfo, _ error) error {
+		if info.IsDir() {
+			return nil
+		}
+
+		dir, file := filepath.Split(path)
+		dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
+		tag := strings.Join([]string{dir, file}, ":")
+
+		model, err := server.GetModel(tag)
+		if err != nil {
+			return nil
+		}
+
+		modelList = append(modelList, modelInfo{tag, "sha256:" + model.Digest})
+		return nil
+	}
+
+	if err = filepath.Walk(fp, walkFunc); err != nil {
+		return err
+	}
+
+	for _, m := range modelList {
+		err = pull(cmd, m.Name, m.Digest)
+		if err != nil {
+			slog.Warn(fmt.Sprintf("couldn't pull model '%s'", m.Name))
+		}
+	}
+	return nil
+}
+
+func pull(cmd *cobra.Command, name string, currentDigest string) error {
 	insecure, err := cmd.Flags().GetBool("insecure")
 	if err != nil {
 		return err
@@ -368,7 +425,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
 	}
 
 	p := progress.NewProgress(os.Stderr)
-	defer p.Stop()
+	defer p.StopWithoutClear()
 
 	bars := make(map[string]*progress.Bar)
 
@@ -402,7 +459,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
 		return nil
 	}
 
-	request := api.PullRequest{Name: args[0], Insecure: insecure}
+	request := api.PullRequest{Name: name, Insecure: insecure, CurrentDigest: currentDigest}
 	if err := client.Pull(cmd.Context(), &request, fn); err != nil {
 		return err
 	}
@@ -884,12 +941,13 @@ func NewCLI() *cobra.Command {
 	pullCmd := &cobra.Command{
 		Use:     "pull MODEL",
 		Short:   "Pull a model from a registry",
-		Args:    cobra.ExactArgs(1),
+		Args:    cobra.RangeArgs(0, 1),
 		PreRunE: checkServerHeartbeat,
 		RunE:    PullHandler,
 	}
 
 	pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
+	pullCmd.Flags().Bool("upgrade-all", false, "Upgrade all models if they're out of date")
 
 	pushCmd := &cobra.Command{
 		Use:     "push MODEL",

+ 4 - 0
progress/progress.go

@@ -52,6 +52,10 @@ func (p *Progress) Stop() bool {
 	return stopped
 }
 
+func (p *Progress) StopWithoutClear() bool {
+	return p.stop()
+}
+
 func (p *Progress) StopAndClear() bool {
 	fmt.Fprint(p.w, "\033[?25l")
 	defer fmt.Fprint(p.w, "\033[?25h")

+ 25 - 5
server/images.go

@@ -471,7 +471,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 				switch {
 				case errors.Is(err, os.ErrNotExist):
 					fn(api.ProgressResponse{Status: "pulling model"})
-					if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
+					if err := PullModel(ctx, c.Args, "", &RegistryOptions{}, fn); err != nil {
 						return err
 					}
 
@@ -1041,7 +1041,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 	return nil
 }
 
-func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
+func PullModel(ctx context.Context, name, currentDigest string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
 	mp := ParseModelPath(name)
 
 	var manifest *ManifestV2
@@ -1069,13 +1069,23 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 		return fmt.Errorf("insecure protocol http")
 	}
 
-	fn(api.ProgressResponse{Status: "pulling manifest"})
+	if currentDigest == "" {
+		fn(api.ProgressResponse{Status: "pulling manifest"})
+	}
 
-	manifest, err = pullModelManifest(ctx, mp, regOpts)
+	manifest, err = pullModelManifest(ctx, mp, currentDigest, regOpts)
 	if err != nil {
 		return fmt.Errorf("pull model manifest: %s", err)
 	}
 
+	if currentDigest != "" {
+		if manifest == nil {
+			// we already have the model
+			return nil
+		}
+		fn(api.ProgressResponse{Status: "upgrading " + mp.GetShortTagname()})
+	}
+
 	var layers []*Layer
 	layers = append(layers, manifest.Layers...)
 	layers = append(layers, manifest.Config)
@@ -1147,17 +1157,27 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 	return nil
 }
 
-func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
+func pullModelManifest(ctx context.Context, mp ModelPath, currentDigest string, regOpts *RegistryOptions) (*ManifestV2, error) {
 	requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
 
 	headers := make(http.Header)
 	headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
+
+	if currentDigest != "" {
+		headers.Set("If-None-Match", currentDigest)
+	}
+
 	resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
 	if err != nil {
 		return nil, err
 	}
 	defer resp.Body.Close()
 
+	// todo we can potentially read the manifest locally and return it here
+	if resp.StatusCode == http.StatusNotModified {
+		return nil, nil
+	}
+
 	var m *ManifestV2
 	if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
 		return nil, err

+ 2 - 1
server/routes.go

@@ -451,7 +451,7 @@ func PullModelHandler(c *gin.Context) {
 		ctx, cancel := context.WithCancel(c.Request.Context())
 		defer cancel()
 
-		if err := PullModel(ctx, model, regOpts, fn); err != nil {
+		if err := PullModel(ctx, model, req.CurrentDigest, regOpts, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}
 	}()
@@ -673,6 +673,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 
 	modelDetails := api.ModelDetails{
 		ParentModel:       model.ParentModel,
+		Digest:            "sha256:" + model.Digest,
 		Format:            model.Config.ModelFormat,
 		Family:            model.Config.ModelFamily,
 		Families:          model.Config.ModelFamilies,