|
@@ -360,7 +360,7 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
|
|
}, offset, nil
|
|
|
}
|
|
|
|
|
|
-func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload uint64) {
|
|
|
+func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffload uint64) {
|
|
|
embedding := llm.KV().EmbeddingLength()
|
|
|
heads := llm.KV().HeadCount()
|
|
|
headsKV := llm.KV().HeadCountKV()
|
|
@@ -368,9 +368,12 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
|
|
|
|
|
embeddingHeads := llm.KV().EmbeddingHeadCount()
|
|
|
embeddingHeadsK := llm.KV().EmbeddingHeadCountK()
|
|
|
+ embeddingHeadsV := llm.KV().EmbeddingHeadCountV()
|
|
|
|
|
|
layers := llm.Tensors().Layers()
|
|
|
|
|
|
+ kv = 2 * context * llm.KV().BlockCount() * (embeddingHeadsK + embeddingHeadsV) * headsKV
|
|
|
+
|
|
|
switch llm.KV().Architecture() {
|
|
|
case "llama":
|
|
|
fullOffload = max(
|
|
@@ -400,6 +403,42 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
|
|
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
|
|
|
)
|
|
|
}
|
|
|
+ case "mllama":
|
|
|
+ var visionTokens, tiles uint64 = 1601, 4
|
|
|
+
|
|
|
+ if crossAttentionLayers, ok := llm.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
|
|
|
+ kv = headsKV *
|
|
|
+ (embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
|
|
|
+ (2* // sizeof(float16)
|
|
|
+ (llm.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
|
|
|
+ context +
|
|
|
+ 4* // sizeof(float32)
|
|
|
+ uint64(crossAttentionLayers.size)* // num cross attention layers
|
|
|
+ visionTokens*
|
|
|
+ tiles)
|
|
|
+ }
|
|
|
+
|
|
|
+ fullOffload = max(
|
|
|
+ 4*batch*(2+3*embedding+embeddingHeadsK*heads+context*(1+heads)),
|
|
|
+ // vocab graph
|
|
|
+ 4*batch*(embedding+vocab),
|
|
|
+ )
|
|
|
+
|
|
|
+ var ropeFreqsCount uint64
|
|
|
+ if ropeFreqs, ok := llm.Tensors().Layers()["rope_freqs"]; ok {
|
|
|
+ if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
|
|
|
+ ropeFreqsCount = ropeFreqsWeights.parameters()
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ partialOffload = max(
|
|
|
+ 4*(batch*
|
|
|
+ (2*embedding+1+context*(1+heads)+embeddingHeadsK*heads)+
|
|
|
+ ropeFreqsCount+
|
|
|
+ embeddingHeadsK*context*headsKV),
|
|
|
+ // vocab graph
|
|
|
+ 4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
|
|
+ )
|
|
|
case "gemma", "gemma2":
|
|
|
fullOffload = max(
|
|
|
4*batch*(embedding+vocab),
|