|
@@ -51,7 +51,7 @@ type Sequence struct {
|
|
pendingInputs []input
|
|
pendingInputs []input
|
|
|
|
|
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
|
- pendingResponses []string
|
|
|
|
|
|
+ pendingResponses []common.CompletionResponse
|
|
|
|
|
|
// input cache being used by this sequence
|
|
// input cache being used by this sequence
|
|
cache *InputCacheSlot
|
|
cache *InputCacheSlot
|
|
@@ -61,7 +61,7 @@ type Sequence struct {
|
|
crossAttention bool
|
|
crossAttention bool
|
|
|
|
|
|
// channel to send responses over
|
|
// channel to send responses over
|
|
- responses chan string
|
|
|
|
|
|
+ responses chan common.CompletionResponse
|
|
|
|
|
|
// channel to stop decoding (such as if the remote connection is closed)
|
|
// channel to stop decoding (such as if the remote connection is closed)
|
|
quit chan bool
|
|
quit chan bool
|
|
@@ -150,8 +150,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|
numPromptInputs: len(inputs),
|
|
numPromptInputs: len(inputs),
|
|
startProcessingTime: startTime,
|
|
startProcessingTime: startTime,
|
|
numPredict: params.numPredict,
|
|
numPredict: params.numPredict,
|
|
- pendingResponses: make([]string, 0),
|
|
|
|
- responses: make(chan string, 100),
|
|
|
|
|
|
+ pendingResponses: make([]common.CompletionResponse, 0),
|
|
|
|
+ responses: make(chan common.CompletionResponse, 100),
|
|
quit: make(chan bool, 1),
|
|
quit: make(chan bool, 1),
|
|
embedding: make(chan []float32, 1),
|
|
embedding: make(chan []float32, 1),
|
|
samplingCtx: sc,
|
|
samplingCtx: sc,
|
|
@@ -276,29 +276,28 @@ func (s *Server) allNil() bool {
|
|
}
|
|
}
|
|
|
|
|
|
func flushPending(seq *Sequence) bool {
|
|
func flushPending(seq *Sequence) bool {
|
|
- joined := strings.Join(seq.pendingResponses, "")
|
|
|
|
- seq.pendingResponses = []string{}
|
|
|
|
-
|
|
|
|
- // Check if there are any partial UTF-8 characters remaining.
|
|
|
|
- // We already check and queue as we are generating but some may
|
|
|
|
- // still make it here:
|
|
|
|
- // - Sequence is ending, e.g. generation limit has been hit
|
|
|
|
- // - Invalid characters in the middle of a string
|
|
|
|
- // This is a stricter check to ensure we never output invalid Unicode.
|
|
|
|
- for !utf8.ValidString(joined) {
|
|
|
|
- joined = joined[:len(joined)-1]
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if len(joined) == 0 {
|
|
|
|
- return true
|
|
|
|
- }
|
|
|
|
|
|
+ pending := seq.pendingResponses
|
|
|
|
+ seq.pendingResponses = []common.CompletionResponse{}
|
|
|
|
+
|
|
|
|
+ for i, r := range pending {
|
|
|
|
+ if i == len(pending)-1 {
|
|
|
|
+ // Check and trim any trailing partial UTF-8 characters
|
|
|
|
+ content := r.Content
|
|
|
|
+ for !utf8.ValidString(content) {
|
|
|
|
+ content = content[:len(content)-1]
|
|
|
|
+ }
|
|
|
|
+ r.Content = content
|
|
|
|
+ }
|
|
|
|
|
|
- select {
|
|
|
|
- case seq.responses <- joined:
|
|
|
|
- return true
|
|
|
|
- case <-seq.quit:
|
|
|
|
- return false
|
|
|
|
|
|
+ select {
|
|
|
|
+ case seq.responses <- r:
|
|
|
|
+ return true
|
|
|
|
+ case <-seq.quit:
|
|
|
|
+ return false
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
+ // no pending responses to send
|
|
|
|
+ return true
|
|
}
|
|
}
|
|
|
|
|
|
func (s *Server) removeSequence(seqIndex int, reason string) {
|
|
func (s *Server) removeSequence(seqIndex int, reason string) {
|
|
@@ -497,8 +496,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
|
|
|
|
seq.inputs = []input{{token: token}}
|
|
seq.inputs = []input{{token: token}}
|
|
|
|
|
|
- seq.pendingResponses = append(seq.pendingResponses, piece)
|
|
|
|
- sequence := strings.Join(seq.pendingResponses, "")
|
|
|
|
|
|
+ seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece})
|
|
|
|
+ sequence := ""
|
|
|
|
+ for _, r := range seq.pendingResponses {
|
|
|
|
+ sequence += r.Content
|
|
|
|
+ }
|
|
|
|
|
|
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
|
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
|
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
|
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|