|
@@ -52,6 +52,10 @@ type Sequence struct {
|
|
// input cache being used by this sequence
|
|
// input cache being used by this sequence
|
|
cache *InputCacheSlot
|
|
cache *InputCacheSlot
|
|
|
|
|
|
|
|
+ // does this sequence require cross-attention layers to be processed? - if we have seen
|
|
|
|
+ // an image for certain multi-modal models
|
|
|
|
+ crossAttention bool
|
|
|
|
+
|
|
// channel to send responses over
|
|
// channel to send responses over
|
|
responses chan string
|
|
responses chan string
|
|
|
|
|
|
@@ -287,7 +291,6 @@ func flushPending(seq *Sequence) bool {
|
|
func (s *Server) removeSequence(seqIndex int, reason string) {
|
|
func (s *Server) removeSequence(seqIndex int, reason string) {
|
|
seq := s.seqs[seqIndex]
|
|
seq := s.seqs[seqIndex]
|
|
|
|
|
|
- s.lc.SetCrossAttention(false)
|
|
|
|
flushPending(seq)
|
|
flushPending(seq)
|
|
seq.doneReason = reason
|
|
seq.doneReason = reason
|
|
close(seq.responses)
|
|
close(seq.responses)
|
|
@@ -334,6 +337,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
defer s.mu.Unlock()
|
|
defer s.mu.Unlock()
|
|
|
|
|
|
var batch *llama.Batch
|
|
var batch *llama.Batch
|
|
|
|
+ crossAttention := false
|
|
|
|
|
|
seqIdx := s.nextSeq - 1
|
|
seqIdx := s.nextSeq - 1
|
|
for range s.seqs {
|
|
for range s.seqs {
|
|
@@ -367,8 +371,9 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
batch = tokenBatch
|
|
batch = tokenBatch
|
|
} else {
|
|
} else {
|
|
batch = embedBatch
|
|
batch = embedBatch
|
|
|
|
+ seq.crossAttention = s.image.NeedCrossAttention(input)
|
|
}
|
|
}
|
|
- } else if embedding != batch.IsEmbedding() {
|
|
|
|
|
|
+ } else if embedding != batch.IsEmbedding() || crossAttention != seq.crossAttention {
|
|
s.nextSeq = seqIdx
|
|
s.nextSeq = seqIdx
|
|
break
|
|
break
|
|
}
|
|
}
|
|
@@ -378,6 +383,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
break
|
|
break
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ crossAttention = seq.crossAttention
|
|
batch.Add(input.token, input.embed, seq.numPast, []int{seq.cache.Id}, numInputsProcessed+1 == len(seq.inputs))
|
|
batch.Add(input.token, input.embed, seq.numPast, []int{seq.cache.Id}, numInputsProcessed+1 == len(seq.inputs))
|
|
seq.numPast++
|
|
seq.numPast++
|
|
numInputsProcessed++
|
|
numInputsProcessed++
|
|
@@ -394,6 +400,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ s.lc.SetCrossAttention(crossAttention)
|
|
|
|
+
|
|
err := s.lc.Decode(batch)
|
|
err := s.lc.Decode(batch)
|
|
if err != nil {
|
|
if err != nil {
|
|
slog.Error("failed to decode batch", "error", err)
|
|
slog.Error("failed to decode batch", "error", err)
|
|
@@ -605,13 +613,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
s.mu.Lock()
|
|
s.mu.Lock()
|
|
for i, sq := range s.seqs {
|
|
for i, sq := range s.seqs {
|
|
if sq == nil {
|
|
if sq == nil {
|
|
- for _, input := range seq.inputs {
|
|
|
|
- if input.embed != nil {
|
|
|
|
- s.lc.SetCrossAttention(true)
|
|
|
|
- break
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
|
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
|
if err != nil {
|
|
if err != nil {
|
|
s.mu.Unlock()
|
|
s.mu.Unlock()
|
|
@@ -619,6 +620,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ seq.crossAttention = s.image.NeedCrossAttention(seq.cache.Inputs...)
|
|
|
|
+
|
|
s.seqs[i] = seq
|
|
s.seqs[i] = seq
|
|
s.cond.Signal()
|
|
s.cond.Signal()
|
|
break
|
|
break
|