|
@@ -91,7 +91,7 @@ import (
|
|
|
"github.com/jmorganca/ollama/api"
|
|
|
)
|
|
|
|
|
|
-type llama struct {
|
|
|
+type LLM struct {
|
|
|
params *C.struct_llama_context_params
|
|
|
model *C.struct_llama_model
|
|
|
ctx *C.struct_llama_context
|
|
@@ -99,12 +99,12 @@ type llama struct {
|
|
|
api.Options
|
|
|
}
|
|
|
|
|
|
-func New(model string, opts api.Options) (*llama, error) {
|
|
|
+func New(model string, opts api.Options) (*LLM, error) {
|
|
|
if _, err := os.Stat(model); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
- llm := llama{Options: opts}
|
|
|
+ llm := LLM{Options: opts}
|
|
|
|
|
|
C.llama_backend_init(C.bool(llm.UseNUMA))
|
|
|
|
|
@@ -144,14 +144,14 @@ func New(model string, opts api.Options) (*llama, error) {
|
|
|
return &llm, nil
|
|
|
}
|
|
|
|
|
|
-func (llm *llama) Close() {
|
|
|
+func (llm *LLM) Close() {
|
|
|
defer C.llama_free_model(llm.model)
|
|
|
defer C.llama_free(llm.ctx)
|
|
|
|
|
|
C.llama_print_timings(llm.ctx)
|
|
|
}
|
|
|
|
|
|
-func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
|
|
+func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
|
|
if input := llm.tokenize(prompt); input != nil {
|
|
|
embd := make([]C.llama_token, len(ctx))
|
|
|
for i := range ctx {
|
|
@@ -164,7 +164,7 @@ func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse
|
|
|
return errors.New("llama: tokenize")
|
|
|
}
|
|
|
|
|
|
-func (llm *llama) tokenize(prompt string) []C.llama_token {
|
|
|
+func (llm *LLM) tokenize(prompt string) []C.llama_token {
|
|
|
cPrompt := C.CString(prompt)
|
|
|
defer C.free(unsafe.Pointer(cPrompt))
|
|
|
|
|
@@ -176,7 +176,7 @@ func (llm *llama) tokenize(prompt string) []C.llama_token {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (llm *llama) detokenize(tokens ...C.llama_token) string {
|
|
|
+func (llm *LLM) detokenize(tokens ...C.llama_token) string {
|
|
|
var sb strings.Builder
|
|
|
for _, token := range tokens {
|
|
|
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))
|
|
@@ -185,7 +185,7 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string {
|
|
|
return sb.String()
|
|
|
}
|
|
|
|
|
|
-func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
|
|
|
+func (llm *LLM) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
|
|
|
var opts C.struct_llama_sample_options
|
|
|
opts.repeat_penalty = C.float(llm.RepeatPenalty)
|
|
|
opts.frequency_penalty = C.float(llm.FrequencyPenalty)
|
|
@@ -256,7 +256,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (llm *llama) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
|
|
|
+func (llm *LLM) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
|
|
|
numVocab := int(C.llama_n_vocab(llm.ctx))
|
|
|
logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
|
|
|
|