Browse Source

mixtral mem

Michael Yang 1 year ago
parent
commit
3397eff0cd
1 changed files with 12 additions and 1 deletions
  1. 12 1
      llm/ggml.go

+ 12 - 1
llm/ggml.go

@@ -330,6 +330,8 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 	headsKV := llm.KV().HeadCountKV()
 	headsKV := llm.KV().HeadCountKV()
 	vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
 	vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
 
 
+	layers := llm.Tensors().Layers()
+
 	switch llm.KV().Architecture() {
 	switch llm.KV().Architecture() {
 	case "llama":
 	case "llama":
 		fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads))
 		fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads))
@@ -339,6 +341,15 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 			4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
 			4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
 			4*batch*(embedding+vocab)+embedding*vocab*105/128,
 			4*batch*(embedding+vocab)+embedding*vocab*105/128,
 		)
 		)
+
+		if ffnGateWeight, ok := layers["0"]["ffn_gate.0.weight"]; ok {
+			ffnGateWeight1 := ffnGateWeight.Shape[1]
+			fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
+			partialOffload = max(
+				4*batch*(3+embedding/heads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
+				4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
+			)
+		}
 	case "gemma":
 	case "gemma":
 		fullOffload = 4 * batch * (embedding + vocab)
 		fullOffload = 4 * batch * (embedding + vocab)
 		partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
 		partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
@@ -350,7 +361,7 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 
 
 		partialOffload = max(
 		partialOffload = max(
 			4*batch*(embedding+vocab)+embedding*vocab*105/128,
 			4*batch*(embedding+vocab)+embedding*vocab*105/128,
-			4*batch*(1+2*embedding+context*(1+heads))+ 4*embedding*context+embedding*embedding*9/16,
+			4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
 		)
 		)
 	case "qwen2":
 	case "qwen2":
 		fullOffload = max(
 		fullOffload = max(