Browse Source

download models when creating from modelfile

Bruce MacDonald 1 year ago
parent
commit
4c1caa3733
5 changed files with 73 additions and 41 deletions
  1. 2 2
      api/client.go
  2. 0 4
      api/types.go
  3. 24 7
      cmd/cmd.go
  4. 45 24
      server/images.go
  5. 2 4
      server/routes.go

+ 2 - 2
api/client.go

@@ -189,11 +189,11 @@ func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc
 	})
 }
 
-type CreateProgressFunc func(CreateProgress) error
+type CreateProgressFunc func(ProgressResponse) error
 
 func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
 	return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
-		var resp CreateProgress
+		var resp ProgressResponse
 		if err := json.Unmarshal(bts, &resp); err != nil {
 			return err
 		}

+ 0 - 4
api/types.go

@@ -40,10 +40,6 @@ type CreateRequest struct {
 	Path string `json:"path"`
 }
 
-type CreateProgress struct {
-	Status string `json:"status"`
-}
-
 type DeleteRequest struct {
 	Name string `json:"name"`
 }

+ 24 - 7
cmd/cmd.go

@@ -36,15 +36,32 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 
 	var spinner *Spinner
 
-	request := api.CreateRequest{Name: args[0], Path: filename}
-	fn := func(resp api.CreateProgress) error {
-		if spinner != nil {
-			spinner.Stop()
-		}
+	var currentDigest string
+	var bar *progressbar.ProgressBar
 
-		spinner = NewSpinner(resp.Status)
-		go spinner.Spin(100 * time.Millisecond)
+	request := api.CreateRequest{Name: args[0], Path: filename}
+	fn := func(resp api.ProgressResponse) error {
+		if resp.Digest != currentDigest && resp.Digest != "" {
+			if spinner != nil {
+				spinner.Stop()
+			}
+			currentDigest = resp.Digest
+			bar = progressbar.DefaultBytes(
+				int64(resp.Total),
+				fmt.Sprintf("pulling %s...", resp.Digest[7:19]),
+			)
 
+			bar.Set(resp.Completed)
+		} else if resp.Digest == currentDigest && resp.Digest != "" {
+			bar.Set(resp.Completed)
+		} else {
+			currentDigest = ""
+			if spinner != nil {
+				spinner.Stop()
+			}
+			spinner = NewSpinner(resp.Status)
+			go spinner.Spin(100 * time.Millisecond)
+		}
 		return nil
 	}
 

+ 45 - 24
server/images.go

@@ -187,15 +187,15 @@ func GetModel(name string) (*Model, error) {
 	return model, nil
 }
 
-func CreateModel(name string, path string, fn func(status string)) error {
+func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error {
 	mf, err := os.Open(path)
 	if err != nil {
-		fn(fmt.Sprintf("couldn't open modelfile '%s'", path))
+		fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)})
 		return fmt.Errorf("failed to open file: %w", err)
 	}
 	defer mf.Close()
 
