瀏覽代碼

fixes for gguf (#3863)

Patrick Devine 1 年之前
父節點
當前提交
14476d48cc
共有 1 個文件被更改,包括 10 次插入6 次删除
  1. 10 6
      llm/gguf.go

+ 10 - 6
llm/gguf.go

@@ -190,8 +190,6 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
 		llm.kv[k] = v
 	}
 
-	slog.Debug(fmt.Sprintf("general.architecture = %s", llm.kv["general.architecture"]))
-
 	// decode tensors
 	for i := 0; uint64(i) < llm.numTensor(); i++ {
 		name, err := readGGUFString(llm, rs)
@@ -465,11 +463,13 @@ var ggufKVOrder = map[string][]string{
 		"llama.embedding_length",
 		"llama.block_count",
 		"llama.feed_forward_length",
-		"llama.rope.dimension_count",
 		"llama.attention.head_count",
 		"llama.attention.head_count_kv",
 		"llama.attention.layer_norm_rms_epsilon",
 		"llama.rope.freq_base",
+		"llama.rope.dimension_count",
+		"llama.expert_count",
+		"llama.expert_used_count",
 		"gemma.context_length",
 		"gemma.embedding_length",
 		"gemma.block_count",
@@ -577,6 +577,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
 					return err
 				}
 			}
+		default:
+			return fmt.Errorf("improper type for '%s'", k)
 		}
 		if err != nil {
 			return err
@@ -598,9 +600,11 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
 			return err
 		}
 
-		dims := 1
-		if tensor.Shape[1] > 0 {
-			dims = 2
+		dims := 0
+		for cnt := 0; cnt < len(tensor.Shape); cnt++ {
+			if tensor.Shape[cnt] > 0 {
+				dims++
+			}
 		}
 
 		if err := binary.Write(ws, llm.ByteOrder, uint32(dims)); err != nil {