ggml.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. package llm
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "strings"
  8. )
  9. type GGML struct {
  10. container
  11. model
  12. }
  13. const (
  14. fileTypeF32 uint32 = iota
  15. fileTypeF16
  16. fileTypeQ4_0
  17. fileTypeQ4_1
  18. fileTypeQ4_1_F16
  19. fileTypeQ8_0 uint32 = iota + 2
  20. fileTypeQ5_0
  21. fileTypeQ5_1
  22. fileTypeQ2_K
  23. fileTypeQ3_K_S
  24. fileTypeQ3_K_M
  25. fileTypeQ3_K_L
  26. fileTypeQ4_K_S
  27. fileTypeQ4_K_M
  28. fileTypeQ5_K_S
  29. fileTypeQ5_K_M
  30. fileTypeQ6_K
  31. fileTypeIQ2_XXS
  32. fileTypeIQ2_XS
  33. fileTypeQ2_K_S
  34. fileTypeQ3_K_XS
  35. fileTypeIQ3_XXS
  36. )
  37. func fileType(fileType uint32) string {
  38. switch fileType {
  39. case fileTypeF32:
  40. return "F32"
  41. case fileTypeF16:
  42. return "F16"
  43. case fileTypeQ4_0:
  44. return "Q4_0"
  45. case fileTypeQ4_1:
  46. return "Q4_1"
  47. case fileTypeQ4_1_F16:
  48. return "Q4_1_F16"
  49. case fileTypeQ8_0:
  50. return "Q8_0"
  51. case fileTypeQ5_0:
  52. return "Q5_0"
  53. case fileTypeQ5_1:
  54. return "Q5_1"
  55. case fileTypeQ2_K:
  56. return "Q2_K"
  57. case fileTypeQ3_K_S:
  58. return "Q3_K_S"
  59. case fileTypeQ3_K_M:
  60. return "Q3_K_M"
  61. case fileTypeQ3_K_L:
  62. return "Q3_K_L"
  63. case fileTypeQ4_K_S:
  64. return "Q4_K_S"
  65. case fileTypeQ4_K_M:
  66. return "Q4_K_M"
  67. case fileTypeQ5_K_S:
  68. return "Q5_K_S"
  69. case fileTypeQ5_K_M:
  70. return "Q5_K_M"
  71. case fileTypeQ6_K:
  72. return "Q6_K"
  73. case fileTypeIQ2_XXS:
  74. return "IQ2_XXS"
  75. case fileTypeIQ2_XS:
  76. return "IQ2_XS"
  77. case fileTypeQ2_K_S:
  78. return "Q2_K_S"
  79. case fileTypeQ3_K_XS:
  80. return "Q3_K_XS"
  81. case fileTypeIQ3_XXS:
  82. return "IQ3_XXS"
  83. default:
  84. return "unknown"
  85. }
  86. }
  87. type model interface {
  88. KV() KV
  89. Tensors() Tensors
  90. }
  91. type KV map[string]any
  92. func (kv KV) u64(key string) uint64 {
  93. switch v := kv[key].(type) {
  94. case uint64:
  95. return v
  96. case uint32:
  97. return uint64(v)
  98. case float64:
  99. return uint64(v)
  100. default:
  101. return 0
  102. }
  103. }
  104. func (kv KV) Architecture() string {
  105. if s, ok := kv["general.architecture"].(string); ok {
  106. return s
  107. }
  108. return "unknown"
  109. }
  110. func (kv KV) ParameterCount() uint64 {
  111. return kv.u64("general.parameter_count")
  112. }
  113. func (kv KV) FileType() string {
  114. if u64 := kv.u64("general.file_type"); u64 > 0 {
  115. return fileType(uint32(u64))
  116. }
  117. return "unknown"
  118. }
  119. func (kv KV) BlockCount() uint64 {
  120. return kv.u64(fmt.Sprintf("%s.block_count", kv.Architecture()))
  121. }
  122. func (kv KV) HeadCount() uint64 {
  123. return kv.u64(fmt.Sprintf("%s.attention.head_count", kv.Architecture()))
  124. }
  125. func (kv KV) HeadCountKV() uint64 {
  126. if headCountKV := kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture())); headCountKV > 0 {
  127. return headCountKV
  128. }
  129. return 1
  130. }
  131. func (kv KV) GQA() uint64 {
  132. return kv.HeadCount() / kv.HeadCountKV()
  133. }
  134. func (kv KV) EmbeddingLength() uint64 {
  135. return kv.u64(fmt.Sprintf("%s.embedding_length", kv.Architecture()))
  136. }
  137. func (kv KV) ContextLength() uint64 {
  138. return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
  139. }
  140. type Tensors []*Tensor
  141. func (ts Tensors) Layers() map[string]Layer {
  142. layers := make(map[string]Layer)
  143. for _, t := range ts {
  144. parts := strings.Split(t.Name, ".")
  145. if parts[0] == "blk" {
  146. // join first and second part, e.g. blk.%d
  147. parts = append([]string{fmt.Sprintf("%s.%s", parts[0], parts[1])}, parts[2:]...)
  148. }
  149. if _, ok := layers[parts[0]]; !ok {
  150. layers[parts[0]] = make(Layer)
  151. }
  152. layers[parts[0]][strings.Join(parts[1:], ".")] = t
  153. }
  154. return layers
  155. }
  156. type Layer map[string]*Tensor
  157. func (l Layer) size() (size uint64) {
  158. for _, t := range l {
  159. size += t.size()
  160. }
  161. return size
  162. }
  163. type Tensor struct {
  164. Name string `json:"name"`
  165. Kind uint32 `json:"kind"`
  166. Offset uint64 `json:"-"`
  167. // Shape is the number of elements in each dimension
  168. Shape []uint64 `json:"shape"`
  169. io.WriterTo `json:"-"`
  170. }
  171. func (t Tensor) blockSize() uint64 {
  172. switch {
  173. case t.Kind < 2:
  174. return 1
  175. case t.Kind < 10:
  176. return 32
  177. default:
  178. return 256
  179. }
  180. }
  181. func (t Tensor) typeSize() uint64 {
  182. blockSize := t.blockSize()
  183. switch t.Kind {
  184. case 0: // FP32
  185. return 4
  186. case 1: // FP16
  187. return 2
  188. case 2: // Q4_0
  189. return 2 + blockSize/2
  190. case 3: // Q4_1
  191. return 2 + 2 + blockSize/2
  192. case 6: // Q5_0
  193. return 2 + 4 + blockSize/2
  194. case 7: // Q5_1
  195. return 2 + 2 + 4 + blockSize/2
  196. case 8: // Q8_0
  197. return 2 + blockSize
  198. case 9: // Q8_1
  199. return 4 + 4 + blockSize
  200. case 10: // Q2_K
  201. return blockSize/16 + blockSize/4 + 2 + 2
  202. case 11: // Q3_K
  203. return blockSize/8 + blockSize/4 + 12 + 2
  204. case 12: // Q4_K
  205. return 2 + 2 + 12 + blockSize/2
  206. case 13: // Q5_K
  207. return 2 + 2 + 12 + blockSize/8 + blockSize/2
  208. case 14: // Q6_K
  209. return blockSize/2 + blockSize/4 + blockSize/16 + 2
  210. case 15: // Q8_K
  211. return 2 + blockSize + 2*blockSize/16
  212. case 16: // IQ2_XXS
  213. return 2 + 2*blockSize/8
  214. case 17: // IQ2_XS
  215. return 2 + 2*blockSize/8 + blockSize/32
  216. case 18: // IQ3_XXS
  217. return 2 + 3*blockSize/8
  218. default:
  219. return 0
  220. }
  221. }
  222. func (t Tensor) parameters() uint64 {
  223. var count uint64 = 1
  224. for _, n := range t.Shape {
  225. count *= n
  226. }
  227. return count
  228. }
  229. func (t Tensor) size() uint64 {
  230. return t.parameters() * t.typeSize() / t.blockSize()
  231. }
  232. type container interface {
  233. Name() string
  234. Decode(io.ReadSeeker) (model, error)
  235. }
  236. const (
  237. // Magic constant for `ggml` files (unversioned).
  238. FILE_MAGIC_GGML = 0x67676d6c
  239. // Magic constant for `ggml` files (versioned, ggmf).
  240. FILE_MAGIC_GGMF = 0x67676d66
  241. // Magic constant for `ggml` files (versioned, ggjt).
  242. FILE_MAGIC_GGJT = 0x67676a74
  243. // Magic constant for `ggla` files (LoRA adapter).
  244. FILE_MAGIC_GGLA = 0x67676C61
  245. // Magic constant for `gguf` files (versioned, gguf)
  246. FILE_MAGIC_GGUF_LE = 0x46554747
  247. FILE_MAGIC_GGUF_BE = 0x47475546
  248. )
  249. var ErrUnsupportedFormat = errors.New("unsupported model format")
  250. func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
  251. var magic uint32
  252. if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
  253. return nil, 0, err
  254. }
  255. var c container
  256. switch magic {
  257. case FILE_MAGIC_GGML, FILE_MAGIC_GGMF, FILE_MAGIC_GGJT:
  258. return nil, 0, ErrUnsupportedFormat
  259. case FILE_MAGIC_GGLA:
  260. c = &containerGGLA{}
  261. case FILE_MAGIC_GGUF_LE:
  262. c = &containerGGUF{ByteOrder: binary.LittleEndian}
  263. case FILE_MAGIC_GGUF_BE:
  264. c = &containerGGUF{ByteOrder: binary.BigEndian}
  265. default:
  266. return nil, 0, errors.New("invalid file magic")
  267. }
  268. model, err := c.Decode(rs)
  269. if errors.Is(err, io.EOF) {
  270. // noop
  271. } else if err != nil {
  272. return nil, 0, err
  273. }
  274. offset, err := rs.Seek(0, io.SeekCurrent)
  275. if err != nil {
  276. return nil, 0, err
  277. }
  278. // final model type
  279. return &GGML{
  280. container: c,
  281. model: model,
  282. }, offset, nil
  283. }
  284. func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload uint64) {
  285. embedding := llm.KV().EmbeddingLength()
  286. heads := llm.KV().HeadCount()
  287. headsKV := llm.KV().HeadCountKV()
  288. vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
  289. layers := llm.Tensors().Layers()
  290. switch llm.KV().Architecture() {
  291. case "llama":
  292. fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads))
  293. partialOffload = 4 * batch * embedding
  294. partialOffload += max(
  295. 4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
  296. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  297. )
  298. if ffnGateWeight, ok := layers["0"]["ffn_gate.0.weight"]; ok {
  299. ffnGateWeight1 := ffnGateWeight.Shape[1]
  300. fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
  301. partialOffload = max(
  302. 4*batch*(3+embedding/heads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
  303. 4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
  304. )
  305. }
  306. case "gemma":
  307. fullOffload = 4 * batch * (embedding + vocab)
  308. partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
  309. case "command-r":
  310. fullOffload = max(
  311. 4*batch*(embedding+vocab),
  312. 4*batch*(2+4*embedding+context*(1+heads)),
  313. )
  314. partialOffload = max(
  315. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  316. 4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
  317. )
  318. case "qwen2":
  319. fullOffload = max(
  320. 4*batch*(embedding+vocab),
  321. 4*batch*(1+2*embedding+context+context*heads),
  322. )
  323. partialOffload = max(
  324. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  325. 4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
  326. )
  327. case "phi2":
  328. fullOffload = max(
  329. 4*batch*(embedding+vocab),
  330. 4*batch*(1+4*embedding+context+context*heads),
  331. )
  332. partialOffload = 4*batch*(2*embedding+vocab) + embedding*vocab*105/128
  333. case "stablelm":
  334. fullOffload = 4 * batch * (context*(1+heads) + 3*embedding + 2)
  335. partialOffload = max(
  336. 4*batch*(vocab+2*embedding),
  337. fullOffload,
  338. )
  339. }
  340. return
  341. }