ggml.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. package llm
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "strings"
  8. )
  9. type GGML struct {
  10. container
  11. model
  12. }
  13. type model interface {
  14. KV() KV
  15. Tensors() Tensors
  16. Offset() int64
  17. }
  18. type KV map[string]any
  19. func (kv KV) u64(key string) uint64 {
  20. switch v := kv[key].(type) {
  21. case uint64:
  22. return v
  23. case uint32:
  24. return uint64(v)
  25. case float64:
  26. return uint64(v)
  27. default:
  28. return 0
  29. }
  30. }
  31. func (kv KV) Architecture() string {
  32. if s, ok := kv["general.architecture"].(string); ok {
  33. return s
  34. }
  35. return "unknown"
  36. }
  37. func (kv KV) ParameterCount() uint64 {
  38. return kv.u64("general.parameter_count")
  39. }
  40. func (kv KV) FileType() string {
  41. if u64 := kv.u64("general.file_type"); u64 > 0 {
  42. return fileType(uint32(u64)).String()
  43. }
  44. return "unknown"
  45. }
  46. func (kv KV) BlockCount() uint64 {
  47. return kv.u64(fmt.Sprintf("%s.block_count", kv.Architecture()))
  48. }
  49. func (kv KV) HeadCount() uint64 {
  50. return kv.u64(fmt.Sprintf("%s.attention.head_count", kv.Architecture()))
  51. }
  52. func (kv KV) HeadCountKV() uint64 {
  53. if headCountKV := kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture())); headCountKV > 0 {
  54. return headCountKV
  55. }
  56. return 1
  57. }
  58. func (kv KV) GQA() uint64 {
  59. return kv.HeadCount() / kv.HeadCountKV()
  60. }
  61. func (kv KV) EmbeddingLength() uint64 {
  62. return kv.u64(fmt.Sprintf("%s.embedding_length", kv.Architecture()))
  63. }
  64. func (kv KV) ContextLength() uint64 {
  65. return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
  66. }
  67. type Tensors []*Tensor
  68. func (ts Tensors) Layers() map[string]Layer {
  69. layers := make(map[string]Layer)
  70. for _, t := range ts {
  71. parts := strings.Split(t.Name, ".")
  72. if parts[0] == "blk" {
  73. // join first and second part, e.g. blk.%d
  74. parts = append([]string{fmt.Sprintf("%s.%s", parts[0], parts[1])}, parts[2:]...)
  75. }
  76. if _, ok := layers[parts[0]]; !ok {
  77. layers[parts[0]] = make(Layer)
  78. }
  79. layers[parts[0]][strings.Join(parts[1:], ".")] = t
  80. }
  81. return layers
  82. }
  83. type Layer map[string]*Tensor
  84. func (l Layer) size() (size uint64) {
  85. for _, t := range l {
  86. size += t.size()
  87. }
  88. return size
  89. }
  90. type Tensor struct {
  91. Name string `json:"name"`
  92. Kind uint32 `json:"kind"`
  93. Offset uint64 `json:"-"`
  94. // Shape is the number of elements in each dimension
  95. Shape []uint64 `json:"shape"`
  96. io.WriterTo `json:"-"`
  97. }
  98. func (t Tensor) blockSize() uint64 {
  99. switch {
  100. case t.Kind < 2:
  101. return 1
  102. case t.Kind < 10:
  103. return 32
  104. default:
  105. return 256
  106. }
  107. }
  108. func (t Tensor) typeSize() uint64 {
  109. blockSize := t.blockSize()
  110. switch t.Kind {
  111. case 0: // FP32
  112. return 4
  113. case 1: // FP16
  114. return 2
  115. case 2: // Q4_0
  116. return 2 + blockSize/2
  117. case 3: // Q4_1
  118. return 2 + 2 + blockSize/2
  119. case 6: // Q5_0
  120. return 2 + 4 + blockSize/2
  121. case 7: // Q5_1
  122. return 2 + 2 + 4 + blockSize/2
  123. case 8: // Q8_0
  124. return 2 + blockSize
  125. case 9: // Q8_1
  126. return 4 + 4 + blockSize
  127. case 10: // Q2_K
  128. return blockSize/16 + blockSize/4 + 2 + 2
  129. case 11: // Q3_K
  130. return blockSize/8 + blockSize/4 + 12 + 2
  131. case 12: // Q4_K
  132. return 2 + 2 + 12 + blockSize/2
  133. case 13: // Q5_K
  134. return 2 + 2 + 12 + blockSize/8 + blockSize/2
  135. case 14: // Q6_K
  136. return blockSize/2 + blockSize/4 + blockSize/16 + 2
  137. case 15: // Q8_K
  138. return 2 + blockSize + 2*blockSize/16
  139. case 16: // IQ2_XXS
  140. return 2 + 2*blockSize/8
  141. case 17: // IQ2_XS
  142. return 2 + 2*blockSize/8 + blockSize/32
  143. case 18: // IQ3_XXS
  144. return 2 + 3*blockSize/8
  145. default:
  146. return 0
  147. }
  148. }
  149. func (t Tensor) parameters() uint64 {
  150. var count uint64 = 1
  151. for _, n := range t.Shape {
  152. count *= n
  153. }
  154. return count
  155. }
  156. func (t Tensor) size() uint64 {
  157. return t.parameters() * t.typeSize() / t.blockSize()
  158. }
  159. type container interface {
  160. Name() string
  161. Decode(io.ReadSeeker) (model, error)
  162. }
  163. const (
  164. // Magic constant for `ggml` files (unversioned).
  165. FILE_MAGIC_GGML = 0x67676d6c
  166. // Magic constant for `ggml` files (versioned, ggmf).
  167. FILE_MAGIC_GGMF = 0x67676d66
  168. // Magic constant for `ggml` files (versioned, ggjt).
  169. FILE_MAGIC_GGJT = 0x67676a74
  170. // Magic constant for `ggla` files (LoRA adapter).
  171. FILE_MAGIC_GGLA = 0x67676C61
  172. // Magic constant for `gguf` files (versioned, gguf)
  173. FILE_MAGIC_GGUF_LE = 0x46554747
  174. FILE_MAGIC_GGUF_BE = 0x47475546
  175. )
  176. var ErrUnsupportedFormat = errors.New("unsupported model format")
  177. func DetectGGMLType(b []byte) string {
  178. switch binary.LittleEndian.Uint32(b[:4]) {
  179. case FILE_MAGIC_GGML:
  180. return "ggml"
  181. case FILE_MAGIC_GGMF:
  182. return "ggmf"
  183. case FILE_MAGIC_GGJT:
  184. return "ggjt"
  185. case FILE_MAGIC_GGLA:
  186. return "ggla"
  187. case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE:
  188. return "gguf"
  189. default:
  190. return ""
  191. }
  192. }
  193. func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
  194. var magic uint32
  195. if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
  196. return nil, 0, err
  197. }
  198. var c container
  199. switch magic {
  200. case FILE_MAGIC_GGML, FILE_MAGIC_GGMF, FILE_MAGIC_GGJT:
  201. return nil, 0, ErrUnsupportedFormat
  202. case FILE_MAGIC_GGLA:
  203. c = &containerGGLA{}
  204. case FILE_MAGIC_GGUF_LE:
  205. c = &containerGGUF{ByteOrder: binary.LittleEndian}
  206. case FILE_MAGIC_GGUF_BE:
  207. c = &containerGGUF{ByteOrder: binary.BigEndian}
  208. default:
  209. return nil, 0, errors.New("invalid file magic")
  210. }
  211. model, err := c.Decode(rs)
  212. if errors.Is(err, io.EOF) {
  213. // noop
  214. } else if err != nil {
  215. return nil, 0, err
  216. }
  217. offset, err := rs.Seek(0, io.SeekCurrent)
  218. if err != nil {
  219. return nil, 0, err
  220. }
  221. // final model type
  222. return &GGML{
  223. container: c,
  224. model: model,
  225. }, offset, nil
  226. }
  227. func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload uint64) {
  228. embedding := llm.KV().EmbeddingLength()
  229. heads := llm.KV().HeadCount()
  230. headsKV := llm.KV().HeadCountKV()
  231. vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
  232. layers := llm.Tensors().Layers()
  233. switch llm.KV().Architecture() {
  234. case "llama":
  235. fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads))
  236. partialOffload = 4 * batch * embedding
  237. partialOffload += max(
  238. 4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
  239. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  240. )
  241. if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
  242. // mixtral 8x22b
  243. ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
  244. partialOffload = max(
  245. 3*ffnGateExpsWeight.size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV),
  246. 4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch),
  247. )
  248. } else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
  249. // mixtral 8x7b
  250. ffnGateWeight1 := ffnGateWeight.Shape[1]
  251. fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
  252. partialOffload = max(
  253. 4*batch*(3+embedding/heads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
  254. 4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
  255. )
  256. }
  257. case "gemma":
  258. fullOffload = 4 * batch * (embedding + vocab)
  259. partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
  260. case "command-r":
  261. fullOffload = max(
  262. 4*batch*(embedding+vocab),
  263. 4*batch*(2+4*embedding+context*(1+heads)),
  264. )
  265. partialOffload = max(
  266. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  267. 4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
  268. )
  269. case "qwen2":
  270. fullOffload = max(
  271. 4*batch*(embedding+vocab),
  272. 4*batch*(1+2*embedding+context+context*heads),
  273. )
  274. partialOffload = max(
  275. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  276. 4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
  277. )
  278. case "phi2":
  279. fullOffload = max(
  280. 4*batch*(embedding+vocab),
  281. 4*batch*(1+4*embedding+context+context*heads),
  282. )
  283. partialOffload = 4*batch*(2*embedding+vocab) + embedding*vocab*105/128
  284. case "stablelm":
  285. fullOffload = 4 * batch * (context*(1+heads) + 3*embedding + 2)
  286. partialOffload = max(
  287. 4*batch*(vocab+2*embedding),
  288. fullOffload,
  289. )
  290. }
  291. return
  292. }