ggml.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. package llm
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "strings"
  8. "github.com/ollama/ollama/util/bufioutil"
  9. )
  10. type GGML struct {
  11. container
  12. model
  13. }
  14. type model interface {
  15. KV() KV
  16. Tensors() Tensors
  17. }
  18. type KV map[string]any
  19. func (kv KV) u64(key string) uint64 {
  20. switch v := kv[key].(type) {
  21. case uint64:
  22. return v
  23. case uint32:
  24. return uint64(v)
  25. case float64:
  26. return uint64(v)
  27. default:
  28. return 0
  29. }
  30. }
  31. func (kv KV) Architecture() string {
  32. if s, ok := kv["general.architecture"].(string); ok {
  33. return s
  34. }
  35. return "unknown"
  36. }
  37. func (kv KV) ParameterCount() uint64 {
  38. return kv.u64("general.parameter_count")
  39. }
  40. func (kv KV) FileType() fileType {
  41. if u64 := kv.u64("general.file_type"); u64 > 0 {
  42. return fileType(uint32(u64))
  43. }
  44. return fileTypeUnknown
  45. }
  46. func (kv KV) BlockCount() uint64 {
  47. return kv.u64(fmt.Sprintf("%s.block_count", kv.Architecture()))
  48. }
  49. func (kv KV) HeadCount() uint64 {
  50. return kv.u64(fmt.Sprintf("%s.attention.head_count", kv.Architecture()))
  51. }
  52. func (kv KV) HeadCountKV() uint64 {
  53. if headCountKV := kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture())); headCountKV > 0 {
  54. return headCountKV
  55. }
  56. return 1
  57. }
  58. func (kv KV) EmbeddingHeadCount() uint64 {
  59. if heads := kv.HeadCount(); heads > 0 {
  60. return kv.EmbeddingLength() / kv.HeadCount()
  61. }
  62. return 0
  63. }
  64. func (kv KV) EmbeddingHeadCountK() uint64 {
  65. if k := kv.u64(fmt.Sprintf("%s.attention.key_length", kv.Architecture())); k > 0 {
  66. return k
  67. }
  68. return kv.EmbeddingHeadCount()
  69. }
  70. func (kv KV) EmbeddingHeadCountV() uint64 {
  71. if v := kv.u64(fmt.Sprintf("%s.attention.value_length", kv.Architecture())); v > 0 {
  72. return v
  73. }
  74. return kv.EmbeddingHeadCount()
  75. }
  76. func (kv KV) GQA() uint64 {
  77. return kv.HeadCount() / kv.HeadCountKV()
  78. }
  79. func (kv KV) EmbeddingLength() uint64 {
  80. return kv.u64(fmt.Sprintf("%s.embedding_length", kv.Architecture()))
  81. }
  82. func (kv KV) ContextLength() uint64 {
  83. return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
  84. }
  85. func (kv KV) ChatTemplate() string {
  86. s, _ := kv["tokenizer.chat_template"].(string)
  87. return s
  88. }
  89. // Tensors type as a slice of pointers to Tensor
  90. // type Tensors []*Tensor
  91. type Tensors struct {
  92. Items []*Tensor
  93. Offset int64
  94. }
  95. // Implement the Len method
  96. func (ts Tensors) Len() int {
  97. return len(ts.Items)
  98. }
  99. // Implement the Swap method
  100. func (ts Tensors) Swap(i, j int) {
  101. ts.Items[i], ts.Items[j] = ts.Items[j], ts.Items[i]
  102. }
  103. // Implement the Less method
  104. func (ts Tensors) Less(i, j int) bool {
  105. var x, y int
  106. if n, err := fmt.Sscanf(ts.Items[i].Name, "blk.%d", &x); err != nil || n != 1 {
  107. return ts.Items[i].Name < ts.Items[j].Name
  108. } else if n, err := fmt.Sscanf(ts.Items[j].Name, "blk.%d", &y); err != nil || n != 1 {
  109. return ts.Items[i].Name < ts.Items[j].Name
  110. }
  111. return x < y
  112. }
  113. func (ts Tensors) Layers() map[string]Layer {
  114. layers := make(map[string]Layer)
  115. for _, t := range ts.Items {
  116. parts := strings.Split(t.Name, ".")
  117. if parts[0] == "blk" {
  118. // join first and second part, e.g. blk.%d
  119. parts = append([]string{fmt.Sprintf("%s.%s", parts[0], parts[1])}, parts[2:]...)
  120. }
  121. if _, ok := layers[parts[0]]; !ok {
  122. layers[parts[0]] = make(Layer)
  123. }
  124. layers[parts[0]][strings.Join(parts[1:], ".")] = t
  125. }
  126. return layers
  127. }
  128. type Layer map[string]*Tensor
  129. func (l Layer) size() (size uint64) {
  130. for _, t := range l {
  131. size += t.Size()
  132. }
  133. return size
  134. }
  135. type Tensor struct {
  136. Name string `json:"name"`
  137. Kind uint32 `json:"kind"`
  138. Offset uint64 `json:"-"`
  139. // Shape is the number of elements in each dimension
  140. Shape []uint64 `json:"shape"`
  141. io.WriterTo `json:"-"`
  142. }
  143. func (t Tensor) blockSize() uint64 {
  144. switch t.Kind {
  145. case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16
  146. return 1
  147. case 2, 3, 4, 5, 6, 7, 8, 9, 20: // Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, IQ4_NL
  148. return 32
  149. default: // All others
  150. return 256
  151. }
  152. }
  153. func (t Tensor) typeSize() uint64 {
  154. blockSize := t.blockSize()
  155. switch t.Kind {
  156. case 0: // FP32
  157. return 4
  158. case 1: // FP16
  159. return 2
  160. case 2: // Q4_0
  161. return 2 + blockSize/2
  162. case 3: // Q4_1
  163. return 2 + 2 + blockSize/2
  164. case 6: // Q5_0
  165. return 2 + 4 + blockSize/2
  166. case 7: // Q5_1
  167. return 2 + 2 + 4 + blockSize/2
  168. case 8: // Q8_0
  169. return 2 + blockSize
  170. case 9: // Q8_1
  171. return 4 + 4 + blockSize
  172. case 10: // Q2_K
  173. return blockSize/16 + blockSize/4 + 2 + 2
  174. case 11: // Q3_K
  175. return blockSize/8 + blockSize/4 + 12 + 2
  176. case 12: // Q4_K
  177. return 2 + 2 + 12 + blockSize/2
  178. case 13: // Q5_K
  179. return 2 + 2 + 12 + blockSize/8 + blockSize/2
  180. case 14: // Q6_K
  181. return blockSize/2 + blockSize/4 + blockSize/16 + 2
  182. case 15: // Q8_K
  183. return 2 + blockSize + 2*blockSize/16
  184. case 16: // IQ2_XXS
  185. return 2 + 2*blockSize/8
  186. case 17: // IQ2_XS
  187. return 2 + 2*blockSize/8 + blockSize/32
  188. case 18: // IQ3_XXS
  189. return 2 + blockSize/4 + blockSize/8
  190. case 19: // IQ1_S
  191. return 2 + blockSize/8 + blockSize/16
  192. case 20: // IQ4_NL
  193. return 2 + blockSize/2
  194. case 21: // IQ3_S
  195. return 2 + blockSize/4 + blockSize/8 + blockSize/32 + 4
  196. case 22: // IQ2_S
  197. return 2 + blockSize/4 + blockSize/16
  198. case 23: // IQ4_XS
  199. return 2 + 2 + blockSize/2 + blockSize/64
  200. case 24: // I8
  201. return 1
  202. case 25: // I16
  203. return 2
  204. case 26: // I32
  205. return 4
  206. case 27: // I64
  207. return 8
  208. case 28: // F64
  209. return 8
  210. case 29: // IQ1_M
  211. return blockSize/8 + blockSize/16 + blockSize/32
  212. default:
  213. return 0
  214. }
  215. }
  216. func (t Tensor) parameters() uint64 {
  217. var count uint64 = 1
  218. for _, n := range t.Shape {
  219. count *= n
  220. }
  221. return count
  222. }
  223. func (t Tensor) Size() uint64 {
  224. return t.parameters() * t.typeSize() / t.blockSize()
  225. }
  226. type container interface {
  227. Name() string
  228. Decode(io.ReadSeeker) (model, error)
  229. }
  230. const (
  231. // Magic constant for `ggml` files (unversioned).
  232. FILE_MAGIC_GGML = 0x67676d6c
  233. // Magic constant for `ggml` files (versioned, ggmf).
  234. FILE_MAGIC_GGMF = 0x67676d66
  235. // Magic constant for `ggml` files (versioned, ggjt).
  236. FILE_MAGIC_GGJT = 0x67676a74
  237. // Magic constant for `ggla` files (LoRA adapter).
  238. FILE_MAGIC_GGLA = 0x67676C61
  239. // Magic constant for `gguf` files (versioned, gguf)
  240. FILE_MAGIC_GGUF_LE = 0x46554747
  241. FILE_MAGIC_GGUF_BE = 0x47475546
  242. )
  243. var ErrUnsupportedFormat = errors.New("unsupported model format")
  244. func DetectGGMLType(b []byte) string {
  245. switch binary.LittleEndian.Uint32(b[:4]) {
  246. case FILE_MAGIC_GGML:
  247. return "ggml"
  248. case FILE_MAGIC_GGMF:
  249. return "ggmf"
  250. case FILE_MAGIC_GGJT:
  251. return "ggjt"
  252. case FILE_MAGIC_GGLA:
  253. return "ggla"
  254. case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE:
  255. return "gguf"
  256. default:
  257. return ""
  258. }
  259. }
  260. // DecodeGGML decodes a GGML model from the given reader.
  261. //
  262. // It collects array values for arrays with a size less than or equal to
  263. // maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
  264. // the maxArraySize is negative, all arrays are collected.
  265. func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
  266. if maxArraySize == 0 {
  267. maxArraySize = 1024
  268. }
  269. rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
  270. var magic uint32
  271. if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
  272. return nil, 0, err
  273. }
  274. var c container
  275. switch magic {
  276. case FILE_MAGIC_GGML, FILE_MAGIC_GGMF, FILE_MAGIC_GGJT:
  277. return nil, 0, ErrUnsupportedFormat
  278. case FILE_MAGIC_GGLA:
  279. c = &containerGGLA{}
  280. case FILE_MAGIC_GGUF_LE:
  281. c = &containerGGUF{ByteOrder: binary.LittleEndian, maxArraySize: maxArraySize}
  282. case FILE_MAGIC_GGUF_BE:
  283. c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
  284. default:
  285. return nil, 0, errors.New("invalid file magic")
  286. }
  287. model, err := c.Decode(rs)
  288. if err != nil {
  289. return nil, 0, err
  290. }
  291. offset, err := rs.Seek(0, io.SeekCurrent)
  292. if err != nil {
  293. return nil, 0, err
  294. }
  295. // final model type
  296. return &GGML{
  297. container: c,
  298. model: model,
  299. }, offset, nil
  300. }
  301. func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload uint64) {
  302. embedding := llm.KV().EmbeddingLength()
  303. heads := llm.KV().HeadCount()
  304. headsKV := llm.KV().HeadCountKV()
  305. vocab := uint64(llm.KV()["tokenizer.ggml.tokens"].(*array).size)
  306. embeddingHeads := llm.KV().EmbeddingHeadCount()
  307. embeddingHeadsK := llm.KV().EmbeddingHeadCountK()
  308. layers := llm.Tensors().Layers()
  309. switch llm.KV().Architecture() {
  310. case "llama":
  311. fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads))
  312. partialOffload = 4 * batch * embedding
  313. partialOffload += max(
  314. // 4*batch*(4+6*embedding+context*(2*heads)+llm.KV().GQA()),
  315. 4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV),
  316. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  317. )
  318. if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
  319. // mixtral 8x22b
  320. ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
  321. partialOffload = max(
  322. 3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
  323. 4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
  324. )
  325. } else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
  326. // mixtral 8x7b
  327. ffnGateWeight1 := ffnGateWeight.Shape[1]
  328. fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
  329. partialOffload = max(
  330. 4*batch*(3+embeddingHeads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
  331. 4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
  332. )
  333. }
  334. case "gemma", "gemma2":
  335. fullOffload = max(
  336. 4*batch*(embedding+vocab),
  337. 4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
  338. )
  339. partialOffload = max(
  340. 4*embedding*batch+embedding*vocab*105/128+4*vocab*batch,
  341. 4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+
  342. 4*embeddingHeadsK*context*8+
  343. embedding*embeddingHeadsK*heads*9/16,
  344. )
  345. case "command-r":
  346. fullOffload = max(
  347. 4*batch*(embedding+vocab),
  348. 4*batch*(2+4*embedding+context*(1+heads)),
  349. )
  350. partialOffload = max(
  351. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  352. 4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
  353. )
  354. case "qwen2":
  355. fullOffload = max(
  356. 4*batch*(embedding+vocab),
  357. 4*batch*(1+2*embedding+context+context*heads),
  358. )
  359. partialOffload = max(
  360. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  361. 4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
  362. )
  363. case "phi2":
  364. fullOffload = max(
  365. 4*batch*(embedding+vocab),
  366. 4*batch*(1+4*embedding+context+context*heads),
  367. )
  368. partialOffload = max(
  369. 4*batch*(2*embedding+vocab)+embedding*vocab*105/128,
  370. 4*batch*(2+3*embedding+context+context*heads),
  371. )
  372. case "stablelm":
  373. fullOffload = 4 * batch * (context*(1+heads) + 3*embedding + 2)
  374. partialOffload = max(
  375. 4*batch*(vocab+2*embedding),
  376. fullOffload,
  377. )
  378. case "deepseek2":
  379. fullOffload = max(
  380. 4*batch*(3*embedding+vocab),
  381. 4*batch*(3*embedding+2+context*(1+headsKV)+2*embeddingHeadsK*headsKV),
  382. )
  383. partialOffload = max(
  384. 4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
  385. 4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
  386. )
  387. case "chatglm":
  388. fullOffload = 4 * batch * (embedding + vocab)
  389. partialOffload = 4*batch*(embedding+vocab) + embedding*vocab*105/128
  390. if qkvBias, ok := layers["blk.0"]["attn_qkv.bias"]; ok {
  391. fullOffload = max(
  392. fullOffload,
  393. 4*batch*(2+
  394. 2*embedding+
  395. context+
  396. context*heads+
  397. embeddingHeadsK*heads+
  398. qkvBias.Shape[0]),
  399. )
  400. partialOffload = max(
  401. partialOffload,
  402. 4*batch*(1+
  403. 2*embedding+
  404. embeddingHeadsK*heads+
  405. context+
  406. context*heads)+
  407. 4*embeddingHeadsK*context+
  408. 4*context*embeddingHeadsK+
  409. 4*qkvBias.Shape[0],
  410. )
  411. }
  412. }
  413. return
  414. }
  415. type TensorWriter struct {
  416. io.Reader
  417. }
  418. func (tw TensorWriter) WriteTo(w io.Writer) (int64, error) {
  419. return io.Copy(w, tw.Reader)
  420. }