|
@@ -148,15 +148,15 @@ func (kv KV) HeadCount() uint64 {
|
|
|
}
|
|
|
|
|
|
func (kv KV) HeadCountKV() uint64 {
|
|
|
- return kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture()))
|
|
|
+ if headCountKV := kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture())); headCountKV > 0 {
|
|
|
+ return headCountKV
|
|
|
+ }
|
|
|
+
|
|
|
+ return 1
|
|
|
}
|
|
|
|
|
|
func (kv KV) GQA() uint64 {
|
|
|
- if headCountKV := kv.HeadCountKV(); headCountKV > 0 {
|
|
|
- return kv.HeadCount() / headCountKV
|
|
|
- }
|
|
|
-
|
|
|
- return 0
|
|
|
+ return kv.HeadCount() / kv.HeadCountKV()
|
|
|
}
|
|
|
|
|
|
func (kv KV) EmbeddingLength() uint64 {
|