|
@@ -79,9 +79,11 @@ llama_token llama_sample(
|
|
import "C"
|
|
import "C"
|
|
import (
|
|
import (
|
|
"errors"
|
|
"errors"
|
|
|
|
+ "fmt"
|
|
"io"
|
|
"io"
|
|
"os"
|
|
"os"
|
|
"strings"
|
|
"strings"
|
|
|
|
+ "time"
|
|
"unsafe"
|
|
"unsafe"
|
|
|
|
|
|
"github.com/jmorganca/ollama/api"
|
|
"github.com/jmorganca/ollama/api"
|
|
@@ -147,7 +149,7 @@ func (llm *llama) Close() {
|
|
C.llama_print_timings(llm.ctx)
|
|
C.llama_print_timings(llm.ctx)
|
|
}
|
|
}
|
|
|
|
|
|
-func (llm *llama) Predict(prompt string, fn func(string)) error {
|
|
|
|
|
|
+func (llm *llama) Predict(prompt string, fn func(api.GenerateResponse)) error {
|
|
if tokens := llm.tokenize(prompt); tokens != nil {
|
|
if tokens := llm.tokenize(prompt); tokens != nil {
|
|
return llm.generate(tokens, fn)
|
|
return llm.generate(tokens, fn)
|
|
}
|
|
}
|
|
@@ -176,7 +178,7 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string {
|
|
return sb.String()
|
|
return sb.String()
|
|
}
|
|
}
|
|
|
|
|
|
-func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error {
|
|
|
|
|
|
+func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
|
|
var opts C.struct_llama_sample_options
|
|
var opts C.struct_llama_sample_options
|
|
opts.repeat_penalty = C.float(llm.RepeatPenalty)
|
|
opts.repeat_penalty = C.float(llm.RepeatPenalty)
|
|
opts.frequency_penalty = C.float(llm.FrequencyPenalty)
|
|
opts.frequency_penalty = C.float(llm.FrequencyPenalty)
|
|
@@ -190,38 +192,58 @@ func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error {
|
|
opts.mirostat_tau = C.float(llm.MirostatTau)
|
|
opts.mirostat_tau = C.float(llm.MirostatTau)
|
|
opts.mirostat_eta = C.float(llm.MirostatEta)
|
|
opts.mirostat_eta = C.float(llm.MirostatEta)
|
|
|
|
|
|
- pastTokens := deque[C.llama_token]{capacity: llm.RepeatLastN}
|
|
|
|
|
|
+ output := deque[C.llama_token]{capacity: llm.NumCtx}
|
|
|
|
|
|
for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
|
|
for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
|
|
- if retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
|
|
|
|
|
|
+ if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
|
|
return errors.New("llama: eval")
|
|
return errors.New("llama: eval")
|
|
}
|
|
}
|
|
|
|
|
|
- token, err := llm.sample(pastTokens, &opts)
|
|
|
|
- switch {
|
|
|
|
- case errors.Is(err, io.EOF):
|
|
|
|
- return nil
|
|
|
|
- case err != nil:
|
|
|
|
|
|
+ token, err := llm.sample(output, &opts)
|
|
|
|
+ if errors.Is(err, io.EOF) {
|
|
|
|
+ break
|
|
|
|
+ } else if err != nil {
|
|
return err
|
|
return err
|
|
}
|
|
}
|
|
|
|
|
|
- fn(llm.detokenize(token))
|
|
|
|
|
|
+ // call the callback
|
|
|
|
+ fn(api.GenerateResponse{
|
|
|
|
+ Response: llm.detokenize(token),
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ output.PushLeft(token)
|
|
|
|
+
|
|
|
|
+ input = []C.llama_token{token}
|
|
|
|
+ }
|
|
|
|
|
|
- tokens = []C.llama_token{token}
|
|
|
|
|
|
+ dur := func(ms float64) time.Duration {
|
|
|
|
+ d, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
|
|
|
+ if err != nil {
|
|
|
|
+ panic(err)
|
|
|
|
+ }
|
|
|
|
|
|
- pastTokens.PushLeft(token)
|
|
|
|
|
|
+ return d
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ timings := C.llama_get_timings(llm.ctx)
|
|
|
|
+ fn(api.GenerateResponse{
|
|
|
|
+ Done: true,
|
|
|
|
+ PromptEvalCount: int(timings.n_p_eval),
|
|
|
|
+ PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
|
|
|
|
+ EvalCount: int(timings.n_eval),
|
|
|
|
+ EvalDuration: dur(float64(timings.t_eval_ms)),
|
|
|
|
+ })
|
|
|
|
+
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
-func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
|
|
|
|
|
|
+func (llm *llama) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
|
|
numVocab := int(C.llama_n_vocab(llm.ctx))
|
|
numVocab := int(C.llama_n_vocab(llm.ctx))
|
|
logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
|
|
logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
|
|
|
|
|
|
- candidates := make([]C.struct_llama_token_data, 0, numVocab)
|
|
|
|
- for i := 0; i < numVocab; i++ {
|
|
|
|
- candidates = append(candidates, C.llama_token_data{
|
|
|
|
|
|
+ candidates := deque[C.struct_llama_token_data]{capacity: numVocab}
|
|
|
|
+ for i := 0; i < candidates.Cap(); i++ {
|
|
|
|
+ candidates.PushLeft(C.struct_llama_token_data{
|
|
id: C.int(i),
|
|
id: C.int(i),
|
|
logit: logits[i],
|
|
logit: logits[i],
|
|
p: 0,
|
|
p: 0,
|
|
@@ -230,8 +252,8 @@ func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_s
|
|
|
|
|
|
token := C.llama_sample(
|
|
token := C.llama_sample(
|
|
llm.ctx,
|
|
llm.ctx,
|
|
- unsafe.SliceData(candidates), C.ulong(len(candidates)),
|
|
|
|
- unsafe.SliceData(pastTokens.Data()), C.ulong(pastTokens.Len()),
|
|
|
|
|
|
+ unsafe.SliceData(candidates.Data()), C.ulong(candidates.Len()),
|
|
|
|
+ unsafe.SliceData(output.Data()), C.ulong(output.Len()),
|
|
opts)
|
|
opts)
|
|
if token != C.llama_token_eos() {
|
|
if token != C.llama_token_eos() {
|
|
return token, nil
|
|
return token, nil
|