Browse Source

pass model and predict options

Patrick Devine 1 year ago
parent
commit
3f1b7177f2
3 changed files with 151 additions and 16 deletions
  1. 86 0
      api/types.go
  2. 2 6
      llama/llama.go
  3. 63 10
      server/routes.go

+ 86 - 0
api/types.go

@@ -31,6 +31,92 @@ type PullProgress struct {
 type GenerateRequest struct {
 	Model  string `json:"model"`
 	Prompt string `json:"prompt"`
+
+	ModelOptions   `json:"model_opts"`
+	PredictOptions `json:"predict_opts"`
+}
+
+type ModelOptions struct {
+	ContextSize int    `json:"context_size"`
+	Seed        int    `json:"seed"`
+	NBatch      int    `json:"n_batch"`
+	F16Memory   bool   `json:"memory_f16"`
+	MLock       bool   `json:"mlock"`
+	MMap        bool   `json:"mmap"`
+	VocabOnly   bool   `json:"vocab_only"`
+	LowVRAM     bool   `json:"low_vram"`
+	Embeddings  bool   `json:"embeddings"`
+	NUMA        bool   `json:"numa"`
+	NGPULayers  int    `json:"gpu_layers"`
+	MainGPU     string `json:"main_gpu"`
+	TensorSplit string `json:"tensor_split"`
+}
+
+type PredictOptions struct {
+	Seed        int     `json:"seed"`
+	Threads     int     `json:"threads"`
+	Tokens      int     `json:"tokens"`
+	TopK        int     `json:"top_k"`
+	Repeat      int     `json:"repeat"`
+	Batch       int     `json:"batch"`
+	NKeep       int     `json:"nkeep"`
+	TopP        float64 `json:"top_p"`
+	Temperature float64 `json:"temp"`
+	Penalty     float64 `json:"penalty"`
+	F16KV       bool
+	DebugMode   bool
+	StopPrompts []string
+	IgnoreEOS   bool `json:"ignore_eos"`
+
+	TailFreeSamplingZ float64 `json:"tfs_z"`
+	TypicalP          float64 `json:"typical_p"`
+	FrequencyPenalty  float64 `json:"freq_penalty"`
+	PresencePenalty   float64 `json:"pres_penalty"`
+	Mirostat          int     `json:"mirostat"`
+	MirostatETA       float64 `json:"mirostat_lr"`
+	MirostatTAU       float64 `json:"mirostat_ent"`
+	PenalizeNL        bool    `json:"penalize_nl"`
+	LogitBias         string  `json:"logit_bias"`
+
+	PathPromptCache string
+	MLock           bool `json:"mlock"`
+	MMap            bool `json:"mmap"`
+	PromptCacheAll  bool
+	PromptCacheRO   bool
+	MainGPU         string
+	TensorSplit     string
+}
+
+var DefaultModelOptions ModelOptions = ModelOptions{
+	ContextSize: 128,
+	Seed:        0,
+	F16Memory:   true,
+	MLock:       false,
+	Embeddings:  true,
+	MMap:        true,
+	LowVRAM:     false,
+}
+
+var DefaultPredictOptions PredictOptions = PredictOptions{
+	Seed:              -1,
+	Threads:           -1,
+	Tokens:            512,
+	Penalty:           1.1,
+	Repeat:            64,
+	Batch:             512,
+	NKeep:             64,
+	TopK:              90,
+	TopP:              0.86,
+	TailFreeSamplingZ: 1.0,
+	TypicalP:          1.0,
+	Temperature:       0.8,
+	FrequencyPenalty:  0.0,
+	PresencePenalty:   0.0,
+	Mirostat:          0,
+	MirostatTAU:       5.0,
+	MirostatETA:       0.1,
+	MMap:              true,
+	StopPrompts:       []string{"llama"},
 }
 
 type GenerateResponse struct {

+ 2 - 6
llama/llama.go

@@ -42,9 +42,7 @@ type LLama struct {
 	contextSize int
 }
 
-func New(model string, opts ...ModelOption) (*LLama, error) {
-	mo := NewModelOptions(opts...)
-
+func New(model string, mo ModelOptions) (*LLama, error) {
 	modelPath := C.CString(model)
 	defer C.free(unsafe.Pointer(modelPath))
 
@@ -108,9 +106,7 @@ func (l *LLama) Eval(text string, opts ...PredictOption) error {
 	return nil
 }
 
-func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) {
-	po := NewPredictOptions(opts...)
-
+func (l *LLama) Predict(text string, po PredictOptions) (string, error) {
 	if po.TokenCallback != nil {
 		setCallback(l.ctx, po.TokenCallback)
 	}

+ 63 - 10
server/routes.go

@@ -26,12 +26,9 @@ var templatesFS embed.FS
 var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
 
 func generate(c *gin.Context) {
-	// TODO: these should be request parameters
-	gpulayers := 1
-	tokens := 512
-	threads := runtime.NumCPU()
-
 	var req api.GenerateRequest
+	req.ModelOptions = api.DefaultModelOptions
+	req.PredictOptions = api.DefaultPredictOptions
 	if err := c.ShouldBindJSON(&req); err != nil {
 		c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
 		return
@@ -41,7 +38,10 @@ func generate(c *gin.Context) {
 		req.Model = remoteModel.FullName()
 	}
 
-	model, err := llama.New(req.Model, llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(gpulayers))
+	modelOpts := getModelOpts(req)
+	modelOpts.NGPULayers = 1  // hard-code this for now
+
+	model, err := llama.New(req.Model, modelOpts)
 	if err != nil {
 		fmt.Println("Loading the model failed:", err.Error())
 		return
@@ -65,13 +65,16 @@ func generate(c *gin.Context) {
 	}
 
 	ch := make(chan string)
+	model.SetTokenCallback(func(token string) bool {
+		ch <- token
+		return true
+	})
+
+	predictOpts := getPredictOpts(req)
 
 	go func() {
 		defer close(ch)
-		_, err := model.Predict(req.Prompt, llama.Debug, llama.SetTokenCallback(func(token string) bool {
-			ch <- token
-			return true
-		}), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"))
+		_, err := model.Predict(req.Prompt, predictOpts)
 		if err != nil {
 			panic(err)
 		}
@@ -161,3 +164,53 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i
 
 	return
 }
+
+func getModelOpts(req api.GenerateRequest) llama.ModelOptions {
+	var opts llama.ModelOptions
+	opts.ContextSize = req.ModelOptions.ContextSize
+	opts.Seed = req.ModelOptions.Seed
+	opts.F16Memory = req.ModelOptions.F16Memory
+	opts.MLock = req.ModelOptions.MLock
+	opts.Embeddings = req.ModelOptions.Embeddings
+	opts.MMap = req.ModelOptions.MMap
+	opts.LowVRAM = req.ModelOptions.LowVRAM
+
+	opts.NBatch = req.ModelOptions.NBatch
+	opts.VocabOnly = req.ModelOptions.VocabOnly
+	opts.NUMA = req.ModelOptions.NUMA
+	opts.NGPULayers = req.ModelOptions.NGPULayers
+	opts.MainGPU = req.ModelOptions.MainGPU
+	opts.TensorSplit = req.ModelOptions.TensorSplit
+
+	return opts
+}
+
+func getPredictOpts(req api.GenerateRequest) llama.PredictOptions {
+	var opts llama.PredictOptions
+
+	if req.PredictOptions.Threads == -1 {
+		opts.Threads = runtime.NumCPU()
+	} else {
+		opts.Threads = req.PredictOptions.Threads
+	}
+
+	opts.Seed = req.PredictOptions.Seed
+	opts.Tokens = req.PredictOptions.Tokens
+	opts.Penalty = req.PredictOptions.Penalty
+	opts.Repeat = req.PredictOptions.Repeat
+	opts.Batch = req.PredictOptions.Batch
+	opts.NKeep = req.PredictOptions.NKeep
+	opts.TopK = req.PredictOptions.TopK
+	opts.TopP = req.PredictOptions.TopP
+	opts.TailFreeSamplingZ = req.PredictOptions.TailFreeSamplingZ
+	opts.TypicalP = req.PredictOptions.TypicalP
+	opts.Temperature = req.PredictOptions.Temperature
+	opts.FrequencyPenalty = req.PredictOptions.FrequencyPenalty
+	opts.PresencePenalty = req.PredictOptions.PresencePenalty
+	opts.Mirostat = req.PredictOptions.Mirostat
+	opts.MirostatTAU = req.PredictOptions.MirostatTAU
+	opts.MirostatETA = req.PredictOptions.MirostatETA
+	opts.MMap = req.PredictOptions.MMap
+
+	return opts
+}