|
@@ -303,3 +303,50 @@ func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
|
|
|
model: model,
|
|
|
}, offset, nil
|
|
|
}
|
|
|
+
|
|
|
+func (llm GGML) GraphSize(context, batch int) (int64, bool) {
|
|
|
+ embeddingLength := llm.KV().EmbeddingLength()
|
|
|
+ headCount := llm.KV().HeadCount()
|
|
|
+ headCountKV := llm.KV().HeadCountKV()
|
|
|
+ vocabLength := len(llm.KV()["tokenizer.ggml.tokens"].([]any))
|
|
|
+
|
|
|
+ var attnQKVWeight1 uint64 = 0
|
|
|
+ for _, t := range llm.Tensors() {
|
|
|
+ if strings.HasSuffix(t.Name, ".attn_qkv.weight") && len(t.Shape) >= 2 {
|
|
|
+ attnQKVWeight1 = t.Shape[1]
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ var ffnGate1 uint64 = 0
|
|
|
+ for _, t := range llm.Tensors() {
|
|
|
+ if strings.Index(t.Name, ".ffn_gate") > 0 && len(t.Shape) >= 2 {
|
|
|
+ ffnGate1 = t.Shape[1]
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ switch llm.KV().Architecture() {
|
|
|
+ case "gemma":
|
|
|
+ return 4 * int64(batch) * int64(embeddingLength+uint64(vocabLength)), true
|
|
|
+ case "phi2":
|
|
|
+ return max(
|
|
|
+ 4*int64(batch)*int64(embeddingLength+uint64(vocabLength)),
|
|
|
+ 4*int64(batch)*int64(1+4*embeddingLength+uint64(context)+attnQKVWeight1+uint64(context)*headCount),
|
|
|
+ ), true
|
|
|
+ case "qwen2":
|
|
|
+ return max(
|
|
|
+ 4*int64(batch)*int64(embeddingLength+uint64(vocabLength)),
|
|
|
+ 4*int64(batch)*int64(1+2*embeddingLength+uint64(context)+uint64(context)*headCount),
|
|
|
+ ), true
|
|
|
+ case "llama":
|
|
|
+ if ffnGate1 > 0 {
|
|
|
+ // moe
|
|
|
+ return 4 * int64(batch) * int64(2+3*embeddingLength+uint64(context)+uint64(context)*headCount+2*headCountKV+ffnGate1), true
|
|
|
+ }
|
|
|
+
|
|
|
+ return 4 * int64(batch) * int64(1+4*embeddingLength+uint64(context)+uint64(context)*headCount), true
|
|
|
+ }
|
|
|
+
|
|
|
+ return 0, false
|
|
|
+}
|