|
@@ -1,4 +1,4 @@
|
|
|
-package llama
|
|
|
+package llm
|
|
|
|
|
|
/*
|
|
|
#cgo CPPFLAGS: -O3 -Wall -Wextra -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS
|
|
@@ -105,7 +105,7 @@ import (
|
|
|
//go:embed ggml-metal.metal
|
|
|
var fs embed.FS
|
|
|
|
|
|
-type LLM struct {
|
|
|
+type llama struct {
|
|
|
params *C.struct_llama_context_params
|
|
|
model *C.struct_llama_model
|
|
|
ctx *C.struct_llama_context
|
|
@@ -120,12 +120,28 @@ type LLM struct {
|
|
|
api.Options
|
|
|
}
|
|
|
|
|
|
-func New(model string, opts api.Options) (*LLM, error) {
|
|
|
+type llamaHyperparameters struct {
|
|
|
+ // NumVocab is the size of the model's vocabulary.
|
|
|
+ NumVocab uint32
|
|
|
+
|
|
|
+ // NumEmbd is the size of the model's embedding layer.
|
|
|
+ NumEmbd uint32
|
|
|
+ NumMult uint32
|
|
|
+ NumHead uint32
|
|
|
+
|
|
|
+ // NumLayer is the number of layers in the model.
|
|
|
+ NumLayer uint32
|
|
|
+ NumRot uint32
|
|
|
+ // FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
|
|
|
+ FileType
|
|
|
+}
|
|
|
+
|
|
|
+func newLlama(model string, opts api.Options) (*llama, error) {
|
|
|
if _, err := os.Stat(model); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
- llm := LLM{Options: opts}
|
|
|
+ llm := llama{Options: opts}
|
|
|
|
|
|
C.llama_backend_init(C.bool(llm.UseNUMA))
|
|
|
|
|
@@ -168,7 +184,7 @@ func New(model string, opts api.Options) (*LLM, error) {
|
|
|
return &llm, nil
|
|
|
}
|
|
|
|
|
|
-func (llm *LLM) Close() {
|
|
|
+func (llm *llama) Close() {
|
|
|
llm.gc = true
|
|
|
|
|
|
llm.mu.Lock()
|
|
@@ -180,17 +196,16 @@ func (llm *LLM) Close() {
|
|
|
C.llama_print_timings(llm.ctx)
|
|
|
}
|
|
|
|
|
|
+func (llm *llama) SetOptions(opts api.Options) {
|
|
|
+ llm.Options = opts
|
|
|
+}
|
|
|
+
|
|
|
var errNeedMoreData = errors.New("need more data")
|
|
|
|
|
|
-func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
|
|
+func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
|
|
C.llama_reset_timings(llm.ctx)
|
|
|
|
|
|
- tokens := make([]C.llama_token, len(ctx))
|
|
|
- for i := range tokens {
|
|
|
- tokens[i] = C.llama_token(ctx[i])
|
|
|
- }
|
|
|
-
|
|
|
- llm.marshalPrompt(tokens, prompt)
|
|
|
+ llm.marshalPrompt(ctx, prompt)
|
|
|
|
|
|
C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed))
|
|
|
|
|
@@ -205,7 +220,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
- b.WriteString(llm.Decode(token))
|
|
|
+ b.WriteString(llm.Decode(int(token)))
|
|
|
|
|
|
if err := llm.checkStopConditions(b); err != nil {
|
|
|
if errors.Is(err, io.EOF) {
|
|
@@ -243,7 +258,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
|
|
|
+func (llm *llama) checkStopConditions(b bytes.Buffer) error {
|
|
|
for _, stopCondition := range llm.Stop {
|
|
|
if stopCondition == strings.TrimSpace(b.String()) {
|
|
|
return io.EOF
|
|
@@ -255,12 +270,17 @@ func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
|
|
|
+func (llm *llama) marshalPrompt(ctx []int, prompt string) []C.llama_token {
|
|
|
tokens := append(ctx, llm.Encode(prompt)...)
|
|
|
if llm.NumKeep < 0 {
|
|
|
llm.NumKeep = len(tokens)
|
|
|
}
|
|
|
|
|
|
+ cTokens := make([]C.llama_token, len(tokens))
|
|
|
+ for i := range tokens {
|
|
|
+ cTokens[i] = C.llama_token(tokens[i])
|
|
|
+ }
|
|
|
+
|
|
|
// min(llm.NumCtx - 4, llm.NumKeep)
|
|
|
if llm.NumCtx-4 < llm.NumKeep {
|
|
|
llm.NumKeep = llm.NumCtx - 4
|
|
@@ -269,25 +289,25 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
|
|
|
if len(tokens) >= llm.NumCtx {
|
|
|
// truncate input
|
|
|
numLeft := (llm.NumCtx - llm.NumKeep) / 2
|
|
|
- truncated := tokens[:llm.NumKeep]
|
|
|
- erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft
|
|
|
- truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...)
|
|
|
- copy(llm.last, tokens[len(tokens)-llm.NumCtx:])
|
|
|
+ truncated := cTokens[:llm.NumKeep]
|
|
|
+ erasedBlocks := (len(cTokens) - llm.NumKeep - numLeft - 1) / numLeft
|
|
|
+ truncated = append(truncated, cTokens[llm.NumKeep+erasedBlocks*numLeft:]...)
|
|
|
+ copy(llm.last, cTokens[len(cTokens)-llm.NumCtx:])
|
|
|
|
|
|
- tokens = truncated
|
|
|
+ cTokens = truncated
|
|
|
log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
|
|
|
} else {
|
|
|
- llm.last = make([]C.llama_token, llm.NumCtx-len(tokens))
|
|
|
- llm.last = append(llm.last, tokens...)
|
|
|
+ llm.last = make([]C.llama_token, llm.NumCtx-len(cTokens))
|
|
|
+ llm.last = append(llm.last, cTokens...)
|
|
|
}
|
|
|
|
|
|
var i int
|
|
|
- for i = 0; i < len(llm.embd) && i < len(tokens) && llm.embd[i] == tokens[i]; i++ {
|
|
|
+ for i = 0; i < len(llm.embd) && i < len(cTokens) && llm.embd[i] == cTokens[i]; i++ {
|
|
|
// noop
|
|
|
}
|
|
|
|
|
|
- llm.embd = tokens
|
|
|
- if i == len(tokens) {
|
|
|
+ llm.embd = cTokens
|
|
|
+ if i == len(cTokens) {
|
|
|
// evaluate at least one token to generate logits
|
|
|
i--
|
|
|
}
|
|
@@ -295,31 +315,36 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
|
|
|
llm.cursor = i
|
|
|
|
|
|
log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:]))
|
|
|
- return tokens
|
|
|
+ return cTokens
|
|
|
}
|
|
|
|
|
|
-func (llm *LLM) Encode(prompt string) []C.llama_token {
|
|
|
+func (llm *llama) Encode(prompt string) []int {
|
|
|
cPrompt := C.CString(prompt)
|
|
|
defer C.free(unsafe.Pointer(cPrompt))
|
|
|
|
|
|
- tokens := make([]C.llama_token, len(prompt)+1)
|
|
|
- if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(tokens), C.int(len(tokens)), true); n > 0 {
|
|
|
- return tokens[:n]
|
|
|
+ cTokens := make([]C.llama_token, len(prompt)+1)
|
|
|
+ if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(cTokens), C.int(len(cTokens)), true); n > 0 {
|
|
|
+ tokens := make([]int, n)
|
|
|
+ for i := range cTokens[:n] {
|
|
|
+ tokens[i] = int(cTokens[i])
|
|
|
+ }
|
|
|
+
|
|
|
+ return tokens
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (llm *LLM) Decode(tokens ...C.llama_token) string {
|
|
|
+func (llm *llama) Decode(tokens ...int) string {
|
|
|
var sb strings.Builder
|
|
|
for _, token := range tokens {
|
|
|
- sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))
|
|
|
+ sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, C.llama_token(token))))
|
|
|
}
|
|
|
|
|
|
return sb.String()
|
|
|
}
|
|
|
|
|
|
-func (llm *LLM) next() (C.llama_token, error) {
|
|
|
+func (llm *llama) next() (C.llama_token, error) {
|
|
|
llm.mu.Lock()
|
|
|
defer llm.mu.Unlock()
|
|
|
|
|
@@ -410,7 +435,7 @@ func (llm *LLM) next() (C.llama_token, error) {
|
|
|
return token, nil
|
|
|
}
|
|
|
|
|
|
-func (llm *LLM) Embedding(input string) ([]float64, error) {
|
|
|
+func (llm *llama) Embedding(input string) ([]float64, error) {
|
|
|
if !llm.EmbeddingOnly {
|
|
|
return nil, errors.New("llama: embedding not enabled")
|
|
|
}
|
|
@@ -420,7 +445,12 @@ func (llm *LLM) Embedding(input string) ([]float64, error) {
|
|
|
return nil, errors.New("llama: tokenize embedding")
|
|
|
}
|
|
|
|
|
|
- retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
|
|
|
+ cTokens := make([]C.llama_token, len(tokens))
|
|
|
+ for i := range tokens {
|
|
|
+ cTokens[i] = C.llama_token(tokens[i])
|
|
|
+ }
|
|
|
+
|
|
|
+ retval := C.llama_eval(llm.ctx, unsafe.SliceData(cTokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
|
|
|
if retval != 0 {
|
|
|
return nil, errors.New("llama: eval")
|
|
|
}
|