|
@@ -99,26 +99,24 @@ func (c *Context) Model() *Model {
|
|
return &Model{c: C.llama_get_model(c.c)}
|
|
return &Model{c: C.llama_get_model(c.c)}
|
|
}
|
|
}
|
|
|
|
|
|
-// TODO: break this up
|
|
|
|
-func (c *Context) SampleTokenGreedy(batch Batch, i int) int {
|
|
|
|
- nv := c.Model().NumVocab()
|
|
|
|
|
|
+func (c *Context) GetLogitsIth(i int) []float32 {
|
|
|
|
+ return unsafe.Slice((*float32)(unsafe.Pointer(C.llama_get_logits_ith(c.c, C.int(i)))), c.Model().NumVocab())
|
|
|
|
+}
|
|
|
|
|
|
- // TODO(jmorganca): split this up into different functions
|
|
|
|
- candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(nv) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{}))))
|
|
|
|
|
|
+func (c *Context) SampleTokenGreedy(logits []float32) int {
|
|
|
|
+ candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(len(logits)) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{}))))
|
|
defer C.free(unsafe.Pointer(candidates))
|
|
defer C.free(unsafe.Pointer(candidates))
|
|
|
|
|
|
- // get most recent logits
|
|
|
|
- logits := C.llama_get_logits_ith(c.c, C.int(i))
|
|
|
|
- for i := 0; i < int(nv); i++ {
|
|
|
|
|
|
+ for i, logit := range logits {
|
|
ptr := (*C.struct_llama_token_data)(unsafe.Pointer(uintptr(unsafe.Pointer(candidates)) + uintptr(i)*unsafe.Sizeof(C.struct_llama_token_data{})))
|
|
ptr := (*C.struct_llama_token_data)(unsafe.Pointer(uintptr(unsafe.Pointer(candidates)) + uintptr(i)*unsafe.Sizeof(C.struct_llama_token_data{})))
|
|
ptr.id = C.int(i)
|
|
ptr.id = C.int(i)
|
|
- ptr.logit = unsafe.Slice(logits, nv)[i]
|
|
|
|
|
|
+ ptr.logit = C.float(logit)
|
|
ptr.p = 0.0
|
|
ptr.p = 0.0
|
|
}
|
|
}
|
|
|
|
|
|
return int(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{
|
|
return int(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{
|
|
data: candidates,
|
|
data: candidates,
|
|
- size: C.size_t(nv),
|
|
|
|
|
|
+ size: C.size_t(len(logits)),
|
|
sorted: C.bool(false),
|
|
sorted: C.bool(false),
|
|
}))
|
|
}))
|
|
}
|
|
}
|
|
@@ -155,6 +153,8 @@ func (b *Batch) NumTokens() int {
|
|
return int(b.c.n_tokens)
|
|
return int(b.c.n_tokens)
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+// Add adds a token to the batch with the given position for the given
|
|
|
|
+// sequence ids, and optionally instructs to include logits.
|
|
func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) {
|
|
func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) {
|
|
unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token)
|
|
unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token)
|
|
unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos)
|
|
unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos)
|
|
@@ -179,12 +179,6 @@ func (b *Batch) Free() {
|
|
C.llama_batch_free(b.c)
|
|
C.llama_batch_free(b.c)
|
|
}
|
|
}
|
|
|
|
|
|
-// LLAMA_API struct llama_batch llama_batch_get_one(
|
|
|
|
-//
|
|
|
|
-// llama_token * tokens,
|
|
|
|
-// int32_t n_tokens,
|
|
|
|
-// llama_pos pos_0,
|
|
|
|
-// llama_seq_id seq_id);
|
|
|
|
func BatchGetOne(tokens []int, pos0 int, seqId int) Batch {
|
|
func BatchGetOne(tokens []int, pos0 int, seqId int) Batch {
|
|
return Batch{c: C.llama_batch_get_one((*C.int)(unsafe.Pointer(&tokens[0])), C.int32_t(len(tokens)), C.int(pos0), C.int(seqId))}
|
|
return Batch{c: C.llama_batch_get_one((*C.int)(unsafe.Pointer(&tokens[0])), C.int32_t(len(tokens)), C.int(pos0), C.int(seqId))}
|
|
}
|
|
}
|