ggml.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678
  1. package ggml
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log/slog"
  8. "slices"
  9. "strings"
  10. "github.com/ollama/ollama/fs/util/bufioutil"
  11. )
  12. type GGML struct {
  13. container
  14. model
  15. }
  16. type model interface {
  17. KV() KV
  18. Tensors() Tensors
  19. }
  20. type KV map[string]any
  21. func (kv KV) Architecture() string {
  22. return kv.String("general.architecture", "unknown")
  23. }
  24. func (kv KV) Kind() string {
  25. return kv.String("general.type", "unknown")
  26. }
  27. func (kv KV) ParameterCount() uint64 {
  28. return keyValue[uint64](kv, "general.parameter_count")
  29. }
  30. func (kv KV) FileType() fileType {
  31. if t := kv.Uint("general.file_type"); t > 0 {
  32. return fileType(t)
  33. }
  34. return fileTypeUnknown
  35. }
  36. func (kv KV) BlockCount() uint64 {
  37. return uint64(kv.Uint("block_count"))
  38. }
  39. func (kv KV) EmbeddingLength() uint64 {
  40. return uint64(kv.Uint("embedding_length"))
  41. }
  42. func (kv KV) HeadCount() uint64 {
  43. return uint64(kv.Uint("attention.head_count"))
  44. }
  45. func (kv KV) HeadCountKV() uint64 {
  46. return uint64(kv.Uint("attention.head_count_kv", 1))
  47. }
  48. func (kv KV) EmbeddingHeadCount() uint64 {
  49. if heads := kv.HeadCount(); heads > 0 {
  50. return kv.EmbeddingLength() / heads
  51. }
  52. return 0
  53. }
  54. func (kv KV) EmbeddingHeadCountK() uint64 {
  55. return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
  56. }
  57. func (kv KV) EmbeddingHeadCountV() uint64 {
  58. return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
  59. }
  60. func (kv KV) GQA() uint64 {
  61. return kv.HeadCount() / kv.HeadCountKV()
  62. }
  63. func (kv KV) ContextLength() uint64 {
  64. return uint64(kv.Uint("context_length"))
  65. }
  66. func (kv KV) ChatTemplate() string {
  67. return kv.String("tokenizer.chat_template")
  68. }
  69. func (kv KV) String(key string, defaultValue ...string) string {
  70. return keyValue(kv, key, append(defaultValue, "")...)
  71. }
  72. func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
  73. return keyValue(kv, key, append(defaultValue, 0)...)
  74. }
  75. func (kv KV) Float(key string, defaultValue ...float32) float32 {
  76. return keyValue(kv, key, append(defaultValue, 0)...)
  77. }
  78. func (kv KV) Bool(key string, defaultValue ...bool) bool {
  79. return keyValue(kv, key, append(defaultValue, false)...)
  80. }
  81. func (kv KV) Strings(key string, defaultValue ...[]string) []string {
  82. r := keyValue(kv, key, &array{})
  83. s := make([]string, r.size)
  84. for i := range r.size {
  85. s[i] = r.values[i].(string)
  86. }
  87. return s
  88. }
  89. func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
  90. r := keyValue(kv, key, &array{})
  91. s := make([]uint32, r.size)
  92. for i := range r.size {
  93. s[i] = uint32(r.values[i].(int32))
  94. }
  95. return s
  96. }
  97. func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
  98. r := keyValue(kv, key, &array{})
  99. s := make([]float32, r.size)
  100. for i := range r.size {
  101. s[i] = float32(r.values[i].(float32))
  102. }
  103. return s
  104. }
  105. func (kv KV) OllamaEngineRequired() bool {
  106. return kv.Architecture() == "gemma3"
  107. }
  108. func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
  109. if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
  110. key = kv.Architecture() + "." + key
  111. }
  112. if val, ok := kv[key]; ok {
  113. return val.(T)
  114. }
  115. slog.Warn("key not found", "key", key, "default", defaultValue[0])
  116. return defaultValue[0]
  117. }
  118. type Tensors struct {
  119. items []*Tensor
  120. Offset uint64
  121. }
  122. func (s Tensors) Items(prefix ...string) []*Tensor {
  123. if len(prefix) == 0 {
  124. return s.items
  125. }
  126. var items []*Tensor
  127. for _, t := range s.items {
  128. if strings.HasPrefix(t.Name, prefix[0]) {
  129. items = append(items, t)
  130. }
  131. }
  132. return items
  133. }
  134. func (ts Tensors) GroupLayers() map[string]Layer {
  135. layers := make(map[string]Layer)
  136. for _, t := range ts.items {
  137. parts := strings.Split(t.Name, ".")
  138. if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
  139. if len(parts) > index+2 {
  140. // blk and mm should have a number after them, join it
  141. parts = append(
  142. []string{strings.Join(parts[:index+2], ".")},
  143. parts[index+2:]...)
  144. }
  145. }
  146. if _, ok := layers[parts[0]]; !ok {
  147. layers[parts[0]] = make(Layer)
  148. }
  149. layers[parts[0]][strings.Join(parts[1:], ".")] = t
  150. }
  151. return layers
  152. }
  153. type Layer map[string]*Tensor
  154. func (l Layer) Size() (size uint64) {
  155. for _, t := range l {
  156. size += t.Size()
  157. }
  158. return size
  159. }
  160. type Tensor struct {
  161. Name string `json:"name"`
  162. Kind uint32 `json:"kind"`
  163. Offset uint64 `json:"-"`
  164. // Shape is the number of elements in each dimension
  165. Shape []uint64 `json:"shape"`
  166. io.WriterTo `json:"-"`
  167. }
  168. func (t Tensor) block() (n int) {
  169. if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
  170. return -1
  171. }
  172. return
  173. }
  174. func (t Tensor) blockSize() uint64 {
  175. switch t.Kind {
  176. case
  177. 0, // F32
  178. 1, // F16
  179. 24, // I8
  180. 25, // I16
  181. 26, // I32
  182. 27, // I64
  183. 28, // F64
  184. 30: // BF16
  185. return 1
  186. case
  187. 2, // Q4_0
  188. 3, // Q4_1
  189. 6, // Q5_0
  190. 7, // Q5_1
  191. 8, // Q8_0
  192. 9, // Q8_1
  193. 20: // IQ4_NL
  194. return 32
  195. default:
  196. return 256
  197. }
  198. }
  199. func (t Tensor) typeSize() uint64 {
  200. blockSize := t.blockSize()
  201. switch t.Kind {
  202. case 0: // FP32
  203. return 4
  204. case 1: // FP16
  205. return 2
  206. case 2: // Q4_0
  207. return 2 + blockSize/2
  208. case 3: // Q4_1
  209. return 2 + 2 + blockSize/2
  210. case 6: // Q5_0
  211. return 2 + 4 + blockSize/2
  212. case 7: // Q5_1
  213. return 2 + 2 + 4 + blockSize/2
  214. case 8: // Q8_0
  215. return 2 + blockSize
  216. case 9: // Q8_1
  217. return 2 + 2 + blockSize
  218. case 10: // Q2_K
  219. return blockSize/16 + blockSize/4 + 2 + 2
  220. case 11: // Q3_K
  221. return blockSize/8 + blockSize/4 + 12 + 2
  222. case 12: // Q4_K
  223. return 2 + 2 + 12 + blockSize/2
  224. case 13: // Q5_K
  225. return 2 + 2 + 12 + blockSize/8 + blockSize/2
  226. case 14: // Q6_K
  227. return blockSize/2 + blockSize/4 + blockSize/16 + 2
  228. case 15: // Q8_K
  229. return 4 + blockSize + 2*blockSize/16
  230. case 16: // IQ2_XXS
  231. return 2 + 2*blockSize/8
  232. case 17: // IQ2_XS
  233. return 2 + 2*blockSize/8 + blockSize/32
  234. case 18: // IQ3_XXS
  235. return 2 + blockSize/4 + blockSize/8
  236. case 19: // IQ1_S
  237. return 2 + blockSize/8 + blockSize/16
  238. case 20: // IQ4_NL
  239. return 2 + blockSize/2
  240. case 21: // IQ3_S
  241. return 2 + blockSize/4 + blockSize/8 + blockSize/32 + 4
  242. case 22: // IQ2_S
  243. return 2 + blockSize/4 + blockSize/16
  244. case 23: // IQ4_XS
  245. return 2 + 2 + blockSize/2 + blockSize/64
  246. case 24: // I8
  247. return 1
  248. case 25: // I16
  249. return 2
  250. case 26: // I32
  251. return 4
  252. case 27: // I64
  253. return 8
  254. case 28: // F64
  255. return 8
  256. case 29: // IQ1_M
  257. return blockSize/8 + blockSize/16 + blockSize/32
  258. case 30: // BF16
  259. return 2
  260. default:
  261. return 0
  262. }
  263. }
  264. func (t Tensor) parameters() uint64 {
  265. var count uint64 = 1
  266. for _, n := range t.Shape {
  267. count *= n
  268. }
  269. return count
  270. }
  271. func (t Tensor) Size() uint64 {
  272. return t.parameters() * t.typeSize() / t.blockSize()
  273. }
  274. func (t Tensor) Type() string {
  275. return fileType(t.Kind).String()
  276. }
  277. type container interface {
  278. Name() string
  279. Decode(io.ReadSeeker) (model, error)
  280. }
  281. const (
  282. // Magic constant for `ggml` files (unversioned).
  283. FILE_MAGIC_GGML = 0x67676d6c
  284. // Magic constant for `ggml` files (versioned, ggmf).
  285. FILE_MAGIC_GGMF = 0x67676d66
  286. // Magic constant for `ggml` files (versioned, ggjt).
  287. FILE_MAGIC_GGJT = 0x67676a74
  288. // Magic constant for `ggla` files (LoRA adapter).
  289. FILE_MAGIC_GGLA = 0x67676C61
  290. // Magic constant for `gguf` files (versioned, gguf)
  291. FILE_MAGIC_GGUF_LE = 0x46554747
  292. FILE_MAGIC_GGUF_BE = 0x47475546
  293. )
  294. var ErrUnsupportedFormat = errors.New("unsupported model format")
  295. func DetectContentType(b []byte) string {
  296. switch binary.LittleEndian.Uint32(b[:4]) {
  297. case FILE_MAGIC_GGML:
  298. return "ggml"
  299. case FILE_MAGIC_GGMF:
  300. return "ggmf"
  301. case FILE_MAGIC_GGJT:
  302. return "ggjt"
  303. case FILE_MAGIC_GGLA:
  304. return "ggla"
  305. case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE:
  306. return "gguf"
  307. default:
  308. return ""
  309. }
  310. }
  311. // Decode decodes a GGML model from the given reader.
  312. //
  313. // It collects array values for arrays with a size less than or equal to
  314. // maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
  315. // the maxArraySize is negative, all arrays are collected.
  316. func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
  317. if maxArraySize == 0 {
  318. maxArraySize = 1024
  319. }
  320. rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
  321. var magic uint32
  322. if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
  323. return nil, 0, err
  324. }
  325. var c container
  326. switch magic {
  327. case FILE_MAGIC_GGUF_LE:
  328. c = &containerGGUF{ByteOrder: binary.LittleEndian, maxArraySize: maxArraySize}
  329. case FILE_MAGIC_GGUF_BE:
  330. c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
  331. default:
  332. return nil, 0, errors.New("invalid file magic")
  333. }
  334. model, err := c.Decode(rs)
  335. if err != nil {
  336. return nil, 0, err
  337. }
  338. offset, err := rs.Seek(0, io.SeekCurrent)
  339. if err != nil {
  340. return nil, 0, err
  341. }
  342. // final model type
  343. return &GGML{
  344. container: c,
  345. model: model,
  346. }, offset, nil
  347. }
  348. func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
  349. embedding := f.KV().EmbeddingLength()
  350. heads := f.KV().HeadCount()
  351. headsKV := f.KV().HeadCountKV()
  352. vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size)
  353. embeddingHeads := f.KV().EmbeddingHeadCount()
  354. embeddingHeadsK := f.KV().EmbeddingHeadCountK()
  355. embeddingHeadsV := f.KV().EmbeddingHeadCountV()
  356. layers := f.Tensors().GroupLayers()
  357. bytesPerElement := kvCacheBytesPerElement(kvCacheType)
  358. kv = make([]uint64, f.KV().BlockCount())
  359. for i := range kv {
  360. kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
  361. }
  362. switch f.KV().Architecture() {
  363. case "llama":
  364. fullOffload = max(
  365. 4*batch*(1+4*embedding+context*(1+heads)),
  366. 4*batch*(embedding+vocab),
  367. )
  368. partialOffload = 4 * batch * embedding
  369. partialOffload += max(
  370. 4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV),
  371. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  372. )
  373. if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
  374. // mixtral 8x22b
  375. ff := uint64(f.KV()["llama.feed_forward_length"].(uint32))
  376. partialOffload = max(
  377. 3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
  378. 4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
  379. )
  380. } else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
  381. // mixtral 8x7b
  382. ffnGateWeight1 := ffnGateWeight.Shape[1]
  383. fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
  384. partialOffload = max(
  385. 4*batch*(3+embeddingHeads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
  386. 4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
  387. )
  388. }
  389. case "mllama":
  390. var visionTokens, tiles uint64 = 1601, 4
  391. crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers")
  392. for i := range kv {
  393. if slices.Contains(crossAttentionLayers, uint32(i)) {
  394. kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
  395. 4 * // sizeof(float32)
  396. visionTokens *
  397. tiles
  398. }
  399. }
  400. fullOffload = max(
  401. 4*batch*(2+3*embedding+embeddingHeadsK*heads+context*(1+heads)),
  402. // vocab graph
  403. 4*batch*(embedding+vocab),
  404. )
  405. var ropeFreqsCount uint64
  406. if ropeFreqs, ok := f.Tensors().GroupLayers()["rope_freqs"]; ok {
  407. if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
  408. ropeFreqsCount = ropeFreqsWeights.parameters()
  409. }
  410. }
  411. partialOffload = max(
  412. 4*(batch*
  413. (2*embedding+1+context*(1+heads)+embeddingHeadsK*heads)+
  414. ropeFreqsCount+
  415. embeddingHeadsK*context*headsKV),
  416. // vocab graph
  417. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  418. )
  419. case "gemma", "gemma2", "gemma3":
  420. fullOffload = max(
  421. 4*batch*(embedding+vocab),
  422. 4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
  423. )
  424. partialOffload = max(
  425. 4*embedding*batch+embedding*vocab*105/128+4*vocab*batch,
  426. 4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+
  427. 4*embeddingHeadsK*context*8+
  428. embedding*embeddingHeadsK*heads*9/16,
  429. )
  430. // Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
  431. // engine. Gemma3 always uses the Ollama engine.
  432. if f.KV().Architecture() == "gemma3" {
  433. const gemma3GlobalCacheCount = 6
  434. slidingWindow := (uint64(numParallel) * uint64(f.KV().Uint("attention.sliding_window"))) + batch
  435. for i := range kv {
  436. // Every 6th layer is a global layer, which is the full context size that has already been set. The other
  437. // layers are the smaller local (sliding) layers.
  438. if (i+1)%gemma3GlobalCacheCount != 0 {
  439. kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
  440. }
  441. }
  442. }
  443. case "command-r":
  444. fullOffload = max(
  445. 4*batch*(embedding+vocab),
  446. 4*batch*(2+4*embedding+context*(1+heads)),
  447. )
  448. partialOffload = max(
  449. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  450. 4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
  451. )
  452. case "qwen2":
  453. fullOffload = max(
  454. 4*batch*(embedding+vocab),
  455. 4*batch*(1+2*embedding+context+context*heads),
  456. )
  457. partialOffload = max(
  458. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  459. 4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
  460. )
  461. case "phi2":
  462. fullOffload = max(
  463. 4*batch*(embedding+vocab),
  464. 4*batch*(1+4*embedding+context+context*heads),
  465. )
  466. partialOffload = max(
  467. 4*batch*(2*embedding+vocab)+embedding*vocab*105/128,
  468. 4*batch*(2+3*embedding+context+context*heads),
  469. )
  470. case "stablelm":
  471. fullOffload = 4 * batch * (context*(1+heads) + 3*embedding + 2)
  472. partialOffload = max(
  473. 4*batch*(vocab+2*embedding),
  474. fullOffload,
  475. )
  476. case "deepseek2":
  477. fullOffload = max(
  478. 4*batch*(3*embedding+vocab),
  479. 4*batch*(3*embedding+2+context*(1+headsKV)+2*embeddingHeadsK*headsKV),
  480. )
  481. partialOffload = max(
  482. 4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
  483. 4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
  484. )
  485. case "chatglm":
  486. fullOffload = 4 * batch * (embedding + vocab)
  487. partialOffload = 4*batch*(embedding+vocab) + embedding*vocab*105/128
  488. if qkvBias, ok := layers["blk.0"]["attn_qkv.bias"]; ok {
  489. fullOffload = max(
  490. fullOffload,
  491. 4*batch*(2+
  492. 2*embedding+
  493. context+
  494. context*heads+
  495. embeddingHeadsK*heads+
  496. qkvBias.Shape[0]),
  497. )
  498. partialOffload = max(
  499. partialOffload,
  500. 4*batch*(1+
  501. 2*embedding+
  502. embeddingHeadsK*heads+
  503. context+
  504. context*heads)+
  505. 4*embeddingHeadsK*context+
  506. 4*context*embeddingHeadsK+
  507. 4*qkvBias.Shape[0],
  508. )
  509. }
  510. }
  511. return
  512. }
  513. func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
  514. if llm.KV().Uint("vision.block_count") == 0 {
  515. return
  516. }
  517. for name, layer := range llm.Tensors().GroupLayers() {
  518. if name == "v" || strings.HasPrefix(name, "v.") {
  519. for _, tensor := range layer {
  520. weights += tensor.Size()
  521. }
  522. }
  523. }
  524. imageSize := uint64(llm.KV().Uint("vision.image_size"))
  525. patchSize := uint64(llm.KV().Uint("vision.patch_size"))
  526. if patchSize == 0 {
  527. slog.Warn("unknown patch size for vision model")
  528. return
  529. }
  530. numChannels := uint64(llm.KV().Uint("vision.num_channels"))
  531. numPatches := (imageSize / patchSize) * (imageSize / patchSize)
  532. if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
  533. numPatches++
  534. }
  535. headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
  536. embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
  537. switch llm.KV().Architecture() {
  538. case "mllama":
  539. numPaddedPatches := numPatches + 8 - (numPatches%8)%8
  540. maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
  541. graphSize = 4 * (8 +
  542. imageSize*imageSize*numChannels*maxNumTiles +
  543. embeddingLength*numPatches*maxNumTiles +
  544. 9*embeddingLength*numPaddedPatches*maxNumTiles +
  545. numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
  546. case "gemma3":
  547. graphSize = 4 * (imageSize*imageSize*numChannels +
  548. embeddingLength*patchSize +
  549. numPatches*numPatches*headCount)
  550. }
  551. return weights, graphSize
  552. }
  553. // SupportsKVCacheType checks if the requested cache type is supported
  554. func (f GGML) SupportsKVCacheType(cacheType string) bool {
  555. return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
  556. }
  557. // SupportsFlashAttention checks if the model supports flash attention
  558. func (f GGML) SupportsFlashAttention() bool {
  559. _, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
  560. if isEmbedding {
  561. return false
  562. }
  563. // Check head counts match and are non-zero
  564. headCountK := f.KV().EmbeddingHeadCountK()
  565. headCountV := f.KV().EmbeddingHeadCountV()
  566. return headCountK != 0 && headCountV != 0 && headCountK == headCountV
  567. }
  568. // kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
  569. func kvCacheBytesPerElement(cacheType string) float64 {
  570. switch cacheType {
  571. case "q8_0":
  572. return 1 // 1/2 of fp16
  573. case "q4_0":
  574. return 0.5 // 1/4 of fp16
  575. default:
  576. return 2 // f16 (default)
  577. }
  578. }