Przeglądaj źródła

quantize percentage

Josh Yan 9 miesięcy temu
rodzic
commit
e87eafe5cd
3 zmienionych plików z 21 dodań i 3 usunięć
  1. 1 0
      api/types.go
  2. 8 0
      cmd/cmd.go
  3. 12 3
      server/images.go

+ 1 - 0
api/types.go

@@ -267,6 +267,7 @@ type PullRequest struct {
 type ProgressResponse struct {
 	Status    string `json:"status"`
 	Digest    string `json:"digest,omitempty"`
+	Quantize  string `json:"quantize,omitempty"`
 	Total     int64  `json:"total,omitempty"`
 	Completed int64  `json:"completed,omitempty"`
 }

+ 8 - 0
cmd/cmd.go

@@ -125,6 +125,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	}
 
 	bars := make(map[string]*progress.Bar)
+	var quantizeSpin *progress.Spinner
 	fn := func(resp api.ProgressResponse) error {
 		if resp.Digest != "" {
 			spinner.Stop()
@@ -137,6 +138,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 			}
 
 			bar.Set(resp.Completed)
+		} else if resp.Quantize != "" {
+			if quantizeSpin != nil {
+				quantizeSpin.SetMessage(resp.Status)
+			} else {
+				quantizeSpin = progress.NewSpinner(resp.Status)
+				p.Add("quantize", quantizeSpin)
+			}
 		} else if status != resp.Status {
 			spinner.Stop()
 

+ 12 - 3
server/images.go

@@ -413,7 +413,8 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 				return fmt.Errorf("invalid model reference: %s", c.Args)
 			}
 
-			for _, baseLayer := range baseLayers {
+			layerCount := len(baseLayers)
+			for i, baseLayer := range baseLayers {
 				if quantization != "" &&
 					baseLayer.MediaType == "application/vnd.ollama.image.model" &&
 					baseLayer.GGML != nil &&
@@ -427,8 +428,6 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 					if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
 						return errors.New("quantization is only supported for F16 and F32 models")
 					} else if want != ft {
-						fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantization)})
-
 						blob, err := GetBlobsPath(baseLayer.Digest)
 						if err != nil {
 							return err
@@ -472,8 +471,18 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 					config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
 				}
 
+				fn(api.ProgressResponse{
+					Status:   fmt.Sprintf("quantizing model %d%%", i*100/layerCount),
+					Quantize: quantization,
+				})
+
 				layers = append(layers, baseLayer.Layer)
 			}
+
+			fn(api.ProgressResponse{
+				Status:   fmt.Sprintf("quantizing model %d%%", 100),
+				Quantize: quantization,
+			})
 		case "license", "template", "system":
 			if c.Name != "license" {
 				// replace