|
@@ -334,20 +334,18 @@ func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse
|
|
|
|
|
|
b.WriteString(llm.Decode(int(token)))
|
|
|
|
|
|
- if err := llm.checkStopConditions(b); err != nil {
|
|
|
- if errors.Is(err, io.EOF) {
|
|
|
- break
|
|
|
- } else if errors.Is(err, errNeedMoreData) {
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- return err
|
|
|
+ stop, endsWithStopPrefix := handleStopSequences(&b, llm.Stop)
|
|
|
+ if endsWithStopPrefix {
|
|
|
+ continue
|
|
|
}
|
|
|
|
|
|
if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax {
|
|
|
fn(api.GenerateResponse{Response: b.String()})
|
|
|
b.Reset()
|
|
|
}
|
|
|
+ if stop {
|
|
|
+ break
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
embd := make([]int, len(llm.embd))
|
|
@@ -370,16 +368,31 @@ func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (llm *llama) checkStopConditions(b bytes.Buffer) error {
|
|
|
- for _, stopCondition := range llm.Stop {
|
|
|
- if stopCondition == strings.TrimSpace(b.String()) {
|
|
|
- return io.EOF
|
|
|
- } else if strings.HasPrefix(stopCondition, strings.TrimSpace(b.String())) {
|
|
|
- return errNeedMoreData
|
|
|
+// handleStopSequences checks whether b contains any of the stop sequences, or ends with a prefix of
|
|
|
+// any stop sequence (and therefore might contain data that should not ultimately be returned to the
|
|
|
+// client).
|
|
|
+//
|
|
|
+// If b contains a stop sequence, it modifies b to remove the stop sequence and all subsequent data.
|
|
|
+func handleStopSequences(b *bytes.Buffer, stopSequences []string) (stop bool, endsWithStopPrefix bool) {
|
|
|
+ s := b.String()
|
|
|
+ for _, seq := range stopSequences {
|
|
|
+ // Check for an exact or substring match.
|
|
|
+ if i := strings.Index(s, seq); i != -1 {
|
|
|
+ b.Truncate(i)
|
|
|
+ return true, false
|
|
|
+ }
|
|
|
+
|
|
|
+ // Check if b ends with a prefix of the stop sequence.
|
|
|
+ if len(seq) > 1 {
|
|
|
+ for i := 1; i < len(seq); i++ {
|
|
|
+ if strings.HasSuffix(s, seq[:i]) {
|
|
|
+ return false, true
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- return nil
|
|
|
+ return false, false
|
|
|
}
|
|
|
|
|
|
func (llm *llama) marshalPrompt(ctx []int, prompt string) []C.llama_token {
|