Jelajahi Sumber

add stop conditions

Michael Yang 1 tahun lalu
induk
melakukan
fadf75f99d
2 mengubah file dengan 39 tambahan dan 13 penghapusan
  1. 14 13
      api/types.go
  2. 25 0
      llama/llama.go

+ 14 - 13
api/types.go

@@ -165,19 +165,20 @@ type Options struct {
 	EmbeddingOnly bool `json:"embedding_only,omitempty"`
 
 	// Predict options
-	RepeatLastN      int     `json:"repeat_last_n,omitempty"`
-	RepeatPenalty    float32 `json:"repeat_penalty,omitempty"`
-	FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
-	PresencePenalty  float32 `json:"presence_penalty,omitempty"`
-	Temperature      float32 `json:"temperature,omitempty"`
-	TopK             int     `json:"top_k,omitempty"`
-	TopP             float32 `json:"top_p,omitempty"`
-	TFSZ             float32 `json:"tfs_z,omitempty"`
-	TypicalP         float32 `json:"typical_p,omitempty"`
-	Mirostat         int     `json:"mirostat,omitempty"`
-	MirostatTau      float32 `json:"mirostat_tau,omitempty"`
-	MirostatEta      float32 `json:"mirostat_eta,omitempty"`
-	PenalizeNewline  bool    `json:"penalize_newline,omitempty"`
+	RepeatLastN      int      `json:"repeat_last_n,omitempty"`
+	RepeatPenalty    float32  `json:"repeat_penalty,omitempty"`
+	FrequencyPenalty float32  `json:"frequency_penalty,omitempty"`
+	PresencePenalty  float32  `json:"presence_penalty,omitempty"`
+	Temperature      float32  `json:"temperature,omitempty"`
+	TopK             int      `json:"top_k,omitempty"`
+	TopP             float32  `json:"top_p,omitempty"`
+	TFSZ             float32  `json:"tfs_z,omitempty"`
+	TypicalP         float32  `json:"typical_p,omitempty"`
+	Mirostat         int      `json:"mirostat,omitempty"`
+	MirostatTau      float32  `json:"mirostat_tau,omitempty"`
+	MirostatEta      float32  `json:"mirostat_eta,omitempty"`
+	PenalizeNewline  bool     `json:"penalize_newline,omitempty"`
+	StopConditions   []string `json:"stop_conditions,omitempty"`
 
 	NumThread int `json:"num_thread,omitempty"`
 }

+ 25 - 0
llama/llama.go

@@ -172,6 +172,8 @@ func (llm *LLM) Close() {
 	C.llama_print_timings(llm.ctx)
 }
 
+var errNeedMoreData = errors.New("need more data")
+
 func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
 	C.llama_reset_timings(llm.ctx)
 
@@ -200,6 +202,17 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
 		}
 
 		b.WriteString(llm.detokenize(token))
+
+		if err := llm.checkStopConditions(b); err != nil {
+			if errors.Is(err, io.EOF) {
+				break
+			} else if errors.Is(err, errNeedMoreData) {
+				continue
+			}
+
+			return err
+		}
+
 		if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax {
 			fn(api.GenerateResponse{Response: b.String()})
 			b.Reset()
@@ -228,6 +241,18 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
 	return nil
 }
 
+func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
+	for _, stopCondition := range llm.StopConditions {
+		if stopCondition == b.String() {
+			return io.EOF
+		} else if strings.HasPrefix(stopCondition, b.String()) {
+			return errNeedMoreData
+		}
+	}
+
+	return nil
+}
+
 func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
 	tokens := append(ctx, llm.tokenize(prompt)...)
 	if llm.NumKeep < 0 {