瀏覽代碼

Merge pull request #3836 from ollama/mxyng/mixtral

fix: mixtral graph
Michael Yang 1 年之前
父節點
當前提交
e83bcf7f9a
共有 1 個文件被更改,包括 9 次插入1 次删除
  1. 9 1
      llm/ggml.go

+ 9 - 1
llm/ggml.go

@@ -343,7 +343,15 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 			4*batch*(embedding+vocab)+embedding*vocab*105/128,
 		)
 
-		if ffnGateWeight, ok := layers["0"]["ffn_gate.0.weight"]; ok {
+		if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
+			// mixtral 8x22b
+			ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
+			partialOffload = max(
+				3*ffnGateExpsWeight.size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV),
+				4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch),
+			)
+		} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
+			// mixtral 8x7b
 			ffnGateWeight1 := ffnGateWeight.Shape[1]
 			fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
 			partialOffload = max(