|
@@ -400,6 +400,30 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
|
|
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
|
|
|
)
|
|
|
}
|
|
|
+ case "mllama":
|
|
|
+ var visionTokens, tiles uint64 = 1601, 4
|
|
|
+
|
|
|
+ fullOffload = max(
|
|
|
+ 4*batch*(2+3*embedding+embeddingHeadsK*heads+context*(1+heads)),
|
|
|
+ // vocab graph
|
|
|
+ 4*batch*(embedding+vocab),
|
|
|
+ )
|
|
|
+
|
|
|
+ var ropeFreqsCount uint64
|
|
|
+ if ropeFreqs, ok := llm.Tensors().Layers()["rope_freqs"]; ok {
|
|
|
+ if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
|
|
|
+ ropeFreqsCount = ropeFreqsWeights.parameters()
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ partialOffload = max(
|
|
|
+ 4*(batch*
|
|
|
+ (2*embedding+1+context*(1+heads)+embeddingHeadsK*heads)+
|
|
|
+ ropeFreqsCount+
|
|
|
+ embeddingHeadsK*context*headsKV),
|
|
|
+ // vocab graph
|
|
|
+ 4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
|
|
+ )
|
|
|
case "gemma", "gemma2":
|
|
|
fullOffload = max(
|
|
|
4*batch*(embedding+vocab),
|