Browse Source

count all vision tensors

Michael Yang 1 month ago
parent
commit
d2ec22371e
1 changed files with 9 additions and 12 deletions
  1. 9 12
      fs/ggml/ggml.go

+ 9 - 12
fs/ggml/ggml.go

@@ -579,12 +579,16 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
 }
 
 func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
-	switch llm.KV().Architecture() {
-	case "mllama":
-		for _, layer := range llm.Tensors().GroupLayers()["v"] {
-			weights += layer.Size()
+	for name, layer := range llm.Tensors().GroupLayers() {
+		if strings.HasPrefix(name, "v.") {
+			for _, tensor := range layer {
+				weights += tensor.Size()
+			}
 		}
+	}
 
+	switch llm.KV().Architecture() {
+	case "mllama":
 		kv := func(n string) uint64 {
 			if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
 				return uint64(v)
@@ -611,15 +615,8 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
 			embeddingLength*numPatches*maxNumTiles +
 			9*embeddingLength*numPaddedPatches*maxNumTiles +
 			numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
-	case "gemma3":
-		for name, layer := range llm.Tensors().GroupLayers() {
-			if strings.HasPrefix(name, "v.") {
-				for _, tensor := range layer {
-					weights += tensor.Size()
-				}
-			}
-		}
 	}
+
 	return weights, graphSize
 }