ggml.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. package ggml
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log/slog"
  8. "slices"
  9. "strings"
  10. "github.com/ollama/ollama/fs/util/bufioutil"
  11. )
  12. type GGML struct {
  13. container
  14. model
  15. }
  16. type model interface {
  17. KV() KV
  18. Tensors() Tensors
  19. }
  20. type KV map[string]any
  21. func (kv KV) Architecture() string {
  22. return kv.String("general.architecture", "unknown")
  23. }
  24. func (kv KV) Kind() string {
  25. return kv.String("general.type", "unknown")
  26. }
  27. func (kv KV) ParameterCount() uint64 {
  28. return keyValue[uint64](kv, "general.parameter_count")
  29. }
  30. func (kv KV) FileType() fileType {
  31. if t := kv.Uint("general.file_type"); t > 0 {
  32. return fileType(t)
  33. }
  34. return fileTypeUnknown
  35. }
  36. func (kv KV) BlockCount() uint64 {
  37. return uint64(kv.Uint("block_count"))
  38. }
  39. func (kv KV) EmbeddingLength() uint64 {
  40. return uint64(kv.Uint("embedding_length"))
  41. }
  42. func (kv KV) HeadCount() uint64 {
  43. return uint64(kv.Uint("attention.head_count"))
  44. }
  45. func (kv KV) HeadCountKV() uint64 {
  46. return uint64(kv.Uint("attention.head_count_kv", 1))
  47. }
  48. func (kv KV) EmbeddingHeadCount() uint64 {
  49. if heads := kv.HeadCount(); heads > 0 {
  50. return kv.EmbeddingLength() / heads
  51. }
  52. return 0
  53. }
  54. func (kv KV) EmbeddingHeadCountK() uint64 {
  55. return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
  56. }
  57. func (kv KV) EmbeddingHeadCountV() uint64 {
  58. return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
  59. }
  60. func (kv KV) GQA() uint64 {
  61. return kv.HeadCount() / kv.HeadCountKV()
  62. }
  63. func (kv KV) ContextLength() uint64 {
  64. return uint64(kv.Uint("context_length"))
  65. }
  66. func (kv KV) ChatTemplate() string {
  67. return kv.String("tokenizer.chat_template")
  68. }
  69. func (kv KV) String(key string, defaultValue ...string) string {
  70. return keyValue(kv, key, append(defaultValue, "")...)
  71. }
  72. func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
  73. return keyValue(kv, key, append(defaultValue, 0)...)
  74. }
  75. func (kv KV) Float(key string, defaultValue ...float32) float32 {
  76. return keyValue(kv, key, append(defaultValue, 0)...)
  77. }
  78. func (kv KV) Strings(key string, defaultValue ...[]string) []string {
  79. r := keyValue(kv, key, &array{})
  80. s := make([]string, r.size)
  81. for i := range r.size {
  82. s[i] = r.values[i].(string)
  83. }
  84. return s
  85. }
  86. func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
  87. r := keyValue(kv, key, &array{})
  88. s := make([]uint32, r.size)
  89. for i := range r.size {
  90. s[i] = uint32(r.values[i].(int32))
  91. }
  92. return s
  93. }
  94. func keyValue[T string | uint32 | uint64 | float32 | *array](kv KV, key string, defaultValue ...T) T {
  95. if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
  96. key = kv.Architecture() + "." + key
  97. }
  98. if val, ok := kv[key]; ok {
  99. return val.(T)
  100. }
  101. slog.Warn("key not found", "key", key, "default", defaultValue[0])
  102. return defaultValue[0]
  103. }
  104. type Tensors struct {
  105. items []*Tensor
  106. Offset uint64
  107. }
  108. func (s Tensors) Items(prefix ...string) []*Tensor {
  109. if len(prefix) == 0 {
  110. return s.items
  111. }
  112. var items []*Tensor
  113. for _, t := range s.items {
  114. if strings.HasPrefix(t.Name, prefix[0]) {
  115. items = append(items, t)
  116. }
  117. }
  118. return items
  119. }
  120. func (ts Tensors) Layers() map[string]Layer {
  121. layers := make(map[string]Layer)
  122. for _, t := range ts.items {
  123. parts := strings.Split(t.Name, ".")
  124. if i := slices.Index(parts, "blk"); i > 0 {
  125. parts = append([]string{
  126. strings.Join(parts[:i], "."),
  127. strings.Join(parts[i:i+2], "."),
  128. }, parts[i+2:]...)
  129. } else if i == 0 {
  130. parts = append([]string{
  131. strings.Join(parts[i:i+2], "."),
  132. }, parts[i+2:]...)
  133. }
  134. if _, ok := layers[parts[0]]; !ok {
  135. layers[parts[0]] = make(Layer)
  136. }
  137. layers[parts[0]][strings.Join(parts[1:], ".")] = t
  138. }
  139. return layers
  140. }
  141. type Layer map[string]*Tensor
  142. func (l Layer) Size() (size uint64) {
  143. for _, t := range l {
  144. size += t.Size()
  145. }
  146. return size
  147. }
  148. type Tensor struct {
  149. Name string `json:"name"`
  150. Kind uint32 `json:"kind"`
  151. Offset uint64 `json:"-"`
  152. // Shape is the number of elements in each dimension
  153. Shape []uint64 `json:"shape"`
  154. io.WriterTo `json:"-"`
  155. }
  156. func (t Tensor) block() (n int) {
  157. if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
  158. return -1
  159. }
  160. return
  161. }
  162. func (t Tensor) blockSize() uint64 {
  163. switch t.Kind {
  164. case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16
  165. return 1
  166. case 2, 3, 4, 5, 6, 7, 8, 9, 20: // Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, IQ4_NL
  167. return 32
  168. default: // All others
  169. return 256
  170. }
  171. }
  172. func (t Tensor) typeSize() uint64 {
  173. blockSize := t.blockSize()
  174. switch t.Kind {
  175. case 0: // FP32
  176. return 4
  177. case 1: // FP16
  178. return 2
  179. case 2: // Q4_0
  180. return 2 + blockSize/2
  181. case 3: // Q4_1
  182. return 2 + 2 + blockSize/2
  183. case 6: // Q5_0
  184. return 2 + 4 + blockSize/2
  185. case 7: // Q5_1
  186. return 2 + 2 + 4 + blockSize/2
  187. case 8: // Q8_0
  188. return 2 + blockSize
  189. case 9: // Q8_1
  190. return 4 + 4 + blockSize
  191. case 10: // Q2_K
  192. return blockSize/16 + blockSize/4 + 2 + 2
  193. case 11: // Q3_K
  194. return blockSize/8 + blockSize/4 + 12 + 2
  195. case 12: // Q4_K
  196. return 2 + 2 + 12 + blockSize/2
  197. case 13: // Q5_K
  198. return 2 + 2 + 12 + blockSize/8 + blockSize/2
  199. case 14: // Q6_K
  200. return blockSize/2 + blockSize/4 + blockSize/16 + 2
  201. case 15: // Q8_K
  202. return 2 + blockSize + 2*blockSize/16
  203. case 16: // IQ2_XXS
  204. return 2 + 2*blockSize/8
  205. case 17: // IQ2_XS
  206. return 2 + 2*blockSize/8 + blockSize/32
  207. case 18: // IQ3_XXS
  208. return 2 + blockSize/4 + blockSize/8
  209. case 19: // IQ1_S
  210. return 2 + blockSize/8 + blockSize/16
  211. case 20: // IQ4_NL
  212. return 2 + blockSize/2
  213. case 21: // IQ3_S
  214. return 2 + blockSize/4 + blockSize/8 + blockSize/32 + 4
  215. case 22: // IQ2_S
  216. return 2 + blockSize/4 + blockSize/16
  217. case 23: // IQ4_XS
  218. return 2 + 2 + blockSize/2 + blockSize/64
  219. case 24: // I8
  220. return 1
  221. case 25: // I16
  222. return 2
  223. case 26: // I32
  224. return 4
  225. case 27: // I64
  226. return 8
  227. case 28: // F64
  228. return 8
  229. case 29: // IQ1_M
  230. return blockSize/8 + blockSize/16 + blockSize/32
  231. default:
  232. return 0
  233. }
  234. }
  235. func (t Tensor) parameters() uint64 {
  236. var count uint64 = 1
  237. for _, n := range t.Shape {
  238. count *= n
  239. }
  240. return count
  241. }
  242. func (t Tensor) Size() uint64 {
  243. return t.parameters() * t.typeSize() / t.blockSize()
  244. }
  245. type container interface {
  246. Name() string
  247. Decode(io.ReadSeeker) (model, error)
  248. }
  249. const (
  250. // Magic constant for `ggml` files (unversioned).
  251. FILE_MAGIC_GGML = 0x67676d6c
  252. // Magic constant for `ggml` files (versioned, ggmf).
  253. FILE_MAGIC_GGMF = 0x67676d66
  254. // Magic constant for `ggml` files (versioned, ggjt).
  255. FILE_MAGIC_GGJT = 0x67676a74
  256. // Magic constant for `ggla` files (LoRA adapter).
  257. FILE_MAGIC_GGLA = 0x67676C61
  258. // Magic constant for `gguf` files (versioned, gguf)
  259. FILE_MAGIC_GGUF_LE = 0x46554747
  260. FILE_MAGIC_GGUF_BE = 0x47475546
  261. )
  262. var ErrUnsupportedFormat = errors.New("unsupported model format")
  263. func DetectContentType(b []byte) string {
  264. switch binary.LittleEndian.Uint32(b[:4]) {
  265. case FILE_MAGIC_GGML:
  266. return "ggml"
  267. case FILE_MAGIC_GGMF:
  268. return "ggmf"
  269. case FILE_MAGIC_GGJT:
  270. return "ggjt"
  271. case FILE_MAGIC_GGLA:
  272. return "ggla"
  273. case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE:
  274. return "gguf"
  275. default:
  276. return ""
  277. }
  278. }
  279. // Decode decodes a GGML model from the given reader.
  280. //
  281. // It collects array values for arrays with a size less than or equal to
  282. // maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
  283. // the maxArraySize is negative, all arrays are collected.
  284. func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
  285. if maxArraySize == 0 {
  286. maxArraySize = 1024
  287. }
  288. rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
  289. var magic uint32
  290. if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
  291. return nil, 0, err
  292. }
  293. var c container
  294. switch magic {
  295. case FILE_MAGIC_GGUF_LE:
  296. c = &containerGGUF{ByteOrder: binary.LittleEndian, maxArraySize: maxArraySize}
  297. case FILE_MAGIC_GGUF_BE:
  298. c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
  299. default:
  300. return nil, 0, errors.New("invalid file magic")
  301. }
  302. model, err := c.Decode(rs)
  303. if err != nil {
  304. return nil, 0, err
  305. }
  306. offset, err := rs.Seek(0, io.SeekCurrent)
  307. if err != nil {
  308. return nil, 0, err
  309. }
  310. // final model type
  311. return &GGML{
  312. container: c,
  313. model: model,
  314. }, offset, nil
  315. }
  316. func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
  317. embedding := llm.KV().EmbeddingLength()
  318. heads := llm.KV().HeadCount()
  319. headsKV := llm.KV().HeadCountKV()
  320. vocab := uint64(llm.KV()["tokenizer.ggml.tokens"].(*array).size)
  321. embeddingHeads := llm.KV().EmbeddingHeadCount()
  322. embeddingHeadsK := llm.KV().EmbeddingHeadCountK()
  323. embeddingHeadsV := llm.KV().EmbeddingHeadCountV()
  324. layers := llm.Tensors().Layers()
  325. bytesPerElement := kvCacheBytesPerElement(kvCacheType)
  326. kv = uint64(float64(context*llm.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
  327. switch llm.KV().Architecture() {
  328. case "llama":
  329. fullOffload = max(
  330. 4*batch*(1+4*embedding+context*(1+heads)),
  331. 4*batch*(embedding+vocab),
  332. )
  333. partialOffload = 4 * batch * embedding
  334. partialOffload += max(
  335. 4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV),
  336. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  337. )
  338. if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
  339. // mixtral 8x22b
  340. ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
  341. partialOffload = max(
  342. 3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
  343. 4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
  344. )
  345. } else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
  346. // mixtral 8x7b
  347. ffnGateWeight1 := ffnGateWeight.Shape[1]
  348. fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
  349. partialOffload = max(
  350. 4*batch*(3+embeddingHeads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
  351. 4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
  352. )
  353. }
  354. case "mllama":
  355. var visionTokens, tiles uint64 = 1601, 4
  356. if crossAttentionLayers, ok := llm.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
  357. kv = headsKV *
  358. (embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
  359. (2* // sizeof(float16)
  360. (llm.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
  361. context +
  362. 4* // sizeof(float32)
  363. uint64(crossAttentionLayers.size)* // num cross attention layers
  364. visionTokens*
  365. tiles)
  366. }
  367. fullOffload = max(
  368. 4*batch*(2+3*embedding+embeddingHeadsK*heads+context*(1+heads)),
  369. // vocab graph
  370. 4*batch*(embedding+vocab),
  371. )
  372. var ropeFreqsCount uint64
  373. if ropeFreqs, ok := llm.Tensors().Layers()["rope_freqs"]; ok {
  374. if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
  375. ropeFreqsCount = ropeFreqsWeights.parameters()
  376. }
  377. }
  378. partialOffload = max(
  379. 4*(batch*
  380. (2*embedding+1+context*(1+heads)+embeddingHeadsK*heads)+
  381. ropeFreqsCount+
  382. embeddingHeadsK*context*headsKV),
  383. // vocab graph
  384. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  385. )
  386. case "gemma", "gemma2":
  387. fullOffload = max(
  388. 4*batch*(embedding+vocab),
  389. 4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
  390. )
  391. partialOffload = max(
  392. 4*embedding*batch+embedding*vocab*105/128+4*vocab*batch,
  393. 4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+
  394. 4*embeddingHeadsK*context*8+
  395. embedding*embeddingHeadsK*heads*9/16,
  396. )
  397. case "command-r":
  398. fullOffload = max(
  399. 4*batch*(embedding+vocab),
  400. 4*batch*(2+4*embedding+context*(1+heads)),
  401. )
  402. partialOffload = max(
  403. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  404. 4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
  405. )
  406. case "qwen2":
  407. fullOffload = max(
  408. 4*batch*(embedding+vocab),
  409. 4*batch*(1+2*embedding+context+context*heads),
  410. )
  411. partialOffload = max(
  412. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  413. 4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
  414. )
  415. case "phi2":
  416. fullOffload = max(
  417. 4*batch*(embedding+vocab),
  418. 4*batch*(1+4*embedding+context+context*heads),
  419. )
  420. partialOffload = max(
  421. 4*batch*(2*embedding+vocab)+embedding*vocab*105/128,
  422. 4*batch*(2+3*embedding+context+context*heads),
  423. )
  424. case "stablelm":
  425. fullOffload = 4 * batch * (context*(1+heads) + 3*embedding + 2)
  426. partialOffload = max(
  427. 4*batch*(vocab+2*embedding),
  428. fullOffload,
  429. )
  430. case "deepseek2":
  431. fullOffload = max(
  432. 4*batch*(3*embedding+vocab),
  433. 4*batch*(3*embedding+2+context*(1+headsKV)+2*embeddingHeadsK*headsKV),
  434. )
  435. partialOffload = max(
  436. 4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
  437. 4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
  438. )
  439. case "chatglm":
  440. fullOffload = 4 * batch * (embedding + vocab)
  441. partialOffload = 4*batch*(embedding+vocab) + embedding*vocab*105/128
  442. if qkvBias, ok := layers["blk.0"]["attn_qkv.bias"]; ok {
  443. fullOffload = max(
  444. fullOffload,
  445. 4*batch*(2+
  446. 2*embedding+
  447. context+
  448. context*heads+
  449. embeddingHeadsK*heads+
  450. qkvBias.Shape[0]),
  451. )
  452. partialOffload = max(
  453. partialOffload,
  454. 4*batch*(1+
  455. 2*embedding+
  456. embeddingHeadsK*heads+
  457. context+
  458. context*heads)+
  459. 4*embeddingHeadsK*context+
  460. 4*context*embeddingHeadsK+
  461. 4*qkvBias.Shape[0],
  462. )
  463. }
  464. }
  465. return
  466. }
  467. // SupportsKVCacheType checks if the requested cache type is supported
  468. func (llm GGML) SupportsKVCacheType(cacheType string) bool {
  469. return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
  470. }
  471. // SupportsFlashAttention checks if the model supports flash attention
  472. func (llm GGML) SupportsFlashAttention() bool {
  473. _, isEmbedding := llm.KV()[fmt.Sprintf("%s.pooling_type", llm.KV().Architecture())]
  474. if isEmbedding {
  475. return false
  476. }
  477. // Check head counts match and are non-zero
  478. headCountK := llm.KV().EmbeddingHeadCountK()
  479. headCountV := llm.KV().EmbeddingHeadCountV()
  480. return headCountK != 0 && headCountV != 0 && headCountK == headCountV
  481. }
  482. // kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
  483. func kvCacheBytesPerElement(cacheType string) float64 {
  484. switch cacheType {
  485. case "q8_0":
  486. return 1 // 1/2 of fp16
  487. case "q4_0":
  488. return 0.5 // 1/4 of fp16
  489. default:
  490. return 2 // f16 (default)
  491. }
  492. }