Browse Source

roughly count gemma3 graph

the largest operation is by far (q @ k) so just count that for
simplicity
Michael Yang 1 month ago
parent
commit
a422ba39c9
1 changed files with 16 additions and 18 deletions
  1. 16 18
      fs/ggml/ggml.go

+ 16 - 18
fs/ggml/ggml.go

@@ -587,34 +587,32 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
 		}
 	}
 
-	switch llm.KV().Architecture() {
-	case "mllama":
-		kv := func(n string) uint64 {
-			if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
-				return uint64(v)
-			}
-
-			return 0
-		}
+	imageSize := uint64(llm.KV().Uint("vision.image_size"))
+	patchSize := uint64(llm.KV().Uint("vision.patch_size"))
 
-		imageSize := kv("image_size")
+	numPatches := (imageSize / patchSize) * (imageSize / patchSize)
+	if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
+		numPatches++
+	}
 
-		maxNumTiles := kv("max_num_tiles")
-		embeddingLength := kv("embedding_length")
-		headCount := kv("attention.head_count")
+	headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
 
-		numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
-		if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
-			numPatches++
-		}
+	switch llm.KV().Architecture() {
+	case "mllama":
 
 		numPaddedPatches := numPatches + 8 - (numPatches%8)%8
 
+		maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
+		numChannels := uint64(llm.KV().Uint("vision.num_channels"))
+		embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
+
 		graphSize = 4 * (8 +
-			imageSize*imageSize*kv("num_channels")*maxNumTiles +
+			imageSize*imageSize*numChannels*maxNumTiles +
 			embeddingLength*numPatches*maxNumTiles +
 			9*embeddingLength*numPaddedPatches*maxNumTiles +
 			numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
+	case "gemma3":
+		graphSize = 4 * (numPatches * numPatches * headCount)
 	}
 
 	return weights, graphSize