소스 검색

Merge pull request #3463 from ollama/mxyng/graph-estimate

update graph size estimate
Michael Yang 1 년 전
부모
커밋
a0a15cfd5b
2개의 변경된 파일52개의 추가작업 그리고 4개의 파일을 삭제
  1. 47 0
      llm/ggml.go
  2. 5 4
      llm/server.go

+ 47 - 0
llm/ggml.go

@@ -303,3 +303,50 @@ func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
 		model:     model,
 	}, offset, nil
 }
+
+func (llm GGML) GraphSize(context, batch int) (int64, bool) {
+	embeddingLength := llm.KV().EmbeddingLength()
+	headCount := llm.KV().HeadCount()
+	headCountKV := llm.KV().HeadCountKV()
+	vocabLength := len(llm.KV()["tokenizer.ggml.tokens"].([]any))
+
+	var attnQKVWeight1 uint64 = 0
+	for _, t := range llm.Tensors() {
+		if strings.HasSuffix(t.Name, ".attn_qkv.weight") && len(t.Shape) >= 2 {
+			attnQKVWeight1 = t.Shape[1]
+			break
+		}
+	}
+
+	var ffnGate1 uint64 = 0
+	for _, t := range llm.Tensors() {
+		if strings.Index(t.Name, ".ffn_gate") > 0 && len(t.Shape) >= 2 {
+			ffnGate1 = t.Shape[1]
+			break
+		}
+	}
+
+	switch llm.KV().Architecture() {
+	case "gemma":
+		return 4 * int64(batch) * int64(embeddingLength+uint64(vocabLength)), true
+	case "phi2":
+		return max(
+			4*int64(batch)*int64(embeddingLength+uint64(vocabLength)),
+			4*int64(batch)*int64(1+4*embeddingLength+uint64(context)+attnQKVWeight1+uint64(context)*headCount),
+		), true
+	case "qwen2":
+		return max(
+			4*int64(batch)*int64(embeddingLength+uint64(vocabLength)),
+			4*int64(batch)*int64(1+2*embeddingLength+uint64(context)+uint64(context)*headCount),
+		), true
+	case "llama":
+		if ffnGate1 > 0 {
+			// moe
+			return 4 * int64(batch) * int64(2+3*embeddingLength+uint64(context)+uint64(context)*headCount+2*headCountKV+ffnGate1), true
+		}
+	
+		return 4 * int64(batch) * int64(1+4*embeddingLength+uint64(context)+uint64(context)*headCount), true
+	}
+
+	return 0, false
+}

+ 5 - 4
llm/server.go

@@ -79,10 +79,11 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 	// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
 	kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.KV().BlockCount()) * int64(ggml.KV().EmbeddingLength()) / int64(ggml.KV().HeadCount()) * int64(ggml.KV().HeadCountKV())
 
-	// this amount is the overhead + tensors in memory
-	// TODO: get this from the llama.cpp's graph calculations instead of
-	// estimating it's 1/6 * kv_cache_size * num_gqa
-	graph := int64(ggml.KV().GQA()) * kv / 6
+	graph, ok := ggml.GraphSize(opts.NumCtx, min(opts.NumCtx, opts.NumBatch))
+	if !ok {
+		graph = int64(ggml.KV().GQA()) * kv / 6
+	}
+
 	usedMemory += graph
 
 	if (usedMemory > availableMemory || slices.Contains(cpuOnlyFamilies, ggml.KV().Architecture())) && info.Library != "metal" {