Josh Yan 9 months ago
parent
commit
c63b4ecbf7
3 changed files with 48 additions and 8 deletions
  1. 1 1
      llm/llama.cpp
  2. 22 1
      llm/llm.go
  3. 25 6
      server/images.go

+ 1 - 1
llm/llama.cpp

@@ -1 +1 @@
-Subproject commit a8db2a9ce64cd4417f6a312ab61858f17f0f8584
+Subproject commit 7c26775adb579e92b59c82e8084c07a1d0f75e9c

+ 22 - 1
llm/llm.go

@@ -10,6 +10,10 @@ package llm
 // #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
 // #include <stdlib.h>
 // #include "llama.h"
+// bool update_quantize_progress(int progress, void* data) {
+// 	*((int*)data) = progress;
+// 	return true;
+// }
 import "C"
 import (
 	"fmt"
@@ -21,7 +25,7 @@ func SystemInfo() string {
 	return C.GoString(C.llama_print_system_info())
 }
 
-func Quantize(infile, outfile string, ftype fileType) error {
+func Quantize(infile, outfile string, ftype fileType, count *int) error {
 	cinfile := C.CString(infile)
 	defer C.free(unsafe.Pointer(cinfile))
 
@@ -32,6 +36,23 @@ func Quantize(infile, outfile string, ftype fileType) error {
 	params.nthread = -1
 	params.ftype = ftype.Value()
 
+	// Initialize "global" to store progress
+	store := C.malloc(C.sizeof(int))
+
+	params.quantize_callback_data = store
+	params.quantize_callback = C.update_quantize_progress
+
+	go func () {
+		for {
+			time.Sleep(60 * time.Millisecond)
+			if params.quantize_callback_data == nil {
+				return
+			} else {
+				*count = int(*(*C.int)(store))
+			}
+		}
+	}()
+
 	if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
 		return fmt.Errorf("llama_model_quantize: %d", rc)
 	}

+ 25 - 6
server/images.go

@@ -21,6 +21,7 @@ import (
 	"slices"
 	"strconv"
 	"strings"
+	"time"
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/auth"
@@ -413,6 +414,8 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 				return fmt.Errorf("invalid model reference: %s", c.Args)
 			}
 
+			var quantized int
+			tensorCount := 0
 			for _, baseLayer := range baseLayers {
 				if quantization != "" &&
 					baseLayer.MediaType == "application/vnd.ollama.image.model" &&
@@ -423,11 +426,27 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 						return err
 					}
 
-					tensorCount := len(baseLayer.GGML.Tensors())
-					fn(api.ProgressResponse{
-						Status:   fmt.Sprintf("quantizing model %d tensors", tensorCount),
-						Quantize: quantization,
-					})
+					tensorCount = len(baseLayer.GGML.Tensors())
+					ticker := time.NewTicker(60 * time.Millisecond)
+					done := make(chan struct{})
+					defer close(done)
+
+					go func() {
+						defer ticker.Stop()
+						for {
+							select {
+							case <-ticker.C:
+								fn(api.ProgressResponse{
+									Status: fmt.Sprintf("quantizing model %d%%", quantized*100/tensorCount),
+									Quantize: quantization})
+							case <-done:
+								fn(api.ProgressResponse{
+									Status: "quantizing model",
+									Quantize: quantization})
+								}
+								return
+							}
+					}()
 
 					ft := baseLayer.GGML.KV().FileType()
 					if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
@@ -447,7 +466,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 
 						// Quantizes per layer
 						// Save total quantized tensors
-						if err := llm.Quantize(blob, temp.Name(), want); err != nil {
+						if err := llm.Quantize(blob, temp.Name(), want, &quantized); err != nil {
 							return err
 						}