Browse Source

token repeat limit for prediction requests (#3080)

Bruce MacDonald 1 year ago
parent
commit
3e22611200
1 changed files with 27 additions and 7 deletions
  1. 27 7
      llm/dyn_ext_server.go

+ 27 - 7
llm/dyn_ext_server.go

@@ -228,17 +228,14 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
 		}
 
 		retryNeeded := false
+		// keep track of the last token generated, this is used to abort if the model starts looping
+		var lastToken string
+		var tokenRepeat int
 	out:
 		for {
 			select {
 			case <-ctx.Done():
-				// This handles the request cancellation
-				C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
-				if resp.id < 0 {
-					return extServerResponseToErr(resp)
-				} else {
-					return nil
-				}
+				return cancelCompletion(llm, resp)
 			default:
 				var result C.ext_server_task_result_t
 				C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result)
@@ -261,6 +258,20 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
 					break out
 				}
 
+				switch {
+				case strings.TrimSpace(p.Content) == lastToken:
+					tokenRepeat++
+				default:
+					lastToken = strings.TrimSpace(p.Content)
+					tokenRepeat = 0
+				}
+
+				// 30 picked as an arbitrary max token repeat limit, modify as needed
+				if tokenRepeat > 30 {
+					slog.Debug("prediction aborted, token repeat limit reached")
+					return cancelCompletion(llm, resp)
+				}
+
 				if p.Content != "" {
 					fn(PredictResult{
 						Content: p.Content,
@@ -288,6 +299,15 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
 	return fmt.Errorf("max retries exceeded")
 }
 
+func cancelCompletion(llm *dynExtServer, resp C.ext_server_resp_t) error {
+	C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
+	if resp.id < 0 {
+		return extServerResponseToErr(resp)
+	} else {
+		return nil
+	}
+}
+
 func (llm *dynExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
 	data, err := json.Marshal(TokenizeRequest{Content: prompt})
 	if err != nil {