Browse Source

cgo quantize

Michael Yang 1 year ago
parent
commit
9502e5661f
6 changed files with 126 additions and 32 deletions
  1. 28 27
      api/types.go
  2. 4 1
      cmd/cmd.go
  3. 71 0
      llm/llm.go
  4. 21 2
      server/images.go
  5. 1 1
      server/routes.go
  6. 1 1
      server/routes_test.go

+ 28 - 27
api/types.go

@@ -109,19 +109,19 @@ type Options struct {
 
 // Runner options which must be set when the model is loaded into memory
 type Runner struct {
-	UseNUMA            bool    `json:"numa,omitempty"`
-	NumCtx             int     `json:"num_ctx,omitempty"`
-	NumBatch           int     `json:"num_batch,omitempty"`
-	NumGQA             int     `json:"num_gqa,omitempty"`
-	NumGPU             int     `json:"num_gpu,omitempty"`
-	MainGPU            int     `json:"main_gpu,omitempty"`
-	LowVRAM            bool    `json:"low_vram,omitempty"`
-	F16KV              bool    `json:"f16_kv,omitempty"`
-	LogitsAll          bool    `json:"logits_all,omitempty"`
-	VocabOnly          bool    `json:"vocab_only,omitempty"`
-	UseMMap            bool    `json:"use_mmap,omitempty"`
-	UseMLock           bool    `json:"use_mlock,omitempty"`
-	NumThread          int     `json:"num_thread,omitempty"`
+	UseNUMA   bool `json:"numa,omitempty"`
+	NumCtx    int  `json:"num_ctx,omitempty"`
+	NumBatch  int  `json:"num_batch,omitempty"`
+	NumGQA    int  `json:"num_gqa,omitempty"`
+	NumGPU    int  `json:"num_gpu,omitempty"`
+	MainGPU   int  `json:"main_gpu,omitempty"`
+	LowVRAM   bool `json:"low_vram,omitempty"`
+	F16KV     bool `json:"f16_kv,omitempty"`
+	LogitsAll bool `json:"logits_all,omitempty"`
+	VocabOnly bool `json:"vocab_only,omitempty"`
+	UseMMap   bool `json:"use_mmap,omitempty"`
+	UseMLock  bool `json:"use_mlock,omitempty"`
+	NumThread int  `json:"num_thread,omitempty"`
 }
 
 type EmbeddingRequest struct {
@@ -137,10 +137,11 @@ type EmbeddingResponse struct {
 }
 
 type CreateRequest struct {
-	Model     string `json:"model"`
-	Path      string `json:"path"`
-	Modelfile string `json:"modelfile"`
-	Stream    *bool  `json:"stream,omitempty"`
+	Model        string `json:"model"`
+	Path         string `json:"path"`
+	Modelfile    string `json:"modelfile"`
+	Stream       *bool  `json:"stream,omitempty"`
+	Quantization string `json:"quantization,omitempty"`
 
 	// Name is deprecated, see Model
 	Name string `json:"name"`
@@ -380,16 +381,16 @@ func DefaultOptions() Options {
 
 		Runner: Runner{
 			// options set when the model is loaded
-			NumCtx:             2048,
-			NumBatch:           512,
-			NumGPU:             -1, // -1 here indicates that NumGPU should be set dynamically
-			NumGQA:             1,
-			NumThread:          0, // let the runtime decide
-			LowVRAM:            false,
-			F16KV:              true,
-			UseMLock:           false,
-			UseMMap:            true,
-			UseNUMA:            false,
+			NumCtx:    2048,
+			NumBatch:  512,
+			NumGPU:    -1, // -1 here indicates that NumGPU should be set dynamically
+			NumGQA:    1,
+			NumThread: 0, // let the runtime decide
+			LowVRAM:   false,
+			F16KV:     true,
+			UseMLock:  false,
+			UseMMap:   true,
+			UseNUMA:   false,
 		},
 	}
 }

+ 4 - 1
cmd/cmd.go

@@ -194,7 +194,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 		return nil
 	}
 
-	request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile)}
+	quantization, _ := cmd.Flags().GetString("quantization")
+
+	request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization}
 	if err := client.Create(cmd.Context(), &request, fn); err != nil {
 		return err
 	}
@@ -943,6 +945,7 @@ func NewCLI() *cobra.Command {
 	}
 
 	createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")")
