浏览代码

revert GroupLayers

Michael Yang 2 月之前
父节点
当前提交
a3e0df1a5d
共有 2 个文件被更改,包括 41 次插入39 次删除
  1. 7 9
      fs/ggml/ggml.go
  2. 34 30
      fs/ggml/ggml_test.go

+ 7 - 9
fs/ggml/ggml.go

@@ -157,15 +157,13 @@ func (ts Tensors) GroupLayers() map[string]Layer {
 	layers := make(map[string]Layer)
 	for _, t := range ts.items {
 		parts := strings.Split(t.Name, ".")
-		if i := slices.Index(parts, "blk"); i > 0 {
-			parts = append([]string{
-				strings.Join(parts[:i], "."),
-				strings.Join(parts[i:i+2], "."),
-			}, parts[i+2:]...)
-		} else if i == 0 {
-			parts = append([]string{
-				strings.Join(parts[i:i+2], "."),
-			}, parts[i+2:]...)
+		if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
+			if len(parts) > index+2 {
+				// blk and mm should have a number after them, join it
+				parts = append(
+					[]string{strings.Join(parts[:index+2], ".")},
+					parts[index+2:]...)
+			}
 		}
 
 		if _, ok := layers[parts[0]]; !ok {

+ 34 - 30
fs/ggml/ggml_test.go

@@ -85,23 +85,25 @@ func TestTensorLayers(t *testing.T) {
 				}
 			}),
 			want: map[string]Layer{
-				"mm": {
-					"0.bias":   tensors["mm.0.bias"],
-					"0.weight": tensors["mm.0.weight"],
+				"mm.0": {
+					"bias":   tensors["mm.0.bias"],
+					"weight": tensors["mm.0.weight"],
+				},
+				"v.blk.0": {
+					"attn_k.weight":      tensors["v.blk.0.attn_k.weight"],
+					"attn_q.weight":      tensors["v.blk.0.attn_q.weight"],
+					"attn_v.weight":      tensors["v.blk.0.attn_v.weight"],
+					"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
+					"attn_norm.weight":   tensors["v.blk.0.attn_norm.weight"],
+					"ffn_down.weight":    tensors["v.blk.0.ffn_down.weight"],
+					"ffn_gate.weight":    tensors["v.blk.0.ffn_gate.weight"],
+					"ffn_up.weight":      tensors["v.blk.0.ffn_up.weight"],
+					"ffn_norm.weight":    tensors["v.blk.0.ffn_norm.weight"],
 				},
 				"v": {
-					"blk.0.attn_k.weight":      tensors["v.blk.0.attn_k.weight"],
-					"blk.0.attn_q.weight":      tensors["v.blk.0.attn_q.weight"],
-					"blk.0.attn_v.weight":      tensors["v.blk.0.attn_v.weight"],
-					"blk.0.attn_output.weight": tensors["v.blk.0.attn_output.weight"],
-					"blk.0.attn_norm.weight":   tensors["v.blk.0.attn_norm.weight"],
-					"blk.0.ffn_down.weight":    tensors["v.blk.0.ffn_down.weight"],
-					"blk.0.ffn_gate.weight":    tensors["v.blk.0.ffn_gate.weight"],
-					"blk.0.ffn_up.weight":      tensors["v.blk.0.ffn_up.weight"],
-					"blk.0.ffn_norm.weight":    tensors["v.blk.0.ffn_norm.weight"],
-					"patch_embd.weight":        tensors["v.patch_embd.weight"],
-					"position_embd.gate":       tensors["v.position_embd.gate"],
-					"position_embd.weight":     tensors["v.position_embd.weight"],
+					"patch_embd.weight":    tensors["v.patch_embd.weight"],
+					"position_embd.gate":   tensors["v.position_embd.gate"],
+					"position_embd.weight": tensors["v.position_embd.weight"],
 				},
 			},
 		},
@@ -122,23 +124,25 @@ func TestTensorLayers(t *testing.T) {
 				},
 				"token_embd":  {"weight": tensors["token_embd.weight"]},
 				"output_norm": {"weight": tensors["output_norm.weight"]},
-				"mm": {
-					"0.bias":   tensors["mm.0.bias"],
-					"0.weight": tensors["mm.0.weight"],
+				"mm.0": {
+					"bias":   tensors["mm.0.bias"],
+					"weight": tensors["mm.0.weight"],
+				},
+				"v.blk.0": {
+					"attn_k.weight":      tensors["v.blk.0.attn_k.weight"],
+					"attn_q.weight":      tensors["v.blk.0.attn_q.weight"],
+					"attn_v.weight":      tensors["v.blk.0.attn_v.weight"],
+					"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
+					"attn_norm.weight":   tensors["v.blk.0.attn_norm.weight"],
+					"ffn_down.weight":    tensors["v.blk.0.ffn_down.weight"],
+					"ffn_gate.weight":    tensors["v.blk.0.ffn_gate.weight"],
+					"ffn_up.weight":      tensors["v.blk.0.ffn_up.weight"],
+					"ffn_norm.weight":    tensors["v.blk.0.ffn_norm.weight"],
 				},
 				"v": {
-					"blk.0.attn_k.weight":      tensors["v.blk.0.attn_k.weight"],
-					"blk.0.attn_q.weight":      tensors["v.blk.0.attn_q.weight"],
-					"blk.0.attn_v.weight":      tensors["v.blk.0.attn_v.weight"],
-					"blk.0.attn_output.weight": tensors["v.blk.0.attn_output.weight"],
-					"blk.0.attn_norm.weight":   tensors["v.blk.0.attn_norm.weight"],
-					"blk.0.ffn_down.weight":    tensors["v.blk.0.ffn_down.weight"],
-					"blk.0.ffn_gate.weight":    tensors["v.blk.0.ffn_gate.weight"],
-					"blk.0.ffn_up.weight":      tensors["v.blk.0.ffn_up.weight"],
-					"blk.0.ffn_norm.weight":    tensors["v.blk.0.ffn_norm.weight"],
-					"patch_embd.weight":        tensors["v.patch_embd.weight"],
-					"position_embd.gate":       tensors["v.position_embd.gate"],
-					"position_embd.weight":     tensors["v.position_embd.weight"],
+					"patch_embd.weight":    tensors["v.patch_embd.weight"],
+					"position_embd.gate":   tensors["v.position_embd.gate"],
+					"position_embd.weight": tensors["v.position_embd.weight"],
 				},
 			},
 		},