浏览代码

fix embeddings invalid values

Bruce MacDonald 1 年之前
父节点
当前提交
984c9c628c
共有 2 个文件被更改,包括 9 次插入39 次删除
  1. 7 15
      llama/llama.go
  2. 2 24
      server/images.go

+ 7 - 15
llama/llama.go

@@ -94,7 +94,6 @@ import (
 	"io"
 	"log"
 	"os"
-	"reflect"
 	"strings"
 	"sync"
 	"unicode/utf8"
@@ -421,27 +420,20 @@ func (llm *LLM) Embedding(input string) ([]float64, error) {
 		return nil, errors.New("llama: tokenize embedding")
 	}
 
-	retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread))
+	retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
 	if retval != 0 {
 		return nil, errors.New("llama: eval")
 	}
 
-	n := int(C.llama_n_embd(llm.ctx))
+	n := C.llama_n_embd(llm.ctx)
 	if n <= 0 {
 		return nil, errors.New("llama: no embeddings generated")
 	}
+	cEmbeddings := unsafe.Slice(C.llama_get_embeddings(llm.ctx), n)
 
-	embedPtr := C.llama_get_embeddings(llm.ctx)
-	if embedPtr == nil {
-		return nil, errors.New("llama: embedding retrieval failed")
+	embeddings := make([]float64, len(cEmbeddings))
+	for i, v := range cEmbeddings {
+		embeddings[i] = float64(v)
 	}
-
-	header := reflect.SliceHeader{
-		Data: uintptr(unsafe.Pointer(embedPtr)),
-		Len:  n,
-		Cap:  n,
-	}
-	embedSlice := *(*[]float64)(unsafe.Pointer(&header))
-
-	return embedSlice, nil
+	return embeddings, nil
 }

+ 2 - 24
server/images.go

@@ -11,7 +11,6 @@ import (
 	"html/template"
 	"io"
 	"log"
-	"math"
 	"net/http"
 	"os"
 	"path/filepath"
@@ -480,31 +479,10 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
 						Total:     len(data) - 1,
 						Completed: i,
 					})
-					retry := 0
-				generate:
-					if retry > 3 {
-						log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
-						continue
-					}
 					embed, err := llm.Embedding(d)
 					if err != nil {
-						log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err)
-						retry++
-						goto generate
-					}
-					// Check for NaN and Inf in the embedding, which can't be stored
-					for _, value := range embed {
-						if math.IsNaN(value) || math.IsInf(value, 0) {
-							log.Printf("reloading model, embedding contains NaN or Inf")
-							// reload the model to get a new embedding, the seed can effect these outputs and reloading changes it
-							llm.Close()
-							llm, err = llama.New(e.model, e.opts)
-							if err != nil {
-								return nil, fmt.Errorf("load model to generate embeddings: %v", err)
-							}
-							retry++
-							goto generate
-						}
+						log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
+						continue
 					}
 					embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
 				}