瀏覽代碼

runner.go: Only allocate 1 element embedding batches for mllama

Mllama has large embeddings (100 MB per image) and each embedding is
represented as 1 token when passed to llama.cpp. Batches are pre-
allocated for the size of the tokens times the batch size, so this
results in allocations of over 50 GB at the default batch size.
On some systems, these mallocs will fail.

Since an image is represented as a single token and mllama doesn't
support more than 1 image per request, we only need to allocate a
batch size of 1, which is much more reasonable. In addition, for
non-multimodal models, we don't need to allocate the embedding
batches at all.

Fixes #7464
Jesse Gross 6 月之前
父節點
當前提交
a103dae01e
共有 3 個文件被更改,包括 54 次插入21 次删除
  1. 24 14
      llama/llama.go
  2. 17 0
      llama/runner/image.go
  3. 13 7
      llama/runner/runner.go

+ 24 - 14
llama/llama.go

@@ -315,20 +315,30 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float
 type Batch struct {
 type Batch struct {
 	c         C.struct_llama_batch
 	c         C.struct_llama_batch
 	batchSize int
 	batchSize int
+	maxSeq    int
 	embedSize int
 	embedSize int
 }
 }
 
 
-// Creates a new batch for either word tokens if embed is 0 or
-// image embeddings if embed is specified. Batches cannot contain
-// both types at the same time
-func NewBatch(nTokens int, embed int, maxSeq int) *Batch {
+// Creates a new batch for either word tokens or image embeddings (if embedSize is non-zero).
+// Batches cannot contain both types at the same time. batchSize is the maximum number of entries
+// that can be added per sequence
+func NewBatch(batchSize int, maxSeq int, embedSize int) *Batch {
 	return &Batch{
 	return &Batch{
-		c:         C.llama_batch_init(C.int(nTokens), C.int(embed), C.int(maxSeq)),
-		batchSize: nTokens,
-		embedSize: embed,
+		c:         C.llama_batch_init(C.int(batchSize*maxSeq), C.int(embedSize), C.int(maxSeq)),
+		batchSize: batchSize,
+		maxSeq:    maxSeq,
+		embedSize: embedSize,
 	}
 	}
 }
 }
 
 
+func (b *Batch) Size() int {
+	return b.batchSize
+}
+
+func (b *Batch) allocSize() int {
+	return b.batchSize * b.maxSeq
+}
+
 func (b *Batch) NumTokens() int {
 func (b *Batch) NumTokens() int {
 	return int(b.c.n_tokens)
 	return int(b.c.n_tokens)
 }
 }
