ggml.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. package llm
  2. import (
  3. "cmp"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "strings"
  9. "github.com/ollama/ollama/util/bufioutil"
  10. )
  11. type GGML struct {
  12. container
  13. model
  14. }
  15. type model interface {
  16. KV() KV
  17. Tensors() Tensors
  18. }
  19. type KV map[string]any
  20. func (kv KV) u64(key string) uint64 {
  21. switch v := kv[key].(type) {
  22. case uint64:
  23. return v
  24. case uint32:
  25. return uint64(v)
  26. case float64:
  27. return uint64(v)
  28. default:
  29. return 0
  30. }
  31. }
  32. func (kv KV) Architecture() string {
  33. if s, ok := kv["general.architecture"].(string); ok {
  34. return s
  35. }
  36. return "unknown"
  37. }
  38. func (kv KV) ParameterCount() uint64 {
  39. return kv.u64("general.parameter_count")
  40. }
  41. func (kv KV) FileType() fileType {
  42. if u64 := kv.u64("general.file_type"); u64 > 0 {
  43. return fileType(uint32(u64))
  44. }
  45. return fileTypeUnknown
  46. }
  47. func (kv KV) BlockCount() uint64 {
  48. return kv.u64(fmt.Sprintf("%s.block_count", kv.Architecture()))
  49. }
  50. func (kv KV) HeadCount() uint64 {
  51. return kv.u64(fmt.Sprintf("%s.attention.head_count", kv.Architecture()))
  52. }
  53. func (kv KV) HeadCountKV() uint64 {
  54. if headCountKV := kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture())); headCountKV > 0 {
  55. return headCountKV
  56. }
  57. return 1
  58. }
  59. func (kv KV) EmbeddingHeadCount() uint64 {
  60. if heads := kv.HeadCount(); heads > 0 {
  61. return kv.EmbeddingLength() / kv.HeadCount()
  62. }
  63. return 0
  64. }
  65. func (kv KV) EmbeddingHeadCountK() uint64 {
  66. if k := kv.u64(fmt.Sprintf("%s.attention.key_length", kv.Architecture())); k > 0 {
  67. return k
  68. }
  69. return kv.EmbeddingHeadCount()
  70. }
  71. func (kv KV) EmbeddingHeadCountV() uint64 {
  72. if v := kv.u64(fmt.Sprintf("%s.attention.value_length", kv.Architecture())); v > 0 {
  73. return v
  74. }
  75. return kv.EmbeddingHeadCount()
  76. }
  77. func (kv KV) GQA() uint64 {
  78. return kv.HeadCount() / kv.HeadCountKV()
  79. }
  80. func (kv KV) EmbeddingLength() uint64 {
  81. return kv.u64(fmt.Sprintf("%s.embedding_length", kv.Architecture()))
  82. }
  83. func (kv KV) ContextLength() uint64 {
  84. return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
  85. }
  86. func (kv KV) ChatTemplate() string {
  87. s, _ := kv["tokenizer.chat_template"].(string)
  88. return s
  89. }
  90. type Tensors []*Tensor
  91. func (ts Tensors) Less(i, j int) bool {
  92. var x, y int
  93. if n, err := fmt.Sscanf(ts[i].Name, "blk.%d", &x); err != nil || n != 1 {
  94. return cmp.Less(ts[i].Name, ts[j].Name)
  95. } else if n, err := fmt.Sscanf(ts[j].Name, "blk.%d", &y); err != nil || n != 1 {
  96. return cmp.Less(ts[i].Name, ts[j].Name)
  97. }
  98. return cmp.Less(x, y)
  99. }
  100. func (ts Tensors) Len() int {
  101. return len(ts)
  102. }
  103. func (ts Tensors) Swap(i, j int) {
  104. var temp Tensor
  105. }
  106. func (ts Tensors) Layers() map[string]Layer {
  107. layers := make(map[string]Layer)
  108. for _, t := range ts {
  109. parts := strings.Split(t.Name, ".")
  110. if parts[0] == "blk" {
  111. // join first and second part, e.g. blk.%d
  112. parts = append([]string{fmt.Sprintf("%s.%s", parts[0], parts[1])}, parts[2:]...)
  113. }
  114. if _, ok := layers[parts[0]]; !ok {
  115. layers[parts[0]] = make(Layer)
  116. }
  117. layers[parts[0]][strings.Join(parts[1:], ".")] = t
  118. }
  119. return layers
  120. }
  121. type Layer map[string]*Tensor
  122. func (l Layer) size() (size uint64) {
  123. for _, t := range l {
  124. size += t.Size()
  125. }
  126. return size
  127. }
  128. type Tensor struct {
  129. Name string `json:"name"`
  130. Kind uint32 `json:"kind"`
  131. Offset uint64 `json:"-"`
  132. // Shape is the number of elements in each dimension
  133. Shape []uint64 `json:"shape"`
  134. io.WriterTo `json:"-"`
  135. }
  136. func (t Tensor) blockSize() uint64 {
  137. switch t.Kind {
  138. case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16
  139. return 1
  140. case 2, 3, 4, 5, 6, 7, 8, 9, 20: // Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, IQ4_NL
  141. return 32
  142. default: // All others
  143. return 256
  144. }
  145. }
  146. func (t Tensor) typeSize() uint64 {
  147. blockSize := t.blockSize()
  148. switch t.Kind {
  149. case 0: // FP32
  150. return 4
  151. case 1: // FP16
  152. return 2
  153. case 2: // Q4_0
  154. return 2 + blockSize/2
  155. case 3: // Q4_1
  156. return 2 + 2 + blockSize/2
  157. case 6: // Q5_0
  158. return 2 + 4 + blockSize/2
  159. case 7: // Q5_1
  160. return 2 + 2 + 4 + blockSize/2
  161. case 8: // Q8_0
  162. return 2 + blockSize
  163. case 9: // Q8_1
  164. return 4 + 4 + blockSize
  165. case 10: // Q2_K
  166. return blockSize/16 + blockSize/4 + 2 + 2
  167. case 11: // Q3_K
  168. return blockSize/8 + blockSize/4 + 12 + 2
  169. case 12: // Q4_K
  170. return 2 + 2 + 12 + blockSize/2
  171. case 13: // Q5_K
  172. return 2 + 2 + 12 + blockSize/8 + blockSize/2
  173. case 14: // Q6_K
  174. return blockSize/2 + blockSize/4 + blockSize/16 + 2
  175. case 15: // Q8_K
  176. return 2 + blockSize + 2*blockSize/16
  177. case 16: // IQ2_XXS
  178. return 2 + 2*blockSize/8
  179. case 17: // IQ2_XS
  180. return 2 + 2*blockSize/8 + blockSize/32
  181. case 18: // IQ3_XXS
  182. return 2 + blockSize/4 + blockSize/8
  183. case 19: // IQ1_S
  184. return 2 + blockSize/8 + blockSize/16
  185. case 20: // IQ4_NL
  186. return 2 + blockSize/2
  187. case 21: // IQ3_S
  188. return 2 + blockSize/4 + blockSize/8 + blockSize/32 + 4
  189. case 22: // IQ2_S
  190. return 2 + blockSize/4 + blockSize/16
  191. case 23: // IQ4_XS
  192. return 2 + 2 + blockSize/2 + blockSize/64
  193. case 24: // I8
  194. return 1
  195. case 25: // I16
  196. return 2
  197. case 26: // I32
  198. return 4
  199. case 27: // I64
  200. return 8
  201. case 28: // F64
  202. return 8
  203. case 29: // IQ1_M
  204. return blockSize/8 + blockSize/16 + blockSize/32
  205. default:
  206. return 0
  207. }
  208. }
  209. func (t Tensor) parameters() uint64 {
  210. var count uint64 = 1
  211. for _, n := range t.Shape {
  212. count *= n
  213. }
  214. return count
  215. }
  216. func (t Tensor) Size() uint64 {
  217. return t.parameters() * t.typeSize() / t.blockSize()
  218. }
  219. type container interface {
  220. Name() string
  221. Decode(io.ReadSeeker) (model, error)
  222. }
  223. const (
  224. // Magic constant for `ggml` files (unversioned).
  225. FILE_MAGIC_GGML = 0x67676d6c
  226. // Magic constant for `ggml` files (versioned, ggmf).
  227. FILE_MAGIC_GGMF = 0x67676d66
  228. // Magic constant for `ggml` files (versioned, ggjt).
  229. FILE_MAGIC_GGJT = 0x67676a74
  230. // Magic constant for `ggla` files (LoRA adapter).
  231. FILE_MAGIC_GGLA = 0x67676C61
  232. // Magic constant for `gguf` files (versioned, gguf)
  233. FILE_MAGIC_GGUF_LE = 0x46554747
  234. FILE_MAGIC_GGUF_BE = 0x47475546
  235. )
  236. var ErrUnsupportedFormat = errors.New("unsupported model format")
  237. func DetectGGMLType(b []byte) string {
  238. switch binary.LittleEndian.Uint32(b[:4]) {
  239. case FILE_MAGIC_GGML:
  240. return "ggml"
  241. case FILE_MAGIC_GGMF:
  242. return "ggmf"
  243. case FILE_MAGIC_GGJT:
  244. return "ggjt"
  245. case FILE_MAGIC_GGLA:
  246. return "ggla"
  247. case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE:
  248. return "gguf"
  249. default:
  250. return ""
  251. }
  252. }
  253. // DecodeGGML decodes a GGML model from the given reader.
  254. //
  255. // It collects array values for arrays with a size less than or equal to
  256. // maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
  257. // the maxArraySize is negative, all arrays are collected.
  258. func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
  259. if maxArraySize == 0 {
  260. maxArraySize = 1024
  261. }
  262. rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
  263. var magic uint32
  264. if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
  265. return nil, 0, err
  266. }
  267. var c container
  268. switch magic {
  269. case FILE_MAGIC_GGML, FILE_MAGIC_GGMF, FILE_MAGIC_GGJT:
  270. return nil, 0, ErrUnsupportedFormat
  271. case FILE_MAGIC_GGLA:
  272. c = &containerGGLA{}
  273. case FILE_MAGIC_GGUF_LE:
  274. c = &containerGGUF{ByteOrder: binary.LittleEndian, maxArraySize: maxArraySize}
  275. case FILE_MAGIC_GGUF_BE:
  276. c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
  277. default:
  278. return nil, 0, errors.New("invalid file magic")
  279. }
  280. model, err := c.Decode(rs)
  281. if err != nil {
  282. return nil, 0, err
  283. }
  284. offset, err := rs.Seek(0, io.SeekCurrent)
  285. if err != nil {
  286. return nil, 0, err
  287. }
  288. // final model type
  289. return &GGML{
  290. container: c,
  291. model: model,
  292. }, offset, nil
  293. }
  294. func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload uint64) {
  295. embedding := llm.KV().EmbeddingLength()
  296. heads := llm.KV().HeadCount()
  297. headsKV := llm.KV().HeadCountKV()
  298. vocab := uint64(llm.KV()["tokenizer.ggml.tokens"].(*array).size)
  299. embeddingHeads := llm.KV().EmbeddingHeadCount()
  300. embeddingHeadsK := llm.KV().EmbeddingHeadCountK()
  301. layers := llm.Tensors().Layers()
  302. switch llm.KV().Architecture() {
  303. case "llama":
  304. fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads))
  305. partialOffload = 4 * batch * embedding
  306. partialOffload += max(
  307. // 4*batch*(4+6*embedding+context*(2*heads)+llm.KV().GQA()),
  308. 4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV),
  309. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  310. )
  311. if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
  312. // mixtral 8x22b
  313. ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
  314. partialOffload = max(
  315. 3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
  316. 4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
  317. )
  318. } else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
  319. // mixtral 8x7b
  320. ffnGateWeight1 := ffnGateWeight.Shape[1]
  321. fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
  322. partialOffload = max(
  323. 4*batch*(3+embeddingHeads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
  324. 4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
  325. )
  326. }
  327. case "gemma", "gemma2":
  328. fullOffload = max(
  329. 4*batch*(embedding+vocab),
  330. 4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
  331. )
  332. partialOffload = max(
  333. 4*embedding*batch+embedding*vocab*105/128+4*vocab*batch,
  334. 4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+
  335. 4*embeddingHeadsK*context*8+
  336. embedding*embeddingHeadsK*heads*9/16,
  337. )
  338. case "command-r":
  339. fullOffload = max(
  340. 4*batch*(embedding+vocab),
  341. 4*batch*(2+4*embedding+context*(1+heads)),
  342. )
  343. partialOffload = max(
  344. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  345. 4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
  346. )
  347. case "qwen2":
  348. fullOffload = max(
  349. 4*batch*(embedding+vocab),
  350. 4*batch*(1+2*embedding+context+context*heads),
  351. )
  352. partialOffload = max(
  353. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  354. 4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
  355. )
  356. case "phi2":
  357. fullOffload = max(
  358. 4*batch*(embedding+vocab),
  359. 4*batch*(1+4*embedding+context+context*heads),
  360. )
  361. partialOffload = max(
  362. 4*batch*(2*embedding+vocab)+embedding*vocab*105/128,
  363. 4*batch*(2+3*embedding+context+context*heads),
  364. )
  365. case "stablelm":
  366. fullOffload = 4 * batch * (context*(1+heads) + 3*embedding + 2)
  367. partialOffload = max(
  368. 4*batch*(vocab+2*embedding),
  369. fullOffload,
  370. )
  371. case "deepseek2":
  372. fullOffload = max(
  373. 4*batch*(3*embedding+vocab),
  374. 4*batch*(3*embedding+2+context*(1+headsKV)+2*embeddingHeadsK*headsKV),
  375. )
  376. partialOffload = max(
  377. 4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
  378. 4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
  379. )
  380. case "chatglm":
  381. fullOffload = 4 * batch * (embedding + vocab)
  382. partialOffload = 4*batch*(embedding+vocab) + embedding*vocab*105/128
  383. if qkvBias, ok := layers["blk.0"]["attn_qkv.bias"]; ok {
  384. fullOffload = max(
  385. fullOffload,
  386. 4*batch*(2+
  387. 2*embedding+
  388. context+
  389. context*heads+
  390. embeddingHeadsK*heads+
  391. qkvBias.Shape[0]),
  392. )
  393. partialOffload = max(
  394. partialOffload,
  395. 4*batch*(1+
  396. 2*embedding+
  397. embeddingHeadsK*heads+
  398. context+
  399. context*heads)+
  400. 4*embeddingHeadsK*context+
  401. 4*context*embeddingHeadsK+
  402. 4*qkvBias.Shape[0],
  403. )
  404. }
  405. }
  406. return
  407. }