瀏覽代碼

automatically set num_keep if num_keep < 0

num_keep defines how many tokens to keep in the context when truncating
inputs. if left to its default value of -1, the server will calculate
num_keep to be the left of the system instructions
Michael Yang 1 年之前
父節點
當前提交
4dc5b117dd
共有 3 個文件被更改,包括 28 次插入14 次删除
  1. 1 0
      api/types.go
  2. 8 14
      llama/llama.go
  3. 19 0
      server/routes.go

+ 1 - 0
api/types.go

@@ -264,6 +264,7 @@ func DefaultOptions() Options {
 		UseNUMA: false,
 
 		NumCtx:             2048,
+		NumKeep:            -1,
 		NumBatch:           512,
 		NumGPU:             1,
 		NumGQA:             1,

+ 8 - 14
llama/llama.go

@@ -189,10 +189,6 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
 		tokens[i] = C.llama_token(ctx[i])
 	}
 
-	if len(tokens) == 0 {
-		tokens = llm.tokenize(" ")
-	}
-
 	llm.marshalPrompt(tokens, prompt)
 
 	C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed))
@@ -208,7 +204,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
 			return err
 		}
 
-		b.WriteString(llm.detokenize(token))
+		b.WriteString(llm.Decode(token))
 
 		if err := llm.checkStopConditions(b); err != nil {
 			if errors.Is(err, io.EOF) {
@@ -226,17 +222,15 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
 		}
 	}
 
-	last := make([]int, 0, len(llm.last))
-	for _, i := range llm.last {
-		if i != 0 {
-			last = append(last, int(i))
-		}
+	embd := make([]int, len(llm.embd))
+	for i := range llm.embd {
+		embd[i] = int(llm.embd[i])
 	}
 
 	timings := C.llama_get_timings(llm.ctx)
 	fn(api.GenerateResponse{
 		Done:               true,
-		Context:            last,
+		Context:            embd,
 		SampleCount:        int(timings.n_sample),
 		SampleDuration:     parseDurationMs(float64(timings.t_sample_ms)),
 		PromptEvalCount:    int(timings.n_p_eval),
@@ -261,7 +255,7 @@ func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
 }
 
 func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
-	tokens := append(ctx, llm.tokenize(prompt)...)
+	tokens := append(ctx, llm.Encode(prompt)...)
 	if llm.NumKeep < 0 {
 		llm.NumKeep = len(tokens)
 	}
@@ -303,7 +297,7 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
 	return tokens
 }
 
-func (llm *LLM) tokenize(prompt string) []C.llama_token {
+func (llm *LLM) Encode(prompt string) []C.llama_token {
 	cPrompt := C.CString(prompt)
 	defer C.free(unsafe.Pointer(cPrompt))
 
@@ -315,7 +309,7 @@ func (llm *LLM) tokenize(prompt string) []C.llama_token {
 	return nil
 }
 
-func (llm *LLM) detokenize(tokens ...C.llama_token) string {
+func (llm *LLM) Decode(tokens ...C.llama_token) string {
 	var sb strings.Builder
 	for _, token := range tokens {
 		sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))

+ 19 - 0
server/routes.go

@@ -78,6 +78,25 @@ func GenerateHandler(c *gin.Context) {
 			return
 		}
 
+		if opts.NumKeep < 0 {
+			promptWithSystem, err := model.Prompt(api.GenerateRequest{})
+			if err != nil {
+				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+				return
+			}
+
+			promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}})
+			if err != nil {
+				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+				return
+			}
+
+			tokensWithSystem := llm.Encode(promptWithSystem)
+			tokensNoSystem := llm.Encode(promptNoSystem)
+
+			llm.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1
+		}
+
 		loaded.llm = llm
 		loaded.digest = model.Digest
 		loaded.options = opts