|
@@ -315,20 +315,30 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float
|
|
type Batch struct {
|
|
type Batch struct {
|
|
c C.struct_llama_batch
|
|
c C.struct_llama_batch
|
|
batchSize int
|
|
batchSize int
|
|
|
|
+ maxSeq int
|
|
embedSize int
|
|
embedSize int
|
|
}
|
|
}
|
|
|
|
|
|
-// Creates a new batch for either word tokens if embed is 0 or
|
|
|
|
-// image embeddings if embed is specified. Batches cannot contain
|
|
|
|
-// both types at the same time
|
|
|
|
-func NewBatch(nTokens int, embed int, maxSeq int) *Batch {
|
|
|
|
|
|
+// Creates a new batch for either word tokens or image embeddings (if embedSize is non-zero).
|
|
|
|
+// Batches cannot contain both types at the same time. batchSize is the maximum number of entries
|
|
|
|
+// that can be added per sequence
|
|
|
|
+func NewBatch(batchSize int, maxSeq int, embedSize int) *Batch {
|
|
return &Batch{
|
|
return &Batch{
|
|
- c: C.llama_batch_init(C.int(nTokens), C.int(embed), C.int(maxSeq)),
|
|
|
|
- batchSize: nTokens,
|
|
|
|
- embedSize: embed,
|
|
|
|
|
|
+ c: C.llama_batch_init(C.int(batchSize*maxSeq), C.int(embedSize), C.int(maxSeq)),
|
|
|
|
+ batchSize: batchSize,
|
|
|
|
+ maxSeq: maxSeq,
|
|
|
|
+ embedSize: embedSize,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+func (b *Batch) Size() int {
|
|
|
|
+ return b.batchSize
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (b *Batch) allocSize() int {
|
|
|
|
+ return b.batchSize * b.maxSeq
|
|
|
|
+}
|
|
|
|
+
|
|
func (b *Batch) NumTokens() int {
|
|
func (b *Batch) NumTokens() int {
|
|
return int(b.c.n_tokens)
|
|
return int(b.c.n_tokens)
|
|
}
|
|
}
|
|
@@ -341,21 +351,21 @@ func (b *Batch) IsEmbedding() bool {
|
|
// when the batch was initialized. The other argument will be ignored. Adds to the
|
|
// when the batch was initialized. The other argument will be ignored. Adds to the
|
|
// batch with the given position for the given sequence ids, and optionally instructs
|
|
// batch with the given position for the given sequence ids, and optionally instructs
|
|
// to include logits.
|
|
// to include logits.
|
|
-func (b *Batch) Add(token int, embed []float32, pos int, seqIds []int, logits bool) {
|
|
|
|
|
|
+func (b *Batch) Add(token int, embed []float32, pos int, logits bool, seqIds ...int) {
|
|
if !b.IsEmbedding() {
|
|
if !b.IsEmbedding() {
|
|
- unsafe.Slice(b.c.token, b.batchSize)[b.c.n_tokens] = C.llama_token(token)
|
|
|
|
|
|
+ unsafe.Slice(b.c.token, b.allocSize())[b.c.n_tokens] = C.llama_token(token)
|
|
} else {
|
|
} else {
|
|
- copy(unsafe.Slice((*float32)(b.c.embd), b.batchSize*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
|
|
|
|
|
|
+ copy(unsafe.Slice((*float32)(b.c.embd), b.allocSize()*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
|
|
}
|
|
}
|
|
- unsafe.Slice(b.c.pos, b.batchSize)[b.c.n_tokens] = C.llama_pos(pos)
|
|
|
|
- unsafe.Slice(b.c.n_seq_id, b.batchSize)[b.c.n_tokens] = C.int(len(seqIds))
|
|
|
|
|
|
+ unsafe.Slice(b.c.pos, b.allocSize())[b.c.n_tokens] = C.llama_pos(pos)
|
|
|
|
+ unsafe.Slice(b.c.n_seq_id, b.allocSize())[b.c.n_tokens] = C.int(len(seqIds))
|
|
|
|
|
|
for i, s := range seqIds {
|
|
for i, s := range seqIds {
|
|
- unsafe.Slice((unsafe.Slice(b.c.seq_id, b.batchSize)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
|
|
|
|
|
|
+ unsafe.Slice((unsafe.Slice(b.c.seq_id, b.allocSize())[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
|
|
}
|
|
}
|
|
|
|
|
|
if logits {
|
|
if logits {
|
|
- unsafe.Slice(b.c.logits, b.batchSize)[b.c.n_tokens] = 1
|
|
|
|
|
|
+ unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 1
|
|
}
|
|
}
|
|
|
|
|
|
b.c.n_tokens += 1
|
|
b.c.n_tokens += 1
|