|
@@ -38,6 +38,9 @@ type Sequence struct {
|
|
// channel to send responses over
|
|
// channel to send responses over
|
|
responses chan string
|
|
responses chan string
|
|
|
|
|
|
|
|
+ // channel to stop decoding (such as if the remote connection is closed)
|
|
|
|
+ quit chan bool
|
|
|
|
+
|
|
// number of tokens to predict
|
|
// number of tokens to predict
|
|
numPredict int
|
|
numPredict int
|
|
|
|
|
|
@@ -106,6 +109,7 @@ func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence
|
|
n_prompt_tokens: len(tokens),
|
|
n_prompt_tokens: len(tokens),
|
|
numPredict: params.numPredict,
|
|
numPredict: params.numPredict,
|
|
responses: make(chan string, 1),
|
|
responses: make(chan string, 1),
|
|
|
|
+ quit: make(chan bool, 1),
|
|
embedding: make(chan []float32, 1),
|
|
embedding: make(chan []float32, 1),
|
|
samplingCtx: sc,
|
|
samplingCtx: sc,
|
|
embeddingOnly: params.embedding,
|
|
embeddingOnly: params.embedding,
|
|
@@ -344,7 +348,11 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
|
|
truncated := truncateStop(pieces[i], stop)
|
|
truncated := truncateStop(pieces[i], stop)
|
|
|
|
|
|
for _, p := range truncated {
|
|
for _, p := range truncated {
|
|
- seq.responses <- p
|
|
|
|
|
|
+ select {
|
|
|
|
+ case seq.responses <- p:
|
|
|
|
+ case <-seq.quit:
|
|
|
|
+ break
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
s.removeSequence(i, &pieces, "stop")
|
|
s.removeSequence(i, &pieces, "stop")
|
|
@@ -356,7 +364,12 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
|
|
}
|
|
}
|
|
|
|
|
|
for _, p := range pieces[i] {
|
|
for _, p := range pieces[i] {
|
|
- seq.responses <- p
|
|
|
|
|
|
+ select {
|
|
|
|
+ case seq.responses <- p:
|
|
|
|
+ case <-seq.quit:
|
|
|
|
+ s.removeSequence(i, &pieces, "connection")
|
|
|
|
+ break
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
pieces[i] = []string{}
|
|
pieces[i] = []string{}
|
|
@@ -475,12 +488,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
Content: content,
|
|
Content: content,
|
|
}); err != nil {
|
|
}); err != nil {
|
|
log.Println("Failed to encode result:", err)
|
|
log.Println("Failed to encode result:", err)
|
|
|
|
+ close(seq.quit)
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
flusher, ok := w.(http.Flusher)
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
if !ok {
|
|
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
|
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
|
|
|
+ close(seq.quit)
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|