Browse Source

gemma2 graph

Michael Yang 10 months ago
parent
commit
de2163dafd
1 changed files with 12 additions and 3 deletions
  1. 12 3
      llm/ggml.go

+ 12 - 3
llm/ggml.go

@@ -366,9 +366,18 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 				4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
 			)
 		}
-	case "gemma":
-		fullOffload = 4 * batch * (embedding + vocab)
-		partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
+	case "gemma", "gemma2":
+		fullOffload = max(
+			4*batch*(embedding+vocab),
+			4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
+		)
+
+		partialOffload = max(
+			4*embedding*batch+embedding*vocab*105/128+4*vocab*batch,
+			4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+
+				4*embeddingHeadsK*context*8+
+				embedding*embeddingHeadsK*heads*9/16,
+		)
 	case "command-r":
 		fullOffload = max(
 			4*batch*(embedding+vocab),