|
@@ -94,7 +94,6 @@ import (
|
|
|
"io"
|
|
|
"log"
|
|
|
"os"
|
|
|
- "reflect"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
"unicode/utf8"
|
|
@@ -421,27 +420,20 @@ func (llm *LLM) Embedding(input string) ([]float64, error) {
|
|
|
return nil, errors.New("llama: tokenize embedding")
|
|
|
}
|
|
|
|
|
|
- retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread))
|
|
|
+ retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
|
|
|
if retval != 0 {
|
|
|
return nil, errors.New("llama: eval")
|
|
|
}
|
|
|
|
|
|
- n := int(C.llama_n_embd(llm.ctx))
|
|
|
+ n := C.llama_n_embd(llm.ctx)
|
|
|
if n <= 0 {
|
|
|
return nil, errors.New("llama: no embeddings generated")
|
|
|
}
|
|
|
+ cEmbeddings := unsafe.Slice(C.llama_get_embeddings(llm.ctx), n)
|
|
|
|
|
|
- embedPtr := C.llama_get_embeddings(llm.ctx)
|
|
|
- if embedPtr == nil {
|
|
|
- return nil, errors.New("llama: embedding retrieval failed")
|
|
|
+ embeddings := make([]float64, len(cEmbeddings))
|
|
|
+ for i, v := range cEmbeddings {
|
|
|
+ embeddings[i] = float64(v)
|
|
|
}
|
|
|
-
|
|
|
- header := reflect.SliceHeader{
|
|
|
- Data: uintptr(unsafe.Pointer(embedPtr)),
|
|
|
- Len: n,
|
|
|
- Cap: n,
|
|
|
- }
|
|
|
- embedSlice := *(*[]float64)(unsafe.Pointer(&header))
|
|
|
-
|
|
|
- return embedSlice, nil
|
|
|
+ return embeddings, nil
|
|
|
}
|