瀏覽代碼

Relay default values to llama runner (#672)

* include seed in params for llama.cpp server and remove empty filter for temp

* relay default predict options to llama.cpp

- reorganize options to match predict request for readability

* omit empty stop

---------

Co-authored-by: hallh <hallh@users.noreply.github.com>
Bruce MacDonald 1 年之前
父節點
當前提交
1fbf3585d6
共有 2 個文件被更改,包括 43 次插入44 次删除
  1. 22 22
      api/types.go
  2. 21 22
      llm/llama.go

+ 22 - 22
api/types.go

@@ -280,38 +280,38 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
 
 func DefaultOptions() Options {
 	return Options{
-		Seed: -1,
-
-		UseNUMA: false,
-
-		NumCtx:             2048,
-		NumKeep:            -1,
-		NumBatch:           512,
-		NumGPU:             -1, // -1 here indicates that NumGPU should be set dynamically
-		NumGQA:             1,
-		LowVRAM:            false,
-		F16KV:              true,
-		UseMMap:            true,
-		UseMLock:           false,
-		RopeFrequencyBase:  10000.0,
-		RopeFrequencyScale: 1.0,
-		EmbeddingOnly:      true,
-
-		RepeatLastN:      64,
-		RepeatPenalty:    1.1,
-		FrequencyPenalty: 0.0,
-		PresencePenalty:  0.0,
+		// options set on request to runner
+		NumPredict:       -1,
+		NumKeep:          -1,
 		Temperature:      0.8,
 		TopK:             40,
 		TopP:             0.9,
 		TFSZ:             1.0,
 		TypicalP:         1.0,
+		RepeatLastN:      64,
+		RepeatPenalty:    1.1,
+		PresencePenalty:  0.0,
+		FrequencyPenalty: 0.0,
 		Mirostat:         0,
 		MirostatTau:      5.0,
 		MirostatEta:      0.1,
 		PenalizeNewline:  true,
+		Seed:             -1,
 
-		NumThread: 0, // let the runtime decide
+		// options set when the model is loaded
+		NumCtx:             2048,
+		RopeFrequencyBase:  10000.0,
+		RopeFrequencyScale: 1.0,
+		NumBatch:           512,
+		NumGPU:             -1, // -1 here indicates that NumGPU should be set dynamically
+		NumGQA:             1,
+		NumThread:          0, // let the runtime decide
+		LowVRAM:            false,
+		F16KV:              true,
+		UseMLock:           false,
+		UseMMap:            true,
+		UseNUMA:            false,
+		EmbeddingOnly:      true,
 	}
 }
 

+ 21 - 22
llm/llama.go

@@ -417,28 +417,25 @@ type Prediction struct {
 }
 
 type PredictRequest struct {
-	Stream           bool            `json:"stream"`
-	NPredict         int             `json:"n_predict,omitempty"`
-	TopK             int             `json:"top_k,omitempty"`
-	TopP             float32         `json:"top_p,omitempty"`
-	TfsZ             float32         `json:"tfs_z,omitempty"`
-	TypicalP         float32         `json:"typical_p,omitempty"`
-	RepeatLastN      int             `json:"repeat_last_n,omitempty"`
-	Temperature      float32         `json:"temperature,omitempty"`
-	RepeatPenalty    float32         `json:"repeat_penalty,omitempty"`
-	PresencePenalty  float32         `json:"presence_penalty,omitempty"`
-	FrequencyPenalty float32         `json:"frequency_penalty,omitempty"`
-	Mirostat         int             `json:"mirostat,omitempty"`
-	MirostatTau      float32         `json:"mirostat_tau,omitempty"`
-	MirostatEta      float32         `json:"mirostat_eta,omitempty"`
-	PenalizeNl       bool            `json:"penalize_nl,omitempty"`
-	NKeep            int             `json:"n_keep,omitempty"`
-	Seed             int             `json:"seed,omitempty"`
-	Prompt           string          `json:"prompt,omitempty"`
-	NProbs           int             `json:"n_probs,omitempty"`
-	LogitBias        map[int]float32 `json:"logit_bias,omitempty"`
-	IgnoreEos        bool            `json:"ignore_eos,omitempty"`
-	Stop             []string        `json:"stop,omitempty"`
+	Prompt           string   `json:"prompt"`
+	Stream           bool     `json:"stream"`
+	NPredict         int      `json:"n_predict"`
+	NKeep            int      `json:"n_keep"`
+	Temperature      float32  `json:"temperature"`
+	TopK             int      `json:"top_k"`
+	TopP             float32  `json:"top_p"`
+	TfsZ             float32  `json:"tfs_z"`
+	TypicalP         float32  `json:"typical_p"`
+	RepeatLastN      int      `json:"repeat_last_n"`
+	RepeatPenalty    float32  `json:"repeat_penalty"`
+	PresencePenalty  float32  `json:"presence_penalty"`
+	FrequencyPenalty float32  `json:"frequency_penalty"`
+	Mirostat         int      `json:"mirostat"`
+	MirostatTau      float32  `json:"mirostat_tau"`
+	MirostatEta      float32  `json:"mirostat_eta"`
+	PenalizeNl       bool     `json:"penalize_nl"`
+	Seed             int      `json:"seed"`
+	Stop             []string `json:"stop,omitempty"`
 }
 
 func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
@@ -470,8 +467,10 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
 		MirostatTau:      llm.MirostatTau,
 		MirostatEta:      llm.MirostatEta,
 		PenalizeNl:       llm.PenalizeNewline,
+		Seed:             llm.Seed,
 		Stop:             llm.Stop,
 	}
+
 	data, err := json.Marshal(predReq)
 	if err != nil {
 		return fmt.Errorf("error marshaling data: %v", err)