浏览代码

runner.go: Better handle return NULL values from llama.cpp

Llama.cpp sometimes returns NULL as a return value to report an
error. We should explicitly check for this and convert it to a Go
error rather than putting NULL in our data structures and waiting
for it to blow up later.
Jesse Gross 6 月之前
父节点
当前提交
de1557a0dc
共有 3 个文件被更改,包括 37 次插入13 次删除
  1. 20 9
      llama/llama.go
  2. 9 2
      llama/runner/runner.go
  3. 8 2
      llm/server.go

+ 20 - 9
llama/llama.go

@@ -136,10 +136,6 @@ func (c *Context) Model() *Model {
 	return &Model{c: C.llama_get_model(c.c)}
 	return &Model{c: C.llama_get_model(c.c)}
 }
 }
 
 
-func (c *Context) GetLogitsIth(i int) []float32 {
-	return unsafe.Slice((*float32)(unsafe.Pointer(C.llama_get_logits_ith(c.c, C.int(i)))), c.Model().NumVocab())
-}
-
 func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
 func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
 	C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
 	C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
 }
 }
@@ -163,7 +159,12 @@ func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
 }
 }
 
 
 func (c *Context) GetEmbeddingsIth(i int) []float32 {
 func (c *Context) GetEmbeddingsIth(i int) []float32 {
-	return unsafe.Slice((*float32)(unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))), c.Model().NEmbd())
+	embeddings := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))
+	if embeddings == nil {
+		return nil
+	}
+
+	return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
 }
 }
 
 
 type ModelParams struct {
 type ModelParams struct {
@@ -184,7 +185,7 @@ func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
 	return true
 	return true
 }
 }
 
 
-func LoadModelFromFile(modelPath string, params ModelParams) *Model {
+func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
 	cparams := C.llama_model_default_params()
 	cparams := C.llama_model_default_params()
 	cparams.n_gpu_layers = C.int(params.NumGpuLayers)
 	cparams.n_gpu_layers = C.int(params.NumGpuLayers)
 	cparams.main_gpu = C.int32_t(params.MainGpu)
 	cparams.main_gpu = C.int32_t(params.MainGpu)
@@ -214,18 +215,28 @@ func LoadModelFromFile(modelPath string, params ModelParams) *Model {
 		cparams.progress_callback_user_data = unsafe.Pointer(&handle)
 		cparams.progress_callback_user_data = unsafe.Pointer(&handle)
 	}
 	}
 
 
-	return &Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)}
+	m := Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)}
+	if m.c == (*C.struct_llama_model)(C.NULL) {
+		return nil, fmt.Errorf("unable to load model: %s", modelPath)
+	}
+
+	return &m, nil
 }
 }
 
 
 func FreeModel(model *Model) {
 func FreeModel(model *Model) {
 	C.llama_free_model(model.c)
 	C.llama_free_model(model.c)
 }
 }
 
 
-func NewContextWithModel(model *Model, params ContextParams) *Context {
-	return &Context{
+func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
+	c := Context{
 		c:          C.llama_new_context_with_model(model.c, params.c),
 		c:          C.llama_new_context_with_model(model.c, params.c),
 		numThreads: int(params.c.n_threads),
 		numThreads: int(params.c.n_threads),
 	}
 	}
+	if c.c == (*C.struct_llama_context)(C.NULL) {
+		return nil, errors.New("unable to create llama context")
+	}
+
+	return &c, nil
 }
 }
 
 
 func (m *Model) NumVocab() int {
 func (m *Model) NumVocab() int {

+ 9 - 2
llama/runner/runner.go

@@ -790,10 +790,17 @@ func (s *Server) loadModel(
 ) {
 ) {
 	llama.BackendInit()
 	llama.BackendInit()
 
 
-	s.model = llama.LoadModelFromFile(mpath, params)
+	var err error
+	s.model, err = llama.LoadModelFromFile(mpath, params)
+	if err != nil {
+		panic(err)
+	}
 
 
 	ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention)
 	ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention)
-	s.lc = llama.NewContextWithModel(s.model, ctxParams)
+	s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
+	if err != nil {
+		panic(err)
+	}
 
 
 	if lpath != "" {
 	if lpath != "" {
 		err := s.model.ApplyLoraFromFile(s.lc, lpath, 1.0, threads)
 		err := s.model.ApplyLoraFromFile(s.lc, lpath, 1.0, threads)

+ 8 - 2
llm/server.go

@@ -958,7 +958,10 @@ func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error)
 	if resp.StatusCode == http.StatusNotFound {
 	if resp.StatusCode == http.StatusNotFound {
 		if s.model == nil {
 		if s.model == nil {
 			slog.Debug("new runner detected, loading model for cgo tokenization")
 			slog.Debug("new runner detected, loading model for cgo tokenization")
-			m := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
+			m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
+			if err != nil {
+				return nil, err
+			}
 			s.model = m
 			s.model = m
 		}
 		}
 		return s.model.Tokenize(content, false, true)
 		return s.model.Tokenize(content, false, true)
@@ -1027,7 +1030,10 @@ func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error
 	if resp.StatusCode == http.StatusNotFound {
 	if resp.StatusCode == http.StatusNotFound {
 		if s.model == nil {
 		if s.model == nil {
 			slog.Debug("new runner detected, loading model for cgo tokenization")
 			slog.Debug("new runner detected, loading model for cgo tokenization")
-			m := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
+			m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
+			if err != nil {
+				return "", err
+			}
 			s.model = m
 			s.model = m
 		}
 		}
 		var resp string
 		var resp string