Преглед на файлове

send completion response on chan

Bruce MacDonald преди 2 месеца
родител
ревизия
6dfcdec2da
променени са 1 файла, в които са добавени 10 реда и са изтрити 10 реда
  1. 10 10
      llama/runner/runner.go

+ 10 - 10
llama/runner/runner.go

@@ -61,7 +61,7 @@ type Sequence struct {
 	crossAttention bool
 
 	// channel to send responses over
-	responses chan string
+	responses chan CompletionResponse
 
 	// channel to stop decoding (such as if the remote connection is closed)
 	quit chan bool
@@ -153,7 +153,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
 		startProcessingTime: startTime,
 		numPredict:          params.numPredict,
 		pendingResponses:    make([]string, 0),
-		responses:           make(chan string, 100),
+		responses:           make(chan CompletionResponse, 100),
 		quit:                make(chan bool, 1),
 		embedding:           make(chan []float32, 1),
 		samplingCtx:         sc,
@@ -281,7 +281,7 @@ func flushPending(seq *Sequence) bool {
 	if len(seq.pendingResponses) == 0 {
 		return true
 	}
-	joined := strings.Join(seq.pendingResponses, "")
+	content := strings.Join(seq.pendingResponses, "")
 	seq.pendingResponses = []string{}
 
 	// Check if there are any partial UTF-8 characters remaining.
@@ -290,8 +290,8 @@ func flushPending(seq *Sequence) bool {
 	// - Sequence is ending, e.g. generation limit has been hit
 	// - Invalid characters in the middle of a string
 	// This is a stricter check to ensure we never output invalid Unicode.
-	for !utf8.ValidString(joined) {
-		joined = joined[:len(joined)-1]
+	for !utf8.ValidString(content) {
+		content = content[:len(content)-1]
 	}
 
 	// Add logits if requested and available
@@ -302,7 +302,9 @@ func flushPending(seq *Sequence) bool {
 	}
 
 	select {
-	case seq.responses <- joined:
+	case seq.responses <- CompletionResponse{
+		Content: content,
+	}:
 		return true
 	case <-seq.quit:
 		return false
@@ -755,11 +757,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 		case <-r.Context().Done():
 			close(seq.quit)
 			return
-		case content, ok := <-seq.responses:
+		case resp, ok := <-seq.responses:
 			if ok {
-				if err := json.NewEncoder(w).Encode(&CompletionResponse{
-					Content: content,
-				}); err != nil {
+				if err := json.NewEncoder(w).Encode(&resp); err != nil {
 					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
 					close(seq.quit)
 					return