Selaa lähdekoodia

Adjust mmap logic for cuda windows for faster model load

On Windows, recent llama.cpp changes make mmap slower in most
cases, so default to off.  This also implements a tri-state for
use_mmap so we can detect the difference between a user provided
value of true/false, or unspecified.
Daniel Hiltgen 10 kuukautta sitten
vanhempi
commit
171796791f
3 muutettua tiedostoa jossa 96 lisäystä ja 15 poistoa
  1. 57 13
      api/types.go
  2. 36 0
      api/types_test.go
  3. 3 2
      llm/server.go

+ 57 - 13
api/types.go

@@ -159,18 +159,49 @@ type Options struct {
 
 
 // Runner options which must be set when the model is loaded into memory
 // Runner options which must be set when the model is loaded into memory
 type Runner struct {
 type Runner struct {
-	UseNUMA   bool `json:"numa,omitempty"`
-	NumCtx    int  `json:"num_ctx,omitempty"`
-	NumBatch  int  `json:"num_batch,omitempty"`
-	NumGPU    int  `json:"num_gpu,omitempty"`
-	MainGPU   int  `json:"main_gpu,omitempty"`
-	LowVRAM   bool `json:"low_vram,omitempty"`
-	F16KV     bool `json:"f16_kv,omitempty"`
-	LogitsAll bool `json:"logits_all,omitempty"`
-	VocabOnly bool `json:"vocab_only,omitempty"`
-	UseMMap   bool `json:"use_mmap,omitempty"`
-	UseMLock  bool `json:"use_mlock,omitempty"`
-	NumThread int  `json:"num_thread,omitempty"`
+	UseNUMA   bool     `json:"numa,omitempty"`
+	NumCtx    int      `json:"num_ctx,omitempty"`
+	NumBatch  int      `json:"num_batch,omitempty"`
+	NumGPU    int      `json:"num_gpu,omitempty"`
+	MainGPU   int      `json:"main_gpu,omitempty"`
+	LowVRAM   bool     `json:"low_vram,omitempty"`
+	F16KV     bool     `json:"f16_kv,omitempty"`
+	LogitsAll bool     `json:"logits_all,omitempty"`
+	VocabOnly bool     `json:"vocab_only,omitempty"`
+	UseMMap   TriState `json:"use_mmap,omitempty"`
+	UseMLock  bool     `json:"use_mlock,omitempty"`
+	NumThread int      `json:"num_thread,omitempty"`
+}
+
+type TriState int
+
+const (
+	TriStateUndefined TriState = -1
+	TriStateFalse     TriState = 0
+	TriStateTrue      TriState = 1
+)
+
+func (b *TriState) UnmarshalJSON(data []byte) error {
+	var v bool
+	if err := json.Unmarshal(data, &v); err != nil {
+		return err
+	}
+	if v {
+		*b = TriStateTrue
+	}
+	*b = TriStateFalse
+	return nil
+}
+
+func (b *TriState) MarshalJSON() ([]byte, error) {
+	if *b == TriStateUndefined {
+		return nil, nil
+	}
+	var v bool
+	if *b == TriStateTrue {
+		v = true
+	}
+	return json.Marshal(v)
 }
 }
 
 
 // EmbeddingRequest is the request passed to [Client.Embeddings].
 // EmbeddingRequest is the request passed to [Client.Embeddings].
@@ -403,6 +434,19 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
 				continue
 				continue
 			}
 			}
 
 
+			if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) {
+				val, ok := val.(bool)
+				if !ok {
+					return fmt.Errorf("option %q must be of type boolean", key)
+				}
+				if val {
+					field.SetInt(int64(TriStateTrue))
+				} else {
+					field.SetInt(int64(TriStateFalse))
+				}
+				continue
+			}
+
 			switch field.Kind() {
 			switch field.Kind() {
 			case reflect.Int:
 			case reflect.Int:
 				switch t := val.(type) {
 				switch t := val.(type) {
@@ -491,7 +535,7 @@ func DefaultOptions() Options {
 			LowVRAM:   false,
 			LowVRAM:   false,
 			F16KV:     true,
 			F16KV:     true,
 			UseMLock:  false,
 			UseMLock:  false,
-			UseMMap:   true,
+			UseMMap:   TriStateUndefined,
 			UseNUMA:   false,
 			UseNUMA:   false,
 		},
 		},
 	}
 	}

+ 36 - 0
api/types_test.go

@@ -105,3 +105,39 @@ func TestDurationMarshalUnmarshal(t *testing.T) {
 		})
 		})
 	}
 	}
 }
 }
+
+func TestUseMmapParsingFromJSON(t *testing.T) {
+	tests := []struct {
+		name string
+		req  string
+		exp  TriState
+	}{
+		{
+			name: "Undefined",
+			req:  `{ }`,
+			exp:  TriStateUndefined,
+		},
+		{
+			name: "True",
+			req:  `{ "use_mmap": true }`,
+			exp:  TriStateTrue,
+		},
+		{
+			name: "False",
+			req:  `{ "use_mmap": false }`,
+			exp:  TriStateFalse,
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			var oMap map[string]interface{}
+			err := json.Unmarshal([]byte(test.req), &oMap)
+			require.NoError(t, err)
+			opts := DefaultOptions()
+			err = opts.FromMap(oMap)
+			require.NoError(t, err)
+			assert.Equal(t, test.exp, opts.UseMMap)
+		})
+	}
+}

+ 3 - 2
llm/server.go

@@ -200,7 +200,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 		if g.Library == "metal" &&
 		if g.Library == "metal" &&
 			uint64(opts.NumGPU) > 0 &&
 			uint64(opts.NumGPU) > 0 &&
 			uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
 			uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
-			opts.UseMMap = false
+			opts.UseMMap = api.TriStateFalse
 		}
 		}
 	}
 	}
 
 
@@ -208,7 +208,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 		params = append(params, "--flash-attn")
 		params = append(params, "--flash-attn")
 	}
 	}
 
 
-	if !opts.UseMMap {
+	// Windows CUDA should not use mmap for best performance
+	if (runtime.GOOS == "windows" && gpus[0].Library == "cuda") || opts.UseMMap == api.TriStateFalse {
 		params = append(params, "--no-mmap")
 		params = append(params, "--no-mmap")
 	}
 	}