Browse Source

runner: avoid buffer overwrite when generating multiple embeddings (#8714)

Shield the code processing the embedding result
from subsequent calls that may overwrite the same
buffer to process a second input when retrieving
model embeddings.
Diego Pereira 2 months ago
parent
commit
928911bc68
1 changed files with 10 additions and 6 deletions
  1. 10 6
      llama/llama.go

+ 10 - 6
llama/llama.go

@@ -199,21 +199,25 @@ func (c *Context) KvCacheDefrag() {
 
 // Get the embeddings for a sequence id
 func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
-	embeddings := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
-	if embeddings == nil {
+	e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
+	if e == nil {
 		return nil
 	}
 
-	return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
+	embeddings := make([]float32, c.Model().NEmbd())
+	_ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd()))
+	return embeddings
 }
 
 func (c *Context) GetEmbeddingsIth(i int) []float32 {
-	embeddings := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))
-	if embeddings == nil {
+	e := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))
+	if e == nil {
 		return nil
 	}
 
-	return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
+	embeddings := make([]float32, c.Model().NEmbd())
+	_ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd()))
+	return embeddings
 }
 
 type ModelParams struct {