-	fn("parsing modelfile")
+	fn(api.ProgressResponse{Status: "parsing modelfile"})
 	commands, err := parser.Parse(mf)
 	if err != nil {
 		return err
@@ -208,7 +208,7 @@ func CreateModel(name string, path string, fn func(status string)) error {
 		log.Printf("[%s] - %s\n", c.Name, c.Args)
 		switch c.Name {
 		case "model":
-			fn("looking for model")
+			fn(api.ProgressResponse{Status: "looking for model"})
 			mf, err := GetManifest(ParseModelPath(c.Args))
 			if err != nil {
 				fp := c.Args
@@ -229,20 +229,40 @@ func CreateModel(name string, path string, fn func(status string)) error {
 					fp = filepath.Join(filepath.Dir(path), fp)
 				}
 
-				fn("creating model layer")
-				file, err := os.Open(fp)
-				if err != nil {
-					return fmt.Errorf("failed to open file: %v", err)
-				}
-				defer file.Close()
+				if _, err := os.Stat(fp); err != nil {
+					// the model file does not exist, try pulling it
+					if errors.Is(err, os.ErrNotExist) {
+						fn(api.ProgressResponse{Status: "pulling model file"})
+						if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil {
+							return err
+						}
+						mf, err = GetManifest(ParseModelPath(c.Args))
+						if err != nil {
+							return fmt.Errorf("failed to open file after pull: %v", err)
+						}
+
+					} else {
+						return err
+					}
+				} else {
+					// create a model from this specified file
+					fn(api.ProgressResponse{Status: "creating model layer"})
 
-				l, err := CreateLayer(file)
-				if err != nil {
-					return fmt.Errorf("failed to create layer: %v", err)
+					file, err := os.Open(fp)
+					if err != nil {
+						return fmt.Errorf("failed to open file: %v", err)
+					}
+					defer file.Close()
+
+					l, err := CreateLayer(file)
+					if err != nil {
+						return fmt.Errorf("failed to create layer: %v", err)
+					}
+					l.MediaType = "application/vnd.ollama.image.model"
+					layers = append(layers, l)
 				}
-				l.MediaType = "application/vnd.ollama.image.model"
-				layers = append(layers, l)
-			} else {
+			}
+			if mf != nil {
 				log.Printf("manifest = %#v", mf)
 				for _, l := range mf.Layers {
 					newLayer, err := GetLayerWithBufferFromLayer(l)
@@ -253,7 +273,7 @@ func CreateModel(name string, path string, fn func(status string)) error {
 				}
 			}
 		case "license", "template", "system", "prompt":
-			fn(fmt.Sprintf("creating %s layer", c.Name))
+			fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
 			// remove the prompt layer if one exists
 			mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
 			layers = removeLayerFromLayers(layers, mediaType)
@@ -272,7 +292,7 @@ func CreateModel(name string, path string, fn func(status string)) error {
 
 	// Create a single layer for the parameters
 	if len(params) > 0 {
-		fn("creating parameter layer")
+		fn(api.ProgressResponse{Status: "creating parameter layer"})
 		layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
 		paramData, err := paramsToReader(params)
 		if err != nil {
@@ -297,7 +317,7 @@ func CreateModel(name string, path string, fn func(status string)) error {
 	}
 
 	// Create a layer for the config object
-	fn("creating config layer")
+	fn(api.ProgressResponse{Status: "creating config layer"})
 	cfg, err := createConfigLayer(digests)
 	if err != nil {
 		return err
@@ -310,13 +330,13 @@ func CreateModel(name string, path string, fn func(status string)) error {
 	}
 
 	// Create the manifest
-	fn("writing manifest")
+	fn(api.ProgressResponse{Status: "writing manifest"})
 	err = CreateManifest(name, cfg, manifestLayers)
 	if err != nil {
 		return err
 	}
 
-	fn("success")
+	fn(api.ProgressResponse{Status: "success"})
 	return nil
 }
 
@@ -331,7 +351,7 @@ func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerRead
 	return layers[:j]
 }
 
-func SaveLayers(layers []*LayerReader, fn func(status string), force bool) error {
+func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
 	// Write each of the layers to disk
 	for _, layer := range layers {
 		fp, err := GetBlobsPath(layer.Digest)
@@ -341,7 +361,8 @@ func SaveLayers(layers []*LayerReader, fn func(status string), force bool) error
 
 		_, err = os.Stat(fp)
 		if os.IsNotExist(err) || force {
-			fn(fmt.Sprintf("writing layer %s", layer.Digest))
+			fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
+
 			out, err := os.Create(fp)
 			if err != nil {
 				log.Printf("couldn't create %s", fp)
@@ -354,7 +375,7 @@ func SaveLayers(layers []*LayerReader, fn func(status string), force bool) error
 			}
 
 		} else {
-			fn(fmt.Sprintf("using already created layer %s", layer.Digest))
+			fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
 		}
 	}
 

+ 2 - 4
server/routes.go

@@ -147,10 +147,8 @@ func CreateModelHandler(c *gin.Context) {
 	ch := make(chan any)
 	go func() {
 		defer close(ch)
-		fn := func(status string) {
-			ch <- api.CreateProgress{
-				Status: status,
-			}
+		fn := func(resp api.ProgressResponse) {
+			ch <- resp
 		}
 
 		if err := CreateModel(req.Name, req.Path, fn); err != nil {