Browse Source

Merge pull request #5243 from dhiltgen/modelfile_use_mmap

Fix use_mmap for modefiles
Daniel Hiltgen 10 months ago
parent
commit
ccd7785859
3 changed files with 63 additions and 93 deletions
  1. 35 70
      api/types.go
  2. 22 18
      api/types_test.go
  3. 6 5
      llm/server.go

+ 35 - 70
api/types.go

@@ -159,49 +159,18 @@ type Options struct {
 
 // Runner options which must be set when the model is loaded into memory
 type Runner struct {
-	UseNUMA   bool     `json:"numa,omitempty"`
-	NumCtx    int      `json:"num_ctx,omitempty"`
-	NumBatch  int      `json:"num_batch,omitempty"`
-	NumGPU    int      `json:"num_gpu,omitempty"`
-	MainGPU   int      `json:"main_gpu,omitempty"`
-	LowVRAM   bool     `json:"low_vram,omitempty"`
-	F16KV     bool     `json:"f16_kv,omitempty"`
-	LogitsAll bool     `json:"logits_all,omitempty"`
-	VocabOnly bool     `json:"vocab_only,omitempty"`
-	UseMMap   TriState `json:"use_mmap,omitempty"`
-	UseMLock  bool     `json:"use_mlock,omitempty"`
-	NumThread int      `json:"num_thread,omitempty"`
-}
-
-type TriState int
-
-const (
-	TriStateUndefined TriState = -1
-	TriStateFalse     TriState = 0
-	TriStateTrue      TriState = 1
-)
-
-func (b *TriState) UnmarshalJSON(data []byte) error {
-	var v bool
-	if err := json.Unmarshal(data, &v); err != nil {
-		return err
-	}
-	if v {
-		*b = TriStateTrue
-	}
-	*b = TriStateFalse
-	return nil
-}
-
-func (b *TriState) MarshalJSON() ([]byte, error) {
-	if *b == TriStateUndefined {
-		return nil, nil
-	}
-	var v bool
-	if *b == TriStateTrue {
-		v = true
-	}
-	return json.Marshal(v)
+	UseNUMA   bool  `json:"numa,omitempty"`
+	NumCtx    int   `json:"num_ctx,omitempty"`
+	NumBatch  int   `json:"num_batch,omitempty"`
+	NumGPU    int   `json:"num_gpu,omitempty"`
+	MainGPU   int   `json:"main_gpu,omitempty"`
+	LowVRAM   bool  `json:"low_vram,omitempty"`
+	F16KV     bool  `json:"f16_kv,omitempty"`
+	LogitsAll bool  `json:"logits_all,omitempty"`
+	VocabOnly bool  `json:"vocab_only,omitempty"`
+	UseMMap   *bool `json:"use_mmap,omitempty"`
+	UseMLock  bool  `json:"use_mlock,omitempty"`
+	NumThread int   `json:"num_thread,omitempty"`
 }
 
 // EmbeddingRequest is the request passed to [Client.Embeddings].
@@ -444,19 +413,6 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
 				continue
 			}
 
-			if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) {
-				val, ok := val.(bool)
-				if !ok {
-					return fmt.Errorf("option %q must be of type boolean", key)
-				}
-				if val {
-					field.SetInt(int64(TriStateTrue))
-				} else {
-					field.SetInt(int64(TriStateFalse))
-				}
-				continue
-			}
-
 			switch field.Kind() {
 			case reflect.Int:
 				switch t := val.(type) {
@@ -503,6 +459,17 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
 					slice[i] = str
 				}
 				field.Set(reflect.ValueOf(slice))
+			case reflect.Pointer:
+				var b bool
+				if field.Type() == reflect.TypeOf(&b) {
+					val, ok := val.(bool)
+					if !ok {
+						return fmt.Errorf("option %q must be of type boolean", key)
+					}
+					field.Set(reflect.ValueOf(&val))
+				} else {
+					return fmt.Errorf("unknown type loading config params: %v %v", field.Kind(), field.Type())
+				}
 			default:
 				return fmt.Errorf("unknown type loading config params: %v", field.Kind())
 			}
@@ -545,7 +512,7 @@ func DefaultOptions() Options {
 			LowVRAM:   false,
 			F16KV:     true,
 			UseMLock:  false,
-			UseMMap:   TriStateUndefined,
+			UseMMap:   nil,
 			UseNUMA:   false,
 		},
 	}
