瀏覽代碼

runner.go: Fix resource leaks when removing sequences

There are multiple causes and paths that result in a sequence
ending. Not all of these free the sampling context or reset the
pieces slice. This factors out the removal code so that all
paths release resources.
Jesse Gross 8 月之前
父節點
當前提交
0b73cca386
共有 2 個文件被更改,包括 20 次插入22 次删除
  1. 3 1
      llama/llama.go
  2. 17 21
      llama/runner/runner.go

+ 3 - 1
llama/llama.go

@@ -429,7 +429,9 @@ func NewSamplingContext(params SamplingParams) *SamplingContext {
 }
 }
 
 
 func (s *SamplingContext) Free() {
 func (s *SamplingContext) Free() {
-	C.llama_sampling_cfree(s.c)
+	if s.c != nil {
+		C.llama_sampling_cfree(s.c)
+	}
 }
 }
 
 
 func (s *SamplingContext) Reset() {
 func (s *SamplingContext) Reset() {

+ 17 - 21
llama/runner/runner.go

@@ -197,6 +197,18 @@ func incompleteUnicode(token string) bool {
 	return incomplete
 	return incomplete
 }
 }
 
 
+func (s *Server) removeSequence(seqIndex int, pieces *[][]string, reason string) {
+	seq := s.seqs[seqIndex]
+
+	seq.doneReason = reason
+	close(seq.responses)
+	close(seq.embedding)
+	(*pieces)[seqIndex] = []string{}
+	seq.samplingCtx.Free()
+	s.lc.KvCacheSeqRm(seqIndex, 0, -1)
+	s.seqs[seqIndex] = nil
+}
+
 func (s *Server) run(ctx context.Context) {
 func (s *Server) run(ctx context.Context) {
 	// build up stop sequences as we recognize them
 	// build up stop sequences as we recognize them
 	// TODO (jmorganca): simplify this
 	// TODO (jmorganca): simplify this
@@ -231,10 +243,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
 
 
 		// if past the num predict limit
 		// if past the num predict limit
 		if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
 		if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
-			seq.doneReason = "limit"
-			close(seq.responses)
-			s.lc.KvCacheSeqRm(i, 0, -1)
-			s.seqs[i] = nil
+			s.removeSequence(i, &pieces, "limit")
 			continue
 			continue
 		}
 		}
 
 
@@ -288,9 +297,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
 			}
 			}
 
 
 			seq.embedding <- embd
 			seq.embedding <- embd
-			close(seq.embedding)
-			s.lc.KvCacheSeqRm(i, 0, -1)
-			s.seqs[i] = nil
+			s.removeSequence(i, &pieces, "")
 			continue
 			continue
 		}
 		}
 
 
@@ -313,18 +320,12 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
 		// if it's an end of sequence token, break
 		// if it's an end of sequence token, break
 		// TODO: just end this sequence
 		// TODO: just end this sequence
 		if s.model.TokenIsEog(token) {
 		if s.model.TokenIsEog(token) {
-			// TODO: end the sequence instead of quitting the pool
-			s.lc.KvCacheSeqRm(i, 0, -1)
-
 			// TODO (jmorganca): we should send this back
 			// TODO (jmorganca): we should send this back
 			// as it's important for the /api/generate context
 			// as it's important for the /api/generate context
 			// seq.responses <- piece
 			// seq.responses <- piece
 
 
-			seq.doneReason = "stop"
-			close(seq.responses)
-			seq.samplingCtx.Free()
-			pieces[i] = []string{}
-			s.seqs[i] = nil
+			// TODO: end the sequence instead of quitting the pool
+			s.removeSequence(i, &pieces, "stop")
 			continue
 			continue
 		}
 		}
 
 
@@ -346,12 +347,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
 				seq.responses <- p
 				seq.responses <- p
 			}
 			}
 
 
-			s.lc.KvCacheSeqRm(i, 0, -1)
-			seq.doneReason = "stop"
-			close(seq.responses)
-			seq.samplingCtx.Free()
-			pieces[i] = []string{}
-			s.seqs[i] = nil
+			s.removeSequence(i, &pieces, "stop")
 			continue
 			continue
 		}
 		}