Browse Source

Merge pull request #5192 from ollama/mxyng/kv

handle asymmetric embedding KVs
Michael Yang 10 months ago
parent
commit
e01e535cbb
2 changed files with 35 additions and 9 deletions
  1. 33 7
      llm/ggml.go
  2. 2 2
      llm/memory.go

+ 33 - 7
llm/ggml.go

@@ -69,6 +69,30 @@ func (kv KV) HeadCountKV() uint64 {
 	return 1
 	return 1
 }
 }
 
 
+func (kv KV) EmbeddingHeadCount() uint64 {
+	if heads := kv.HeadCount(); heads > 0 {
+		return kv.EmbeddingLength() / kv.HeadCount()
+	}
+
+	return 0
+}
+
+func (kv KV) EmbeddingHeadCountK() uint64 {
+	if k := kv.u64(fmt.Sprintf("%s.attention.key_length", kv.Architecture())); k > 0 {
+		return k
+	}
+
+	return kv.EmbeddingHeadCount()
+}
+
+func (kv KV) EmbeddingHeadCountV() uint64 {
+	if v := kv.u64(fmt.Sprintf("%s.attention.value_length", kv.Architecture())); v > 0 {
+		return v
+	}
+
+	return kv.EmbeddingHeadCount()
+}
+
 func (kv KV) GQA() uint64 {
 func (kv KV) GQA() uint64 {
 	return kv.HeadCount() / kv.HeadCountKV()
 	return kv.HeadCount() / kv.HeadCountKV()
 }
 }
@@ -299,6 +323,9 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 	headsKV := llm.KV().HeadCountKV()
 	headsKV := llm.KV().HeadCountKV()
 	vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
 	vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
 
 
+	embeddingHeads := llm.KV().EmbeddingHeadCount()
+	embeddingHeadsK := llm.KV().EmbeddingHeadCountK()
+
 	layers := llm.Tensors().Layers()
 	layers := llm.Tensors().Layers()
 
 
 	switch llm.KV().Architecture() {
 	switch llm.KV().Architecture() {
@@ -308,7 +335,7 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 		partialOffload = 4 * batch * embedding
 		partialOffload = 4 * batch * embedding
 		partialOffload += max(
 		partialOffload += max(
 			// 4*batch*(4+6*embedding+context*(2*heads)+llm.KV().GQA()),
 			// 4*batch*(4+6*embedding+context*(2*heads)+llm.KV().GQA()),
-			4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
+			4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV),
 			4*batch*(embedding+vocab)+embedding*vocab*105/128,
 			4*batch*(embedding+vocab)+embedding*vocab*105/128,
 		)
 		)
 
 
@@ -316,15 +343,15 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 			// mixtral 8x22b
 			// mixtral 8x22b
 			ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
 			ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
 			partialOffload = max(
 			partialOffload = max(
-				3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV),
-				4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch),
+				3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
+				4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
 			)
 			)
 		} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
 		} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
 			// mixtral 8x7b
 			// mixtral 8x7b
 			ffnGateWeight1 := ffnGateWeight.Shape[1]
 			ffnGateWeight1 := ffnGateWeight.Shape[1]
 			fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
 			fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
 			partialOffload = max(
 			partialOffload = max(
-				4*batch*(3+embedding/heads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
+				4*batch*(3+embeddingHeads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
 				4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
 				4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
 			)
 			)
 		}
 		}
@@ -368,15 +395,14 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 			fullOffload,
 			fullOffload,
 		)
 		)
 	case "deepseek2":
 	case "deepseek2":
-		keys := uint64(llm.KV()["deepseek2.attention.key_length"].(uint32))
 		fullOffload = max(
 		fullOffload = max(
 			4*batch*(3*embedding+vocab),
 			4*batch*(3*embedding+vocab),
-			4*batch*(3*embedding+2+context*(1+headsKV)+2*keys*headsKV),
+			4*batch*(3*embedding+2+context*(1+headsKV)+2*embeddingHeadsK*headsKV),
 		)
 		)
 
 
 		partialOffload = max(
 		partialOffload = max(
 			4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
 			4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
-			4*batch*(2*embedding+1+2*keys*headsKV+context+context*headsKV)+4*keys*context*headsKV+embedding*keys*headsKV*9/16,
+			4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
 		)
 		)
 	}
 	}
 
 

+ 2 - 2
llm/memory.go

@@ -115,8 +115,8 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
 		slog.Warn("model missing blk.0 layer size")
 		slog.Warn("model missing blk.0 layer size")
 	}
 	}
 
 
-	// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
-	var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
+	// fp16 k,v = sizeof(float16) * n_ctx * n_layer * (n_embd_head_k + n_embd_head_v) * n_head_kv
+	var kv uint64 = 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * (ggml.KV().EmbeddingHeadCountK() + ggml.KV().EmbeddingHeadCountV()) * ggml.KV().HeadCountKV()
 
 
 	// KV is proportional to the number of layers
 	// KV is proportional to the number of layers
 	layerSize += kv / ggml.KV().BlockCount()
 	layerSize += kv / ggml.KV().BlockCount()