@@ -615,19 +582,6 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
 		} else {
 			field := valueOpts.FieldByName(opt.Name)
 			if field.IsValid() && field.CanSet() {
-				if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) {
-					boolVal, err := strconv.ParseBool(vals[0])
-					if err != nil {
-						return nil, fmt.Errorf("invalid bool value %s", vals)
-					}
-					if boolVal {
-						out[key] = TriStateTrue
-					} else {
-						out[key] = TriStateFalse
-					}
-					continue
-				}
-
 				switch field.Kind() {
 				case reflect.Float32:
 					floatVal, err := strconv.ParseFloat(vals[0], 32)
@@ -655,6 +609,17 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
 				case reflect.Slice:
 					// TODO: only string slices are supported right now
 					out[key] = vals
+				case reflect.Pointer:
+					var b bool
+					if field.Type() == reflect.TypeOf(&b) {
+						boolVal, err := strconv.ParseBool(vals[0])
+						if err != nil {
+							return nil, fmt.Errorf("invalid bool value %s", vals)
+						}
+						out[key] = &boolVal
+					} else {
+						return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
+					}
 				default:
 					return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
 				}

+ 22 - 18
api/types_test.go

@@ -108,25 +108,27 @@ func TestDurationMarshalUnmarshal(t *testing.T) {
 }
 
 func TestUseMmapParsingFromJSON(t *testing.T) {
+	tr := true
+	fa := false
 	tests := []struct {
 		name string
 		req  string
-		exp  TriState
+		exp  *bool
 	}{
 		{
 			name: "Undefined",
 			req:  `{ }`,
-			exp:  TriStateUndefined,
+			exp:  nil,
 		},
 		{
 			name: "True",
 			req:  `{ "use_mmap": true }`,
-			exp:  TriStateTrue,
+			exp:  &tr,
 		},
 		{
 			name: "False",
 			req:  `{ "use_mmap": false }`,
-			exp:  TriStateFalse,
+			exp:  &fa,
 		},
 	}
 
@@ -144,50 +146,52 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
 }
 
 func TestUseMmapFormatParams(t *testing.T) {
+	tr := true
+	fa := false
 	tests := []struct {
 		name string
 		req  map[string][]string
-		exp  TriState
+		exp  *bool
 		err  error
 	}{
 		{
 			name: "True",
 			req: map[string][]string{
-				"use_mmap": []string{"true"},
+				"use_mmap": {"true"},
 			},
-			exp: TriStateTrue,
+			exp: &tr,
 			err: nil,
 		},
 		{
 			name: "False",
 			req: map[string][]string{
-				"use_mmap": []string{"false"},
+				"use_mmap": {"false"},
 			},
-			exp: TriStateFalse,
+			exp: &fa,
 			err: nil,
 		},
 		{
 			name: "Numeric True",
 			req: map[string][]string{
-				"use_mmap": []string{"1"},
+				"use_mmap": {"1"},
 			},
-			exp: TriStateTrue,
+			exp: &tr,
 			err: nil,
 		},
 		{
 			name: "Numeric False",
 			req: map[string][]string{
-				"use_mmap": []string{"0"},
+				"use_mmap": {"0"},
 			},
-			exp: TriStateFalse,
+			exp: &fa,
 			err: nil,
 		},
 		{
 			name: "invalid string",
 			req: map[string][]string{
-				"use_mmap": []string{"foo"},
+				"use_mmap": {"foo"},
 			},
-			exp: TriStateUndefined,
+			exp: nil,
 			err: fmt.Errorf("invalid bool value [foo]"),
 		},
 	}
@@ -195,11 +199,11 @@ func TestUseMmapFormatParams(t *testing.T) {
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
 			resp, err := FormatParams(test.req)
-			require.Equal(t, err, test.err)
+			require.Equal(t, test.err, err)
 			respVal, ok := resp["use_mmap"]
-			if test.exp != TriStateUndefined {
+			if test.exp != nil {
 				assert.True(t, ok, "resp: %v", resp)
-				assert.Equal(t, test.exp, respVal)
+				assert.Equal(t, *test.exp, *respVal.(*bool))
 			}
 		})
 	}

+ 6 - 5
llm/server.go

@@ -221,7 +221,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 		if g.Library == "metal" &&
 			uint64(opts.NumGPU) > 0 &&
 			uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
-			opts.UseMMap = api.TriStateFalse
+			opts.UseMMap = new(bool)
+			*opts.UseMMap = false
 		}
 	}
 
@@ -232,10 +233,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 	// Windows CUDA should not use mmap for best performance
 	// Linux  with a model larger than free space, mmap leads to thrashing
 	// For CPU loads we want the memory to be allocated, not FS cache
-	if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == api.TriStateUndefined) ||
-		(runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == api.TriStateUndefined) ||
-		(gpus[0].Library == "cpu" && opts.UseMMap == api.TriStateUndefined) ||
-		opts.UseMMap == api.TriStateFalse {
+	if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == nil) ||
+		(runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == nil) ||
+		(gpus[0].Library == "cpu" && opts.UseMMap == nil) ||
+		(opts.UseMMap != nil && !*opts.UseMMap) {
 		params = append(params, "--no-mmap")
 	}