|
@@ -366,9 +366,18 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
|
|
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
|
|
|
)
|
|
|
}
|
|
|
- case "gemma":
|
|
|
- fullOffload = 4 * batch * (embedding + vocab)
|
|
|
- partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
|
|
|
+ case "gemma", "gemma2":
|
|
|
+ fullOffload = max(
|
|
|
+ 4*batch*(embedding+vocab),
|
|
|
+ 4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
|
|
|
+ )
|
|
|
+
|
|
|
+ partialOffload = max(
|
|
|
+ 4*embedding*batch+embedding*vocab*105/128+4*vocab*batch,
|
|
|
+ 4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+
|
|
|
+ 4*embeddingHeadsK*context*8+
|
|
|
+ embedding*embeddingHeadsK*heads*9/16,
|
|
|
+ )
|
|
|
case "command-r":
|
|
|
fullOffload = max(
|
|
|
4*batch*(embedding+vocab),
|