ggml.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. package llm
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "strings"
  8. )
  9. type GGML struct {
  10. container
  11. model
  12. }
  13. func (ggml *GGML) LayerSize(prefix string) (n int64) {
  14. for _, t := range ggml.Tensors() {
  15. if strings.HasPrefix(t.Name, prefix) {
  16. n += int64(t.size())
  17. }
  18. }
  19. return
  20. }
  21. const (
  22. fileTypeF32 uint32 = iota
  23. fileTypeF16
  24. fileTypeQ4_0
  25. fileTypeQ4_1
  26. fileTypeQ4_1_F16
  27. fileTypeQ8_0 uint32 = iota + 2
  28. fileTypeQ5_0
  29. fileTypeQ5_1
  30. fileTypeQ2_K
  31. fileTypeQ3_K_S
  32. fileTypeQ3_K_M
  33. fileTypeQ3_K_L
  34. fileTypeQ4_K_S
  35. fileTypeQ4_K_M
  36. fileTypeQ5_K_S
  37. fileTypeQ5_K_M
  38. fileTypeQ6_K
  39. fileTypeIQ2_XXS
  40. fileTypeIQ2_XS
  41. fileTypeQ2_K_S
  42. fileTypeQ3_K_XS
  43. fileTypeIQ3_XXS
  44. )
  45. func fileType(fileType uint32) string {
  46. switch fileType {
  47. case fileTypeF32:
  48. return "F32"
  49. case fileTypeF16:
  50. return "F16"
  51. case fileTypeQ4_0:
  52. return "Q4_0"
  53. case fileTypeQ4_1:
  54. return "Q4_1"
  55. case fileTypeQ4_1_F16:
  56. return "Q4_1_F16"
  57. case fileTypeQ8_0:
  58. return "Q8_0"
  59. case fileTypeQ5_0:
  60. return "Q5_0"
  61. case fileTypeQ5_1:
  62. return "Q5_1"
  63. case fileTypeQ2_K:
  64. return "Q2_K"
  65. case fileTypeQ3_K_S:
  66. return "Q3_K_S"
  67. case fileTypeQ3_K_M:
  68. return "Q3_K_M"
  69. case fileTypeQ3_K_L:
  70. return "Q3_K_L"
  71. case fileTypeQ4_K_S:
  72. return "Q4_K_S"
  73. case fileTypeQ4_K_M:
  74. return "Q4_K_M"
  75. case fileTypeQ5_K_S:
  76. return "Q5_K_S"
  77. case fileTypeQ5_K_M:
  78. return "Q5_K_M"
  79. case fileTypeQ6_K:
  80. return "Q6_K"
  81. case fileTypeIQ2_XXS:
  82. return "IQ2_XXS"
  83. case fileTypeIQ2_XS:
  84. return "IQ2_XS"
  85. case fileTypeQ2_K_S:
  86. return "Q2_K_S"
  87. case fileTypeQ3_K_XS:
  88. return "Q3_K_XS"
  89. case fileTypeIQ3_XXS:
  90. return "IQ3_XXS"
  91. default:
  92. return "unknown"
  93. }
  94. }
  95. type model interface {
  96. KV() KV
  97. Tensors() []*Tensor
  98. }
  99. type KV map[string]any
  100. func (kv KV) u64(key string) uint64 {
  101. switch v := kv[key].(type) {
  102. case uint64:
  103. return v
  104. case uint32:
  105. return uint64(v)
  106. case float64:
  107. return uint64(v)
  108. default:
  109. return 0
  110. }
  111. }
  112. func (kv KV) Architecture() string {
  113. if s, ok := kv["general.architecture"].(string); ok {
  114. return s
  115. }
  116. return "unknown"
  117. }
  118. func (kv KV) ParameterCount() uint64 {
  119. return kv.u64("general.parameter_count")
  120. }
  121. func (kv KV) FileType() string {
  122. if u64 := kv.u64("general.file_type"); u64 > 0 {
  123. return fileType(uint32(u64))
  124. }
  125. return "unknown"
  126. }
  127. func (kv KV) BlockCount() uint64 {
  128. return kv.u64(fmt.Sprintf("%s.block_count", kv.Architecture()))
  129. }
  130. func (kv KV) HeadCount() uint64 {
  131. return kv.u64(fmt.Sprintf("%s.attention.head_count", kv.Architecture()))
  132. }
  133. func (kv KV) HeadCountKV() uint64 {
  134. if headCountKV := kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture())); headCountKV > 0 {
  135. return headCountKV
  136. }
  137. return 1
  138. }
  139. func (kv KV) GQA() uint64 {
  140. return kv.HeadCount() / kv.HeadCountKV()
  141. }
  142. func (kv KV) EmbeddingLength() uint64 {
  143. return kv.u64(fmt.Sprintf("%s.embedding_length", kv.Architecture()))
  144. }
  145. func (kv KV) ContextLength() uint64 {
  146. return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
  147. }
  148. type Tensor struct {
  149. Name string `json:"name"`
  150. Kind uint32 `json:"kind"`
  151. Offset uint64 `json:"-"`
  152. // Shape is the number of elements in each dimension
  153. Shape []uint64 `json:"shape"`
  154. io.WriterTo `json:"-"`
  155. }
  156. func (t Tensor) blockSize() uint64 {
  157. switch {
  158. case t.Kind < 2:
  159. return 1
  160. case t.Kind < 10:
  161. return 32
  162. default:
  163. return 256
  164. }
  165. }
  166. func (t Tensor) typeSize() uint64 {
  167. blockSize := t.blockSize()
  168. switch t.Kind {
  169. case 0: // FP32
  170. return 4
  171. case 1: // FP16
  172. return 2
  173. case 2: // Q4_0
  174. return 2 + blockSize/2
  175. case 3: // Q4_1
  176. return 2 + 2 + blockSize/2
  177. case 6: // Q5_0
  178. return 2 + 4 + blockSize/2
  179. case 7: // Q5_1
  180. return 2 + 2 + 4 + blockSize/2
  181. case 8: // Q8_0
  182. return 2 + blockSize
  183. case 9: // Q8_1
  184. return 4 + 4 + blockSize
  185. case 10: // Q2_K
  186. return blockSize/16 + blockSize/4 + 2 + 2
  187. case 11: // Q3_K
  188. return blockSize/8 + blockSize/4 + 12 + 2
  189. case 12: // Q4_K
  190. return 2 + 2 + 12 + blockSize/2
  191. case 13: // Q5_K
  192. return 2 + 2 + 12 + blockSize/8 + blockSize/2
  193. case 14: // Q6_K
  194. return blockSize/2 + blockSize/4 + blockSize/16 + 2
  195. case 15: // Q8_K
  196. return 2 + blockSize + 2*blockSize/16
  197. case 16: // IQ2_XXS
  198. return 2 + 2*blockSize/8
  199. case 17: // IQ2_XS
  200. return 2 + 2*blockSize/8 + blockSize/32
  201. case 18: // IQ3_XXS
  202. return 2 + 3*blockSize/8
  203. default:
  204. return 0
  205. }
  206. }
  207. func (t Tensor) parameters() uint64 {
  208. var count uint64 = 1
  209. for _, n := range t.Shape {
  210. count *= n
  211. }
  212. return count
  213. }
  214. func (t Tensor) size() uint64 {
  215. return t.parameters() * t.typeSize() / t.blockSize()
  216. }
  217. type container interface {
  218. Name() string
  219. Decode(io.ReadSeeker) (model, error)
  220. }
  221. const (
  222. // Magic constant for `ggml` files (unversioned).
  223. FILE_MAGIC_GGML = 0x67676d6c
  224. // Magic constant for `ggml` files (versioned, ggmf).
  225. FILE_MAGIC_GGMF = 0x67676d66
  226. // Magic constant for `ggml` files (versioned, ggjt).
  227. FILE_MAGIC_GGJT = 0x67676a74
  228. // Magic constant for `ggla` files (LoRA adapter).
  229. FILE_MAGIC_GGLA = 0x67676C61
  230. // Magic constant for `gguf` files (versioned, gguf)
  231. FILE_MAGIC_GGUF_LE = 0x46554747
  232. FILE_MAGIC_GGUF_BE = 0x47475546
  233. )
  234. var ErrUnsupportedFormat = errors.New("unsupported model format")
  235. func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
  236. var magic uint32
  237. if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
  238. return nil, 0, err
  239. }
  240. var c container
  241. switch magic {
  242. case FILE_MAGIC_GGML, FILE_MAGIC_GGMF, FILE_MAGIC_GGJT:
  243. return nil, 0, ErrUnsupportedFormat
  244. case FILE_MAGIC_GGLA:
  245. c = &containerGGLA{}
  246. case FILE_MAGIC_GGUF_LE:
  247. c = &containerGGUF{ByteOrder: binary.LittleEndian}
  248. case FILE_MAGIC_GGUF_BE:
  249. c = &containerGGUF{ByteOrder: binary.BigEndian}
  250. default:
  251. return nil, 0, errors.New("invalid file magic")
  252. }
  253. model, err := c.Decode(rs)
  254. if errors.Is(err, io.EOF) {
  255. // noop
  256. } else if err != nil {
  257. return nil, 0, err
  258. }
  259. offset, err := rs.Seek(0, io.SeekCurrent)
  260. if err != nil {
  261. return nil, 0, err
  262. }
  263. // final model type
  264. return &GGML{
  265. container: c,
  266. model: model,
  267. }, offset, nil
  268. }
  269. func (llm GGML) GraphSize(context, batch int) (int64, bool) {
  270. embeddingLength := llm.KV().EmbeddingLength()
  271. headCount := llm.KV().HeadCount()
  272. headCountKV := llm.KV().HeadCountKV()
  273. vocabLength := len(llm.KV()["tokenizer.ggml.tokens"].([]any))
  274. var attnQKVWeight1 uint64 = 0
  275. for _, t := range llm.Tensors() {
  276. if strings.HasSuffix(t.Name, ".attn_qkv.weight") && len(t.Shape) >= 2 {
  277. attnQKVWeight1 = t.Shape[1]
  278. break
  279. }
  280. }
  281. var ffnGate1 uint64 = 0
  282. for _, t := range llm.Tensors() {
  283. if strings.Index(t.Name, ".ffn_gate") > 0 && len(t.Shape) >= 2 {
  284. ffnGate1 = t.Shape[1]
  285. break
  286. }
  287. }
  288. switch llm.KV().Architecture() {
  289. case "gemma", "command-r":
  290. return 4 * int64(batch) * int64(embeddingLength+uint64(vocabLength)), true
  291. case "phi2":
  292. return max(
  293. 4*int64(batch)*int64(embeddingLength+uint64(vocabLength)),
  294. 4*int64(batch)*int64(1+4*embeddingLength+uint64(context)+attnQKVWeight1+uint64(context)*headCount),
  295. ), true
  296. case "qwen2":
  297. return max(
  298. 4*int64(batch)*int64(embeddingLength+uint64(vocabLength)),
  299. 4*int64(batch)*int64(1+2*embeddingLength+uint64(context)+uint64(context)*headCount),
  300. ), true
  301. case "llama":
  302. if ffnGate1 > 0 {
  303. // moe
  304. return 4 * int64(batch) * int64(2+3*embeddingLength+uint64(context)+uint64(context)*headCount+2*headCountKV+ffnGate1), true
  305. }
  306. return 4 * int64(batch) * int64(1+4*embeddingLength+uint64(context)+uint64(context)*headCount), true
  307. }
  308. return 0, false
  309. }