|
@@ -149,9 +149,14 @@ func (llm *llama) Close() {
|
|
|
C.llama_print_timings(llm.ctx)
|
|
|
}
|
|
|
|
|
|
-func (llm *llama) Predict(prompt string, fn func(api.GenerateResponse)) error {
|
|
|
- if tokens := llm.tokenize(prompt); tokens != nil {
|
|
|
- return llm.generate(tokens, fn)
|
|
|
+func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
|
|
+ if input := llm.tokenize(prompt); input != nil {
|
|
|
+ embd := make([]C.llama_token, len(ctx))
|
|
|
+ for i := range ctx {
|
|
|
+ embd[i] = C.llama_token(ctx[i])
|
|
|
+ }
|
|
|
+
|
|
|
+ return llm.generate(append(embd, input...), fn)
|
|
|
}
|
|
|
|
|
|
return errors.New("llama: tokenize")
|
|
@@ -194,6 +199,11 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
|
|
|
|
|
|
output := deque[C.llama_token]{capacity: llm.NumCtx}
|
|
|
|
|
|
+ context := deque[int]{capacity: llm.NumCtx / 2}
|
|
|
+ for _, in := range input {
|
|
|
+ context.PushLeft(int(in))
|
|
|
+ }
|
|
|
+
|
|
|
for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
|
|
|
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
|
|
|
return errors.New("llama: eval")
|
|
@@ -212,6 +222,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
|
|
|
})
|
|
|
|
|
|
output.PushLeft(token)
|
|
|
+ context.PushLeft(int(token))
|
|
|
|
|
|
input = []C.llama_token{token}
|
|
|
}
|
|
@@ -228,6 +239,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
|
|
|
timings := C.llama_get_timings(llm.ctx)
|
|
|
fn(api.GenerateResponse{
|
|
|
Done: true,
|
|
|
+ Context: context.Data(),
|
|
|
PromptEvalCount: int(timings.n_p_eval),
|
|
|
PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
|
|
|
EvalCount: int(timings.n_eval),
|