ggml.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. package llm
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "strings"
  8. )
  9. type GGML struct {
  10. container
  11. model
  12. }
  13. type model interface {
  14. KV() KV
  15. Tensors() Tensors
  16. }
  17. type KV map[string]any
  18. func (kv KV) u64(key string) uint64 {
  19. switch v := kv[key].(type) {
  20. case uint64:
  21. return v
  22. case uint32:
  23. return uint64(v)
  24. case float64:
  25. return uint64(v)
  26. default:
  27. return 0
  28. }
  29. }
  30. func (kv KV) Architecture() string {
  31. if s, ok := kv["general.architecture"].(string); ok {
  32. return s
  33. }
  34. return "unknown"
  35. }
  36. func (kv KV) ParameterCount() uint64 {
  37. return kv.u64("general.parameter_count")
  38. }
  39. func (kv KV) FileType() fileType {
  40. if u64 := kv.u64("general.file_type"); u64 > 0 {
  41. return fileType(uint32(u64))
  42. }
  43. return fileTypeUnknown
  44. }
  45. func (kv KV) BlockCount() uint64 {
  46. return kv.u64(fmt.Sprintf("%s.block_count", kv.Architecture()))
  47. }
  48. func (kv KV) HeadCount() uint64 {
  49. return kv.u64(fmt.Sprintf("%s.attention.head_count", kv.Architecture()))
  50. }
  51. func (kv KV) HeadCountKV() uint64 {
  52. if headCountKV := kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture())); headCountKV > 0 {
  53. return headCountKV
  54. }
  55. return 1
  56. }
  57. func (kv KV) GQA() uint64 {
  58. return kv.HeadCount() / kv.HeadCountKV()
  59. }
  60. func (kv KV) EmbeddingLength() uint64 {
  61. return kv.u64(fmt.Sprintf("%s.embedding_length", kv.Architecture()))
  62. }
  63. func (kv KV) ContextLength() uint64 {
  64. return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
  65. }
  66. func (kv KV) ChatTemplate() string {
  67. s, _ := kv["tokenizer.chat_template"].(string)
  68. return s
  69. }
  70. type Tensors []*Tensor
  71. func (ts Tensors) Layers() map[string]Layer {
  72. layers := make(map[string]Layer)
  73. for _, t := range ts {
  74. parts := strings.Split(t.Name, ".")
  75. if parts[0] == "blk" {
  76. // join first and second part, e.g. blk.%d
  77. parts = append([]string{fmt.Sprintf("%s.%s", parts[0], parts[1])}, parts[2:]...)
  78. }
  79. if _, ok := layers[parts[0]]; !ok {
  80. layers[parts[0]] = make(Layer)
  81. }
  82. layers[parts[0]][strings.Join(parts[1:], ".")] = t
  83. }
  84. return layers
  85. }
  86. type Layer map[string]*Tensor
  87. func (l Layer) size() (size uint64) {
  88. for _, t := range l {
  89. size += t.Size()
  90. }
  91. return size
  92. }
  93. type Tensor struct {
  94. Name string `json:"name"`
  95. Kind uint32 `json:"kind"`
  96. Offset uint64 `json:"-"`
  97. // Shape is the number of elements in each dimension
  98. Shape []uint64 `json:"shape"`
  99. io.WriterTo `json:"-"`
  100. }
  101. func (t Tensor) blockSize() uint64 {
  102. switch t.Kind {
  103. case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16
  104. return 1
  105. case 2, 3, 4, 5, 6, 7, 8, 9, 20: // Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, IQ4_NL
  106. return 32
  107. default: // All others
  108. return 256
  109. }
  110. }
  111. func (t Tensor) typeSize() uint64 {
  112. blockSize := t.blockSize()
  113. switch t.Kind {
  114. case 0: // FP32
  115. return 4
  116. case 1: // FP16
  117. return 2
  118. case 2: // Q4_0
  119. return 2 + blockSize/2
  120. case 3: // Q4_1
  121. return 2 + 2 + blockSize/2
  122. case 6: // Q5_0
  123. return 2 + 4 + blockSize/2
  124. case 7: // Q5_1
  125. return 2 + 2 + 4 + blockSize/2
  126. case 8: // Q8_0
  127. return 2 + blockSize
  128. case 9: // Q8_1
  129. return 4 + 4 + blockSize
  130. case 10: // Q2_K
  131. return blockSize/16 + blockSize/4 + 2 + 2
  132. case 11: // Q3_K
  133. return blockSize/8 + blockSize/4 + 12 + 2
  134. case 12: // Q4_K
  135. return 2 + 2 + 12 + blockSize/2
  136. case 13: // Q5_K
  137. return 2 + 2 + 12 + blockSize/8 + blockSize/2
  138. case 14: // Q6_K
  139. return blockSize/2 + blockSize/4 + blockSize/16 + 2
  140. case 15: // Q8_K
  141. return 2 + blockSize + 2*blockSize/16
  142. case 16: // IQ2_XXS
  143. return 2 + 2*blockSize/8
  144. case 17: // IQ2_XS
  145. return 2 + 2*blockSize/8 + blockSize/32
  146. case 18: // IQ3_XXS
  147. return 2 + blockSize/4 + blockSize/8
  148. case 19: // IQ1_S
  149. return 2 + blockSize/8 + blockSize/16
  150. case 20: // IQ4_NL
  151. return 2 + blockSize/2
  152. case 21: // IQ3_S
  153. return 2 + blockSize/4 + blockSize/8 + blockSize/32 + 4
  154. case 22: // IQ2_S
  155. return 2 + blockSize/4 + blockSize/16
  156. case 23: // IQ4_XS
  157. return 2 + 2 + blockSize/2 + blockSize/64
  158. case 24: // I8
  159. return 1
  160. case 25: // I16
  161. return 2
  162. case 26: // I32
  163. return 4
  164. case 27: // I64
  165. return 8
  166. case 28: // F64
  167. return 8
  168. case 29: // IQ1_M
  169. return blockSize/8 + blockSize/16 + blockSize/32
  170. default:
  171. return 0
  172. }
  173. }
  174. func (t Tensor) parameters() uint64 {
  175. var count uint64 = 1
  176. for _, n := range t.Shape {
  177. count *= n
  178. }
  179. return count
  180. }
  181. func (t Tensor) Size() uint64 {
  182. return t.parameters() * t.typeSize() / t.blockSize()
  183. }
  184. type container interface {
  185. Name() string
  186. Decode(io.ReadSeeker) (model, error)
  187. }
  188. const (
  189. // Magic constant for `ggml` files (unversioned).
  190. FILE_MAGIC_GGML = 0x67676d6c
  191. // Magic constant for `ggml` files (versioned, ggmf).
  192. FILE_MAGIC_GGMF = 0x67676d66
  193. // Magic constant for `ggml` files (versioned, ggjt).
  194. FILE_MAGIC_GGJT = 0x67676a74
  195. // Magic constant for `ggla` files (LoRA adapter).
  196. FILE_MAGIC_GGLA = 0x67676C61
  197. // Magic constant for `gguf` files (versioned, gguf)
  198. FILE_MAGIC_GGUF_LE = 0x46554747
  199. FILE_MAGIC_GGUF_BE = 0x47475546
  200. )
  201. var ErrUnsupportedFormat = errors.New("unsupported model format")
  202. func DetectGGMLType(b []byte) string {
  203. switch binary.LittleEndian.Uint32(b[:4]) {
  204. case FILE_MAGIC_GGML:
  205. return "ggml"
  206. case FILE_MAGIC_GGMF:
  207. return "ggmf"
  208. case FILE_MAGIC_GGJT:
  209. return "ggjt"
  210. case FILE_MAGIC_GGLA:
  211. return "ggla"
  212. case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE:
  213. return "gguf"
  214. default:
  215. return ""
  216. }
  217. }
  218. func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
  219. var magic uint32
  220. if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
  221. return nil, 0, err
  222. }
  223. var c container
  224. switch magic {
  225. case FILE_MAGIC_GGML, FILE_MAGIC_GGMF, FILE_MAGIC_GGJT:
  226. return nil, 0, ErrUnsupportedFormat
  227. case FILE_MAGIC_GGLA:
  228. c = &containerGGLA{}
  229. case FILE_MAGIC_GGUF_LE:
  230. c = &containerGGUF{ByteOrder: binary.LittleEndian}
  231. case FILE_MAGIC_GGUF_BE:
  232. c = &containerGGUF{ByteOrder: binary.BigEndian}
  233. default:
  234. return nil, 0, errors.New("invalid file magic")
  235. }
  236. model, err := c.Decode(rs)
  237. if errors.Is(err, io.EOF) {
  238. // noop
  239. } else if err != nil {
  240. return nil, 0, err
  241. }
  242. offset, err := rs.Seek(0, io.SeekCurrent)
  243. if err != nil {
  244. return nil, 0, err
  245. }
  246. // final model type
  247. return &GGML{
  248. container: c,
  249. model: model,
  250. }, offset, nil
  251. }
  252. func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload uint64) {
  253. embedding := llm.KV().EmbeddingLength()
  254. heads := llm.KV().HeadCount()
  255. headsKV := llm.KV().HeadCountKV()
  256. vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
  257. layers := llm.Tensors().Layers()
  258. switch llm.KV().Architecture() {
  259. case "llama":
  260. fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads))
  261. partialOffload = 4 * batch * embedding
  262. partialOffload += max(
  263. // 4*batch*(4+6*embedding+context*(2*heads)+llm.KV().GQA()),
  264. 4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
  265. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  266. )
  267. if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
  268. // mixtral 8x22b
  269. ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
  270. partialOffload = max(
  271. 3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV),
  272. 4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch),
  273. )
  274. } else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
  275. // mixtral 8x7b
  276. ffnGateWeight1 := ffnGateWeight.Shape[1]
  277. fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
  278. partialOffload = max(
  279. 4*batch*(3+embedding/heads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
  280. 4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
  281. )
  282. }
  283. case "gemma":
  284. fullOffload = 4 * batch * (embedding + vocab)
  285. partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
  286. case "command-r":
  287. fullOffload = max(
  288. 4*batch*(embedding+vocab),
  289. 4*batch*(2+4*embedding+context*(1+heads)),
  290. )
  291. partialOffload = max(
  292. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  293. 4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
  294. )
  295. case "qwen2":
  296. fullOffload = max(
  297. 4*batch*(embedding+vocab),
  298. 4*batch*(1+2*embedding+context+context*heads),
  299. )
  300. partialOffload = max(
  301. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  302. 4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
  303. )
  304. case "phi2":
  305. fullOffload = max(
  306. 4*batch*(embedding+vocab),
  307. 4*batch*(1+4*embedding+context+context*heads),
  308. )
  309. partialOffload = max(
  310. 4*batch*(2*embedding+vocab)+embedding*vocab*105/128,
  311. 4*batch*(2+3*embedding+context+context*heads),
  312. )
  313. case "stablelm":
  314. fullOffload = 4 * batch * (context*(1+heads) + 3*embedding + 2)
  315. partialOffload = max(
  316. 4*batch*(vocab+2*embedding),
  317. fullOffload,
  318. )
  319. }
  320. return
  321. }