+	createCmd.Flags().StringP("quantization", "q", "", "Quantization level.")
 
 	showCmd := &cobra.Command{
 		Use:     "show MODEL",

+ 71 - 0
llm/llm.go

@@ -6,10 +6,81 @@ package llm
 // #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
 // #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
 // #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
+// #include <stdlib.h>
 // #include "llama.h"
 import "C"
+import (
+	"fmt"
+	"unsafe"
+)
 
 // SystemInfo is an unused example of calling llama.cpp functions using CGo
 func SystemInfo() string {
 	return C.GoString(C.llama_print_system_info())
 }
+
+func Quantize(infile, outfile, filetype string) error {
+	cinfile := C.CString(infile)
+	defer C.free(unsafe.Pointer(cinfile))
+
+	coutfile := C.CString(outfile)
+	defer C.free(unsafe.Pointer(coutfile))
+
+	params := C.llama_model_quantize_default_params()
+	params.nthread = -1
+
+	switch filetype {
+	case "F32":
+		params.ftype = fileTypeF32
+	case "F16":
+		params.ftype = fileTypeF16
+	case "Q4_0":
+		params.ftype = fileTypeQ4_0
+	case "Q4_1":
+		params.ftype = fileTypeQ4_1
+	case "Q4_1_F16":
+		params.ftype = fileTypeQ4_1_F16
+	case "Q8_0":
+		params.ftype = fileTypeQ8_0
+	case "Q5_0":
+		params.ftype = fileTypeQ5_0
+	case "Q5_1":
+		params.ftype = fileTypeQ5_1
+	case "Q2_K":
+		params.ftype = fileTypeQ2_K
+	case "Q3_K_S":
+		params.ftype = fileTypeQ3_K_S
+	case "Q3_K_M":
+		params.ftype = fileTypeQ3_K_M
+	case "Q3_K_L":
+		params.ftype = fileTypeQ3_K_L
+	case "Q4_K_S":
+		params.ftype = fileTypeQ4_K_S
+	case "Q4_K_M":
+		params.ftype = fileTypeQ4_K_M
+	case "Q5_K_S":
+		params.ftype = fileTypeQ5_K_S
+	case "Q5_K_M":
+		params.ftype = fileTypeQ5_K_M
+	case "Q6_K":
+		params.ftype = fileTypeQ6_K
+	case "IQ2_XXS":
+		params.ftype = fileTypeIQ2_XXS
+	case "IQ2_XS":
+		params.ftype = fileTypeIQ2_XS
+	case "Q2_K_S":
+		params.ftype = fileTypeQ2_K_S
+	case "Q3_K_XS":
+		params.ftype = fileTypeQ3_K_XS
+	case "IQ3_XXS":
+		params.ftype = fileTypeIQ3_XXS
+	default:
+		return fmt.Errorf("unknown filetype: %s", filetype)
+	}
+
+	if retval := C.llama_model_quantize(cinfile, coutfile, &params); retval != 0 {
+		return fmt.Errorf("llama_model_quantize: %d", retval)
+	}
+
+	return nil
+}

+ 21 - 2
server/images.go

@@ -284,7 +284,7 @@ func realpath(mfDir, from string) string {
 	return abspath
 }
 
-func CreateModel(ctx context.Context, name, modelFileDir string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
+func CreateModel(ctx context.Context, name, modelFileDir, quantization string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
 	deleteMap := make(map[string]struct{})
 	if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
 		for _, layer := range append(manifest.Layers, manifest.Config) {
@@ -337,8 +337,27 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 
 			if ggufName != "" {
 				pathName = ggufName
-				slog.Debug(fmt.Sprintf("new image layer path: %s", pathName))
 				defer os.RemoveAll(ggufName)
+
+				if quantization != "" {
+					quantization = strings.ToUpper(quantization)
+					fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", "F16", quantization)})
+					tempfile, err := os.CreateTemp(filepath.Dir(ggufName), quantization)
+					if err != nil {
+						return err
+					}
+					defer os.RemoveAll(tempfile.Name())
+
+					if err := llm.Quantize(ggufName, tempfile.Name(), quantization); err != nil {
+						return err
+					}
+
+					if err := tempfile.Close(); err != nil {
+						return err
+					}
+
+					pathName = tempfile.Name()
+				}
 			}
 
 			bin, err := os.Open(pathName)

+ 1 - 1
server/routes.go

@@ -647,7 +647,7 @@ func CreateModelHandler(c *gin.Context) {
 		ctx, cancel := context.WithCancel(c.Request.Context())
 		defer cancel()
 
-		if err := CreateModel(ctx, model, filepath.Dir(req.Path), commands, fn); err != nil {
+		if err := CreateModel(ctx, model, filepath.Dir(req.Path), req.Quantization, commands, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}
 	}()

+ 1 - 1
server/routes_test.go

@@ -61,7 +61,7 @@ func Test_Routes(t *testing.T) {
 		fn := func(resp api.ProgressResponse) {
 			t.Logf("Status: %s", resp.Status)
 		}
-		err = CreateModel(context.TODO(), name, "", commands, fn)
+		err = CreateModel(context.TODO(), name, "", "", commands, fn)
 		assert.Nil(t, err)
 	}