|
@@ -61,7 +61,7 @@ type Sequence struct {
|
|
crossAttention bool
|
|
crossAttention bool
|
|
|
|
|
|
// channel to send responses over
|
|
// channel to send responses over
|
|
- responses chan string
|
|
|
|
|
|
+ responses chan CompletionResponse
|
|
|
|
|
|
// channel to stop decoding (such as if the remote connection is closed)
|
|
// channel to stop decoding (such as if the remote connection is closed)
|
|
quit chan bool
|
|
quit chan bool
|
|
@@ -153,7 +153,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
|
startProcessingTime: startTime,
|
|
startProcessingTime: startTime,
|
|
numPredict: params.numPredict,
|
|
numPredict: params.numPredict,
|
|
pendingResponses: make([]string, 0),
|
|
pendingResponses: make([]string, 0),
|
|
- responses: make(chan string, 100),
|
|
|
|
|
|
+ responses: make(chan CompletionResponse, 100),
|
|
quit: make(chan bool, 1),
|
|
quit: make(chan bool, 1),
|
|
embedding: make(chan []float32, 1),
|
|
embedding: make(chan []float32, 1),
|
|
samplingCtx: sc,
|
|
samplingCtx: sc,
|
|
@@ -281,7 +281,7 @@ func flushPending(seq *Sequence) bool {
|
|
if len(seq.pendingResponses) == 0 {
|
|
if len(seq.pendingResponses) == 0 {
|
|
return true
|
|
return true
|
|
}
|
|
}
|
|
- joined := strings.Join(seq.pendingResponses, "")
|
|
|
|
|
|
+ content := strings.Join(seq.pendingResponses, "")
|
|
seq.pendingResponses = []string{}
|
|
seq.pendingResponses = []string{}
|
|
|
|
|
|
// Check if there are any partial UTF-8 characters remaining.
|
|
// Check if there are any partial UTF-8 characters remaining.
|
|
@@ -290,8 +290,8 @@ func flushPending(seq *Sequence) bool {
|
|
// - Sequence is ending, e.g. generation limit has been hit
|
|
// - Sequence is ending, e.g. generation limit has been hit
|
|
// - Invalid characters in the middle of a string
|
|
// - Invalid characters in the middle of a string
|
|
// This is a stricter check to ensure we never output invalid Unicode.
|
|
// This is a stricter check to ensure we never output invalid Unicode.
|
|
- for !utf8.ValidString(joined) {
|
|
|
|
- joined = joined[:len(joined)-1]
|
|
|
|
|
|
+ for !utf8.ValidString(content) {
|
|
|
|
+ content = content[:len(content)-1]
|
|
}
|
|
}
|
|
|
|
|
|
// Add logits if requested and available
|
|
// Add logits if requested and available
|
|
@@ -302,7 +302,9 @@ func flushPending(seq *Sequence) bool {
|
|
}
|
|
}
|
|
|
|
|
|
select {
|
|
select {
|
|
- case seq.responses <- joined:
|
|
|
|
|
|
+ case seq.responses <- CompletionResponse{
|
|
|
|
+ Content: content,
|
|
|
|
+ }:
|
|
return true
|
|
return true
|
|
case <-seq.quit:
|
|
case <-seq.quit:
|
|
return false
|
|
return false
|
|
@@ -755,11 +757,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
case <-r.Context().Done():
|
|
case <-r.Context().Done():
|
|
close(seq.quit)
|
|
close(seq.quit)
|
|
return
|
|
return
|
|
- case content, ok := <-seq.responses:
|
|
|
|
|
|
+ case resp, ok := <-seq.responses:
|
|
if ok {
|
|
if ok {
|
|
- if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
|
|
|
- Content: content,
|
|
|
|
- }); err != nil {
|
|
|
|
|
|
+ if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
close(seq.quit)
|
|
close(seq.quit)
|
|
return
|
|
return
|