@@ -341,21 +351,21 @@ func (b *Batch) IsEmbedding() bool {
 // when the batch was initialized. The other argument will be ignored. Adds to the
 // when the batch was initialized. The other argument will be ignored. Adds to the
 // batch with the given position for the given sequence ids, and optionally instructs
 // batch with the given position for the given sequence ids, and optionally instructs
 // to include logits.
 // to include logits.
-func (b *Batch) Add(token int, embed []float32, pos int, seqIds []int, logits bool) {
+func (b *Batch) Add(token int, embed []float32, pos int, logits bool, seqIds ...int) {
 	if !b.IsEmbedding() {
 	if !b.IsEmbedding() {
-		unsafe.Slice(b.c.token, b.batchSize)[b.c.n_tokens] = C.llama_token(token)
+		unsafe.Slice(b.c.token, b.allocSize())[b.c.n_tokens] = C.llama_token(token)
 	} else {
 	} else {
-		copy(unsafe.Slice((*float32)(b.c.embd), b.batchSize*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
+		copy(unsafe.Slice((*float32)(b.c.embd), b.allocSize()*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
 	}
 	}
-	unsafe.Slice(b.c.pos, b.batchSize)[b.c.n_tokens] = C.llama_pos(pos)
-	unsafe.Slice(b.c.n_seq_id, b.batchSize)[b.c.n_tokens] = C.int(len(seqIds))
+	unsafe.Slice(b.c.pos, b.allocSize())[b.c.n_tokens] = C.llama_pos(pos)
+	unsafe.Slice(b.c.n_seq_id, b.allocSize())[b.c.n_tokens] = C.int(len(seqIds))
 
 
 	for i, s := range seqIds {
 	for i, s := range seqIds {
-		unsafe.Slice((unsafe.Slice(b.c.seq_id, b.batchSize)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
+		unsafe.Slice((unsafe.Slice(b.c.seq_id, b.allocSize())[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
 	}
 	}
 
 
 	if logits {
 	if logits {
-		unsafe.Slice(b.c.logits, b.batchSize)[b.c.n_tokens] = 1
+		unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 1
 	}
 	}
 
 
 	b.c.n_tokens += 1
 	b.c.n_tokens += 1

+ 17 - 0
llama/runner/image.go

@@ -89,6 +89,23 @@ func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte, aspect
 	return embed
 	return embed
 }
 }
 
 
+func (c *ImageContext) BatchSize(configuredBatchSize int) int {
+	// If images are not supported, we don't need to allocate embedding batches
+	if c == nil {
+		return 0
+	}
+
+	// Mllama maps an image to 1 embedding token (llava creates many tokens)
+	// and doesn't support more than a single image per request.
+	// The embeddings are large (100 MB), so allocating a big batch can fail
+	// on some systems
+	if c.mllama != nil {
+		return 1
+	}
+
+	return configuredBatchSize
+}
+
 func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
 func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
 	if c != nil && c.mllama != nil {
 	if c != nil && c.mllama != nil {
 		return c.mllama.EmbedSize(llamaContext)
 		return c.mllama.EmbedSize(llamaContext)

+ 13 - 7
llama/runner/runner.go

@@ -211,6 +211,7 @@ type Server struct {
 	// required for image embeddings
 	// required for image embeddings
 	image *ImageContext
 	image *ImageContext
 
 
+	// TODO (jmorganca): make this n_batch
 	batchSize int
 	batchSize int
 
 
 	// parallel is the number of parallel requests to handle
 	// parallel is the number of parallel requests to handle
@@ -302,13 +303,19 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
 func (s *Server) run(ctx context.Context) {
 func (s *Server) run(ctx context.Context) {
 	s.ready.Wait()
 	s.ready.Wait()
 
 
-	// logically these batches are used only within the context of processBatch
+	// Logically these batches are used only within the context of processBatch
 	// but it is better for performance to allocate them once here
 	// but it is better for performance to allocate them once here
-	tokenBatch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
+	tokenBatch := llama.NewBatch(s.batchSize, len(s.seqs), 0)
 	defer tokenBatch.Free()
 	defer tokenBatch.Free()
 
 
-	embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.image.EmbedSize(s.lc), len(s.seqs))
-	defer embedBatch.Free()
+	var embedBatch *llama.Batch
+	embedBatchSize := s.image.BatchSize(s.batchSize)
+	if embedBatchSize != 0 {
+		embedBatch = llama.NewBatch(embedBatchSize, len(s.seqs), s.image.EmbedSize(s.lc))
+		defer embedBatch.Free()
+	} else {
+		embedBatch = &llama.Batch{}
+	}
 
 
 	for {
 	for {
 		select {
 		select {
@@ -378,13 +385,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 				break
 				break
 			}
 			}
 
 
-			// todo: make this n_batch
-			if i >= s.batchSize {
+			if i >= batch.Size() {
 				break
 				break
 			}
 			}
 
 
 			crossAttention = seq.crossAttention
 			crossAttention = seq.crossAttention
-			batch.Add(input.token, input.embed, seq.numPast, []int{seq.cache.Id}, numInputsProcessed+1 == len(seq.inputs))
+			batch.Add(input.token, input.embed, seq.numPast, numInputsProcessed+1 == len(seq.inputs), seq.cache.Id)
 			seq.numPast++
 			seq.numPast++
 			numInputsProcessed++
 			numInputsProcessed++
 		}
 		}