Sfoglia il codice sorgente

runner.go: Don't set cross attention before sending embeddings

Currently if an input has embeddings at any point then we will set
cross attention to true from the beginning. This means that any
tokens before the embeddings are sent will incorrectly have cross
attention layers applied.

This only sets cross attention when we have an embedding, either
previously in this sequence or in the cache. It also makes cross
attention capable of supporting parallelism at the runner level,
though the mllama implementation doesn't support that yet.
Jesse Gross 6 mesi fa
parent
commit
26acdcf44e
2 ha cambiato i file con 23 aggiunte e 9 eliminazioni
  1. 11 0
      llama/runner/image.go
  2. 12 9
      llama/runner/runner.go

+ 11 - 0
llama/runner/image.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"fmt"
 	"hash/maphash"
 	"hash/maphash"
 	"log/slog"
 	"log/slog"
+	"slices"
 	"sync"
 	"sync"
 	"time"
 	"time"
 
 
@@ -96,6 +97,16 @@ func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
 	}
 	}
 }
 }
 
 
+func (c *ImageContext) NeedCrossAttention(inputs ...input) bool {
+	if c == nil || c.mllama == nil {
+		return false
+	}
+
+	return slices.ContainsFunc(inputs, func(input input) bool {
+		return input.embed != nil
+	})
+}
+
 type imageCache struct {
 type imageCache struct {
 	key      uint64
 	key      uint64
 	val      [][]float32
 	val      [][]float32

+ 12 - 9
llama/runner/runner.go

@@ -52,6 +52,10 @@ type Sequence struct {
 	// input cache being used by this sequence
 	// input cache being used by this sequence
 	cache *InputCacheSlot
 	cache *InputCacheSlot
 
 
+	// does this sequence require cross-attention layers to be processed? - if we have seen
+	// an image for certain multi-modal models
+	crossAttention bool
+
 	// channel to send responses over
 	// channel to send responses over
 	responses chan string
 	responses chan string
 
 
@@ -287,7 +291,6 @@ func flushPending(seq *Sequence) bool {
 func (s *Server) removeSequence(seqIndex int, reason string) {
 func (s *Server) removeSequence(seqIndex int, reason string) {
 	seq := s.seqs[seqIndex]
 	seq := s.seqs[seqIndex]
 
 
-	s.lc.SetCrossAttention(false)
 	flushPending(seq)
 	flushPending(seq)
 	seq.doneReason = reason
 	seq.doneReason = reason
 	close(seq.responses)
 	close(seq.responses)
@@ -334,6 +337,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 	defer s.mu.Unlock()
 	defer s.mu.Unlock()
 
 
 	var batch *llama.Batch
 	var batch *llama.Batch
+	crossAttention := false
 
 
 	seqIdx := s.nextSeq - 1
 	seqIdx := s.nextSeq - 1
 	for range s.seqs {
 	for range s.seqs {
@@ -367,8 +371,9 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 					batch = tokenBatch
 					batch = tokenBatch
 				} else {
 				} else {
 					batch = embedBatch
 					batch = embedBatch
+					seq.crossAttention = s.image.NeedCrossAttention(input)
 				}
 				}
-			} else if embedding != batch.IsEmbedding() {
+			} else if embedding != batch.IsEmbedding() || crossAttention != seq.crossAttention {
 				s.nextSeq = seqIdx
 				s.nextSeq = seqIdx
 				break
 				break
 			}
 			}
@@ -378,6 +383,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 				break
 				break
 			}
 			}
 
 
+			crossAttention = seq.crossAttention
 			batch.Add(input.token, input.embed, seq.numPast, []int{seq.cache.Id}, numInputsProcessed+1 == len(seq.inputs))
 			batch.Add(input.token, input.embed, seq.numPast, []int{seq.cache.Id}, numInputsProcessed+1 == len(seq.inputs))
 			seq.numPast++
 			seq.numPast++
 			numInputsProcessed++
 			numInputsProcessed++
@@ -394,6 +400,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 		return
 		return
 	}
 	}
 
 
+	s.lc.SetCrossAttention(crossAttention)
+
 	err := s.lc.Decode(batch)
 	err := s.lc.Decode(batch)
 	if err != nil {
 	if err != nil {
 		slog.Error("failed to decode batch", "error", err)
 		slog.Error("failed to decode batch", "error", err)
@@ -605,13 +613,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	s.mu.Lock()
 	s.mu.Lock()
 	for i, sq := range s.seqs {
 	for i, sq := range s.seqs {
 		if sq == nil {
 		if sq == nil {
-			for _, input := range seq.inputs {
-				if input.embed != nil {
-					s.lc.SetCrossAttention(true)
-					break
-				}
-			}
-
 			seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
 			seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
 			if err != nil {
 			if err != nil {
 				s.mu.Unlock()
 				s.mu.Unlock()
@@ -619,6 +620,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 				return
 				return
 			}
 			}
 
 
+			seq.crossAttention = s.image.NeedCrossAttention(seq.cache.Inputs...)
+
 			s.seqs[i] = seq
 			s.seqs[i] = seq
 			s.cond.Signal()
 			s.cond.Signal()
 			break
 			break