Browse Source

llama.go: Make batch memory allocation match configuration

Batch size defaults to 512 but is configurable. However, llama.go uses
a fixed size buffer, causing crashes is the batch size is increase.
This changes the array size to follow the configuration.
Jesse Gross 8 months ago
parent
commit
ed19fad862
1 changed files with 12 additions and 7 deletions
  1. 12 7
      llama/llama.go

+ 12 - 7
llama/llama.go

@@ -209,11 +209,15 @@ func (m *Model) ApplyLoraFromFile(loraPath string, scale float32, baseModelPath
 }
 }
 
 
 type Batch struct {
 type Batch struct {
-	c C.struct_llama_batch
+	c         C.struct_llama_batch
+	batchSize int
 }
 }
 
 
 func NewBatch(nTokens int, embd int, maxSeq int) Batch {
 func NewBatch(nTokens int, embd int, maxSeq int) Batch {
-	return Batch{c: C.llama_batch_init(C.int(nTokens), C.int(embd), C.int(maxSeq))}
+	return Batch{
+		c:         C.llama_batch_init(C.int(nTokens), C.int(embd), C.int(maxSeq)),
+		batchSize: nTokens,
+	}
 }
 }
 
 
 func (b *Batch) NumTokens() int {
 func (b *Batch) NumTokens() int {
@@ -223,16 +227,16 @@ func (b *Batch) NumTokens() int {
 // Add adds a token to the batch with the given position for the given
 // Add adds a token to the batch with the given position for the given
 // sequence ids, and optionally instructs to include logits.
 // sequence ids, and optionally instructs to include logits.
 func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) {
 func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) {
-	unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token)
-	unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos)
-	unsafe.Slice(b.c.n_seq_id, 512)[b.c.n_tokens] = C.int(len(seqIds))
+	unsafe.Slice(b.c.token, b.batchSize)[b.c.n_tokens] = C.llama_token(token)
+	unsafe.Slice(b.c.pos, b.batchSize)[b.c.n_tokens] = C.llama_pos(pos)
+	unsafe.Slice(b.c.n_seq_id, b.batchSize)[b.c.n_tokens] = C.int(len(seqIds))
 
 
 	for i, s := range seqIds {
 	for i, s := range seqIds {
-		unsafe.Slice((unsafe.Slice(b.c.seq_id, 512)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
+		unsafe.Slice((unsafe.Slice(b.c.seq_id, b.batchSize)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
 	}
 	}
 
 
 	if logits {
 	if logits {
-		unsafe.Slice(b.c.logits, 512)[b.c.n_tokens] = 1
+		unsafe.Slice(b.c.logits, b.batchSize)[b.c.n_tokens] = 1
 	}
 	}
 
 
 	b.c.n_tokens += 1
 	b.c.n_tokens += 1
@@ -243,6 +247,7 @@ func (b *Batch) Clear() {
 }
 }
 
 
 func (b *Batch) Free() {
 func (b *Batch) Free() {
+	b.batchSize = 0
 	C.llama_batch_free(b.c)
 	C.llama_batch_free(b.c)
 }
 }