|
@@ -209,11 +209,15 @@ func (m *Model) ApplyLoraFromFile(loraPath string, scale float32, baseModelPath
|
|
}
|
|
}
|
|
|
|
|
|
type Batch struct {
|
|
type Batch struct {
|
|
- c C.struct_llama_batch
|
|
|
|
|
|
+ c C.struct_llama_batch
|
|
|
|
+ batchSize int
|
|
}
|
|
}
|
|
|
|
|
|
func NewBatch(nTokens int, embd int, maxSeq int) Batch {
|
|
func NewBatch(nTokens int, embd int, maxSeq int) Batch {
|
|
- return Batch{c: C.llama_batch_init(C.int(nTokens), C.int(embd), C.int(maxSeq))}
|
|
|
|
|
|
+ return Batch{
|
|
|
|
+ c: C.llama_batch_init(C.int(nTokens), C.int(embd), C.int(maxSeq)),
|
|
|
|
+ batchSize: nTokens,
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
func (b *Batch) NumTokens() int {
|
|
func (b *Batch) NumTokens() int {
|
|
@@ -223,16 +227,16 @@ func (b *Batch) NumTokens() int {
|
|
// Add adds a token to the batch with the given position for the given
|
|
// Add adds a token to the batch with the given position for the given
|
|
// sequence ids, and optionally instructs to include logits.
|
|
// sequence ids, and optionally instructs to include logits.
|
|
func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) {
|
|
func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) {
|
|
- unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token)
|
|
|
|
- unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos)
|
|
|
|
- unsafe.Slice(b.c.n_seq_id, 512)[b.c.n_tokens] = C.int(len(seqIds))
|
|
|
|
|
|
+ unsafe.Slice(b.c.token, b.batchSize)[b.c.n_tokens] = C.llama_token(token)
|
|
|
|
+ unsafe.Slice(b.c.pos, b.batchSize)[b.c.n_tokens] = C.llama_pos(pos)
|
|
|
|
+ unsafe.Slice(b.c.n_seq_id, b.batchSize)[b.c.n_tokens] = C.int(len(seqIds))
|
|
|
|
|
|
for i, s := range seqIds {
|
|
for i, s := range seqIds {
|
|
- unsafe.Slice((unsafe.Slice(b.c.seq_id, 512)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
|
|
|
|
|
|
+ unsafe.Slice((unsafe.Slice(b.c.seq_id, b.batchSize)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
|
|
}
|
|
}
|
|
|
|
|
|
if logits {
|
|
if logits {
|
|
- unsafe.Slice(b.c.logits, 512)[b.c.n_tokens] = 1
|
|
|
|
|
|
+ unsafe.Slice(b.c.logits, b.batchSize)[b.c.n_tokens] = 1
|
|
}
|
|
}
|
|
|
|
|
|
b.c.n_tokens += 1
|
|
b.c.n_tokens += 1
|
|
@@ -243,6 +247,7 @@ func (b *Batch) Clear() {
|
|
}
|
|
}
|
|
|
|
|
|
func (b *Batch) Free() {
|
|
func (b *Batch) Free() {
|
|
|
|
+ b.batchSize = 0
|
|
C.llama_batch_free(b.c)
|
|
C.llama_batch_free(b.c)
|
|
}
|
|
}
|
|
|
|
|