瀏覽代碼

runner.go: Fix deadlock if a connection is closed during decoding

If a connection is closed while a sequence is being decoded, tokens
will continue to be added to the channel without anyone to read them.
This will result in the sender blocking, which will in turn block
all other decoding and sending for other sequences.

This is not limited to just the connection between Ollama and the
runner process. If the connection to the Ollama API is closed by
the user then Ollama will close the connection to the runner,
triggering this issue.
Jesse Gross 8 月之前
父節點
當前提交
6ccd0644e1
共有 1 個文件被更改,包括 17 次插入2 次删除
  1. 17 2
      llama/runner/runner.go

+ 17 - 2
llama/runner/runner.go

@@ -38,6 +38,9 @@ type Sequence struct {
 	// channel to send responses over
 	// channel to send responses over
 	responses chan string
 	responses chan string
 
 
+	// channel to stop decoding (such as if the remote connection is closed)
+	quit chan bool
+
 	// number of tokens to predict
 	// number of tokens to predict
 	numPredict int
 	numPredict int
 
 
@@ -106,6 +109,7 @@ func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence
 		n_prompt_tokens: len(tokens),
 		n_prompt_tokens: len(tokens),
 		numPredict:      params.numPredict,
 		numPredict:      params.numPredict,
 		responses:       make(chan string, 1),
 		responses:       make(chan string, 1),
+		quit:            make(chan bool, 1),
 		embedding:       make(chan []float32, 1),
 		embedding:       make(chan []float32, 1),
 		samplingCtx:     sc,
 		samplingCtx:     sc,
 		embeddingOnly:   params.embedding,
 		embeddingOnly:   params.embedding,
@@ -344,7 +348,11 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
 			truncated := truncateStop(pieces[i], stop)
 			truncated := truncateStop(pieces[i], stop)
 
 
 			for _, p := range truncated {
 			for _, p := range truncated {
-				seq.responses <- p
+				select {
+				case seq.responses <- p:
+				case <-seq.quit:
+					break
+				}
 			}
 			}
 
 
 			s.removeSequence(i, &pieces, "stop")
 			s.removeSequence(i, &pieces, "stop")
@@ -356,7 +364,12 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
 		}
 		}
 
 
 		for _, p := range pieces[i] {
 		for _, p := range pieces[i] {
-			seq.responses <- p
+			select {
+			case seq.responses <- p:
+			case <-seq.quit:
+				s.removeSequence(i, &pieces, "connection")
+				break
+			}
 		}
 		}
 
 
 		pieces[i] = []string{}
 		pieces[i] = []string{}
@@ -475,12 +488,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 			Content: content,
 			Content: content,
 		}); err != nil {
 		}); err != nil {
 			log.Println("Failed to encode result:", err)
 			log.Println("Failed to encode result:", err)
+			close(seq.quit)
 			return
 			return
 		}
 		}
 
 
 		flusher, ok := w.(http.Flusher)
 		flusher, ok := w.(http.Flusher)
 		if !ok {
 		if !ok {
 			http.Error(w, "Streaming not supported", http.StatusInternalServerError)
 			http.Error(w, "Streaming not supported", http.StatusInternalServerError)
+			close(seq.quit)
 			return
 			return
 		}
 		}