|
@@ -69,6 +69,30 @@ func (kv KV) HeadCountKV() uint64 {
|
|
return 1
|
|
return 1
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+func (kv KV) EmbeddingHeadCount() uint64 {
|
|
|
|
+ if heads := kv.HeadCount(); heads > 0 {
|
|
|
|
+ return kv.EmbeddingLength() / kv.HeadCount()
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return 0
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (kv KV) EmbeddingHeadCountK() uint64 {
|
|
|
|
+ if k := kv.u64(fmt.Sprintf("%s.attention.key_length", kv.Architecture())); k > 0 {
|
|
|
|
+ return k
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return kv.EmbeddingHeadCount()
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (kv KV) EmbeddingHeadCountV() uint64 {
|
|
|
|
+ if v := kv.u64(fmt.Sprintf("%s.attention.value_length", kv.Architecture())); v > 0 {
|
|
|
|
+ return v
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return kv.EmbeddingHeadCount()
|
|
|
|
+}
|
|
|
|
+
|
|
func (kv KV) GQA() uint64 {
|
|
func (kv KV) GQA() uint64 {
|
|
return kv.HeadCount() / kv.HeadCountKV()
|
|
return kv.HeadCount() / kv.HeadCountKV()
|
|
}
|
|
}
|
|
@@ -299,6 +323,9 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
|
headsKV := llm.KV().HeadCountKV()
|
|
headsKV := llm.KV().HeadCountKV()
|
|
vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
|
|
vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
|
|
|
|
|
|
|
|
+ embeddingHeads := llm.KV().EmbeddingHeadCount()
|
|
|
|
+ embeddingHeadsK := llm.KV().EmbeddingHeadCountK()
|
|
|
|
+
|
|
layers := llm.Tensors().Layers()
|
|
layers := llm.Tensors().Layers()
|
|
|
|
|
|
switch llm.KV().Architecture() {
|
|
switch llm.KV().Architecture() {
|
|
@@ -308,7 +335,7 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
|
partialOffload = 4 * batch * embedding
|
|
partialOffload = 4 * batch * embedding
|
|
partialOffload += max(
|
|
partialOffload += max(
|
|
// 4*batch*(4+6*embedding+context*(2*heads)+llm.KV().GQA()),
|
|
// 4*batch*(4+6*embedding+context*(2*heads)+llm.KV().GQA()),
|
|
- 4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
|
|
|
|
|
|
+ 4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV),
|
|
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
|
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
|
)
|
|
)
|
|
|
|
|
|
@@ -316,15 +343,15 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
|
// mixtral 8x22b
|
|
// mixtral 8x22b
|
|
ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
|
|
ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
|
|
partialOffload = max(
|
|
partialOffload = max(
|
|
- 3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV),
|
|
|
|
- 4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch),
|
|
|
|
|
|
+ 3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
|
|
|
|
+ 4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
|
|
)
|
|
)
|
|
} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
|
|
} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
|
|
// mixtral 8x7b
|
|
// mixtral 8x7b
|
|
ffnGateWeight1 := ffnGateWeight.Shape[1]
|
|
ffnGateWeight1 := ffnGateWeight.Shape[1]
|
|
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
|
|
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
|
|
partialOffload = max(
|
|
partialOffload = max(
|
|
- 4*batch*(3+embedding/heads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
|
|
|
|
|
|
+ 4*batch*(3+embeddingHeads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
|
|
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
|
|
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
|
|
)
|
|
)
|
|
}
|
|
}
|
|
@@ -368,15 +395,14 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
|
fullOffload,
|
|
fullOffload,
|
|
)
|
|
)
|
|
case "deepseek2":
|
|
case "deepseek2":
|
|
- keys := uint64(llm.KV()["deepseek2.attention.key_length"].(uint32))
|
|
|
|
fullOffload = max(
|
|
fullOffload = max(
|
|
4*batch*(3*embedding+vocab),
|
|
4*batch*(3*embedding+vocab),
|
|
- 4*batch*(3*embedding+2+context*(1+headsKV)+2*keys*headsKV),
|
|
|
|
|
|
+ 4*batch*(3*embedding+2+context*(1+headsKV)+2*embeddingHeadsK*headsKV),
|
|
)
|
|
)
|
|
|
|
|
|
partialOffload = max(
|
|
partialOffload = max(
|
|
4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
|
|
4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
|
|
- 4*batch*(2*embedding+1+2*keys*headsKV+context+context*headsKV)+4*keys*context*headsKV+embedding*keys*headsKV*9/16,
|
|
|
|
|
|
+ 4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
|
|
)
|
|
)
|
|
}
|
|
}
|
|
|
|
|