ggml.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. package llm
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "io"
  7. )
  8. type GGML struct {
  9. container
  10. model
  11. }
  12. const (
  13. fileTypeF32 uint32 = iota
  14. fileTypeF16
  15. fileTypeQ4_0
  16. fileTypeQ4_1
  17. fileTypeQ4_1_F16
  18. fileTypeQ8_0 uint32 = iota + 2
  19. fileTypeQ5_0
  20. fileTypeQ5_1
  21. fileTypeQ2_K
  22. fileTypeQ3_K_S
  23. fileTypeQ3_K_M
  24. fileTypeQ3_K_L
  25. fileTypeQ4_K_S
  26. fileTypeQ4_K_M
  27. fileTypeQ5_K_S
  28. fileTypeQ5_K_M
  29. fileTypeQ6_K
  30. fileTypeIQ2_XXS
  31. fileTypeIQ2_XS
  32. fileTypeQ2_K_S
  33. fileTypeQ3_K_XS
  34. fileTypeIQ3_XXS
  35. )
  36. func fileType(fileType uint32) string {
  37. switch fileType {
  38. case fileTypeF32:
  39. return "F32"
  40. case fileTypeF16:
  41. return "F16"
  42. case fileTypeQ4_0:
  43. return "Q4_0"
  44. case fileTypeQ4_1:
  45. return "Q4_1"
  46. case fileTypeQ4_1_F16:
  47. return "Q4_1_F16"
  48. case fileTypeQ8_0:
  49. return "Q8_0"
  50. case fileTypeQ5_0:
  51. return "Q5_0"
  52. case fileTypeQ5_1:
  53. return "Q5_1"
  54. case fileTypeQ2_K:
  55. return "Q2_K"
  56. case fileTypeQ3_K_S:
  57. return "Q3_K_S"
  58. case fileTypeQ3_K_M:
  59. return "Q3_K_M"
  60. case fileTypeQ3_K_L:
  61. return "Q3_K_L"
  62. case fileTypeQ4_K_S:
  63. return "Q4_K_S"
  64. case fileTypeQ4_K_M:
  65. return "Q4_K_M"
  66. case fileTypeQ5_K_S:
  67. return "Q5_K_S"
  68. case fileTypeQ5_K_M:
  69. return "Q5_K_M"
  70. case fileTypeQ6_K:
  71. return "Q6_K"
  72. case fileTypeIQ2_XXS:
  73. return "IQ2_XXS"
  74. case fileTypeIQ2_XS:
  75. return "IQ2_XS"
  76. case fileTypeQ2_K_S:
  77. return "Q2_K_S"
  78. case fileTypeQ3_K_XS:
  79. return "Q3_K_XS"
  80. case fileTypeIQ3_XXS:
  81. return "IQ3_XXS"
  82. default:
  83. return "unknown"
  84. }
  85. }
  86. type model interface {
  87. KV() KV
  88. Tensors() []*Tensor
  89. }
  90. type KV map[string]any
  91. func (kv KV) u64(key string) uint64 {
  92. switch v := kv[key].(type) {
  93. case uint64:
  94. return v
  95. case uint32:
  96. return uint64(v)
  97. case float64:
  98. return uint64(v)
  99. default:
  100. return 0
  101. }
  102. }
  103. func (kv KV) Architecture() string {
  104. if s, ok := kv["general.architecture"].(string); ok {
  105. return s
  106. }
  107. return "unknown"
  108. }
  109. func (kv KV) ParameterCount() uint64 {
  110. return kv.u64("general.parameter_count")
  111. }
  112. func (kv KV) FileType() string {
  113. if u64 := kv.u64("general.file_type"); u64 > 0 {
  114. return fileType(uint32(u64))
  115. }
  116. return "unknown"
  117. }
  118. func (kv KV) BlockCount() uint64 {
  119. return kv.u64(fmt.Sprintf("%s.block_count", kv.Architecture()))
  120. }
  121. func (kv KV) HeadCount() uint64 {
  122. return kv.u64(fmt.Sprintf("%s.attention.head_count", kv.Architecture()))
  123. }
  124. func (kv KV) HeadCountKV() uint64 {
  125. return kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture()))
  126. }
  127. func (kv KV) GQA() uint64 {
  128. if headCountKV := kv.HeadCountKV(); headCountKV > 0 {
  129. return kv.HeadCount() / headCountKV
  130. }
  131. return 0
  132. }
  133. func (kv KV) EmbeddingLength() uint64 {
  134. return kv.u64(fmt.Sprintf("%s.embedding_length", kv.Architecture()))
  135. }
  136. func (kv KV) ContextLength() uint64 {
  137. return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
  138. }
  139. type Tensor struct {
  140. Name string `json:"name"`
  141. Kind uint32 `json:"kind"`
  142. Offset uint64 `json:"-"`
  143. // Shape is the number of elements in each dimension
  144. Shape []uint64 `json:"shape"`
  145. io.WriterTo `json:"-"`
  146. }
  147. func (t Tensor) blockSize() uint64 {
  148. switch {
  149. case t.Kind < 2:
  150. return 1
  151. case t.Kind < 10:
  152. return 32
  153. default:
  154. return 256
  155. }
  156. }
  157. func (t Tensor) typeSize() uint64 {
  158. blockSize := t.blockSize()
  159. switch t.Kind {
  160. case 0: // FP32
  161. return 4
  162. case 1: // FP16
  163. return 2
  164. case 2: // Q4_0
  165. return 2 + blockSize/2
  166. case 3: // Q4_1
  167. return 2 + 2 + blockSize/2
  168. case 6: // Q5_0
  169. return 2 + 4 + blockSize/2
  170. case 7: // Q5_1
  171. return 2 + 2 + 4 + blockSize/2
  172. case 8: // Q8_0
  173. return 2 + blockSize
  174. case 9: // Q8_1
  175. return 4 + 4 + blockSize
  176. case 10: // Q2_K
  177. return blockSize/16 + blockSize/4 + 2 + 2
  178. case 11: // Q3_K
  179. return blockSize/8 + blockSize/4 + 12 + 2
  180. case 12: // Q4_K
  181. return 2 + 2 + 12 + blockSize/2
  182. case 13: // Q5_K
  183. return 2 + 2 + 12 + blockSize/8 + blockSize/2
  184. case 14: // Q6_K
  185. return blockSize/2 + blockSize/4 + blockSize/16 + 2
  186. case 15: // Q8_K
  187. return 2 + blockSize + 2*blockSize/16
  188. case 16: // IQ2_XXS
  189. return 2 + 2*blockSize/8
  190. case 17: // IQ2_XS
  191. return 2 + 2*blockSize/8 + blockSize/32
  192. case 18: // IQ3_XXS
  193. return 2 + 3*blockSize/8
  194. default:
  195. return 0
  196. }
  197. }
  198. func (t Tensor) parameters() uint64 {
  199. var count uint64 = 1
  200. for _, n := range t.Shape {
  201. count *= n
  202. }
  203. return count
  204. }
  205. func (t Tensor) size() uint64 {
  206. return t.parameters() * t.typeSize() / t.blockSize()
  207. }
  208. type container interface {
  209. Name() string
  210. Decode(io.ReadSeeker) (model, error)
  211. }
  212. const (
  213. // Magic constant for `ggml` files (unversioned).
  214. FILE_MAGIC_GGML = 0x67676d6c
  215. // Magic constant for `ggml` files (versioned, ggmf).
  216. FILE_MAGIC_GGMF = 0x67676d66
  217. // Magic constant for `ggml` files (versioned, ggjt).
  218. FILE_MAGIC_GGJT = 0x67676a74
  219. // Magic constant for `ggla` files (LoRA adapter).
  220. FILE_MAGIC_GGLA = 0x67676C61
  221. // Magic constant for `gguf` files (versioned, gguf)
  222. FILE_MAGIC_GGUF_LE = 0x46554747
  223. FILE_MAGIC_GGUF_BE = 0x47475546
  224. )
  225. var ErrUnsupportedFormat = errors.New("unsupported model format")
  226. func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
  227. var magic uint32
  228. if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
  229. return nil, 0, err
  230. }
  231. var c container
  232. switch magic {
  233. case FILE_MAGIC_GGML, FILE_MAGIC_GGMF, FILE_MAGIC_GGJT:
  234. return nil, 0, ErrUnsupportedFormat
  235. case FILE_MAGIC_GGLA:
  236. c = &containerGGLA{}
  237. case FILE_MAGIC_GGUF_LE:
  238. c = &containerGGUF{ByteOrder: binary.LittleEndian}
  239. case FILE_MAGIC_GGUF_BE:
  240. c = &containerGGUF{ByteOrder: binary.BigEndian}
  241. default:
  242. return nil, 0, errors.New("invalid file magic")
  243. }
  244. model, err := c.Decode(rs)
  245. if errors.Is(err, io.EOF) {
  246. // noop
  247. } else if err != nil {
  248. return nil, 0, err
  249. }
  250. offset, err := rs.Seek(0, io.SeekCurrent)
  251. if err != nil {
  252. return nil, 0, err
  253. }
  254. // final model type
  255. return &GGML{
  256. container: c,
  257. model: model,
  258. }, offset, nil
  259. }