ggml.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. package fileutils
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "os"
  8. "slices"
  9. "strings"
  10. "sync"
  11. "github.com/ollama/ollama/util/bufioutil"
  12. )
  13. type GGML struct {
  14. container
  15. model
  16. }
  17. type model interface {
  18. KV() KV
  19. Tensors() *Tensors
  20. }
  21. type KV map[string]any
  22. func (kv KV) u64(key string) uint64 {
  23. switch v := kv[key].(type) {
  24. case uint64:
  25. return v
  26. case uint32:
  27. return uint64(v)
  28. case float64:
  29. return uint64(v)
  30. default:
  31. return 0
  32. }
  33. }
  34. func (kv KV) Architecture() string {
  35. if s, ok := kv["general.architecture"].(string); ok {
  36. return s
  37. }
  38. return "unknown"
  39. }
  40. func (kv KV) Kind() string {
  41. if s, ok := kv["general.type"].(string); ok {
  42. return s
  43. }
  44. return "unknown"
  45. }
  46. func (kv KV) ParameterCount() uint64 {
  47. return kv.u64("general.parameter_count")
  48. }
  49. func (kv KV) FileType() fileType {
  50. if u64 := kv.u64("general.file_type"); u64 > 0 {
  51. return fileType(uint32(u64))
  52. }
  53. return fileTypeUnknown
  54. }
  55. func (kv KV) BlockCount() uint64 {
  56. return kv.u64(fmt.Sprintf("%s.block_count", kv.Architecture()))
  57. }
  58. func (kv KV) HeadCount() uint64 {
  59. return kv.u64(fmt.Sprintf("%s.attention.head_count", kv.Architecture()))
  60. }
  61. func (kv KV) HeadCountKV() uint64 {
  62. if headCountKV := kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture())); headCountKV > 0 {
  63. return headCountKV
  64. }
  65. return 1
  66. }
  67. func (kv KV) EmbeddingHeadCount() uint64 {
  68. if heads := kv.HeadCount(); heads > 0 {
  69. return kv.EmbeddingLength() / kv.HeadCount()
  70. }
  71. return 0
  72. }
  73. func (kv KV) EmbeddingHeadCountK() uint64 {
  74. if k := kv.u64(fmt.Sprintf("%s.attention.key_length", kv.Architecture())); k > 0 {
  75. return k
  76. }
  77. return kv.EmbeddingHeadCount()
  78. }
  79. func (kv KV) EmbeddingHeadCountV() uint64 {
  80. if v := kv.u64(fmt.Sprintf("%s.attention.value_length", kv.Architecture())); v > 0 {
  81. return v
  82. }
  83. return kv.EmbeddingHeadCount()
  84. }
  85. func (kv KV) GQA() uint64 {
  86. return kv.HeadCount() / kv.HeadCountKV()
  87. }
  88. func (kv KV) EmbeddingLength() uint64 {
  89. return kv.u64(fmt.Sprintf("%s.embedding_length", kv.Architecture()))
  90. }
  91. func (kv KV) ContextLength() uint64 {
  92. return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
  93. }
  94. func (kv KV) ChatTemplate() string {
  95. s, _ := kv["tokenizer.chat_template"].(string)
  96. return s
  97. }
  98. type Tensors struct {
  99. Items []*Tensor
  100. Offset uint64
  101. layers map[string]Layer
  102. layersOnce sync.Once
  103. }
  104. func (ts *Tensors) Layers() map[string]Layer {
  105. ts.layersOnce.Do(func() {
  106. ts.layers = make(map[string]Layer)
  107. for _, t := range ts.Items {
  108. parts := strings.Split(t.Name, ".")
  109. if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
  110. if len(parts) > index+2 {
  111. // blk and mm should have a number after them, join it
  112. parts = append(
  113. []string{strings.Join(parts[:index+2], ".")},
  114. parts[index+2:]...)
  115. }
  116. }
  117. if _, ok := ts.layers[parts[0]]; !ok {
  118. ts.layers[parts[0]] = make(Layer)
  119. }
  120. ts.layers[parts[0]][strings.Join(parts[1:], ".")] = t
  121. }
  122. })
  123. return ts.layers
  124. }
  125. type Layer map[string]*Tensor
  126. func (l Layer) size() (size uint64) {
  127. for _, t := range l {
  128. size += t.Size()
  129. }
  130. return size
  131. }
  132. type Tensor struct {
  133. Name string `json:"name"`
  134. Kind uint32 `json:"kind"`
  135. Offset uint64 `json:"-"`
  136. // Shape is the number of elements in each dimension
  137. Shape []uint64 `json:"shape"`
  138. io.WriterTo `json:"-"`
  139. }
  140. func (t Tensor) block() (n int) {
  141. if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
  142. return -1
  143. }
  144. return
  145. }
  146. func (t Tensor) blockSize() uint64 {
  147. switch t.Kind {
  148. case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16
  149. return 1
  150. case 2, 3, 4, 5, 6, 7, 8, 9, 20: // Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, IQ4_NL
  151. return 32
  152. default: // All others
  153. return 256
  154. }
  155. }
  156. func (t Tensor) typeSize() uint64 {
  157. blockSize := t.blockSize()
  158. switch t.Kind {
  159. case 0: // FP32
  160. return 4
  161. case 1: // FP16
  162. return 2
  163. case 2: // Q4_0
  164. return 2 + blockSize/2
  165. case 3: // Q4_1
  166. return 2 + 2 + blockSize/2
  167. case 6: // Q5_0
  168. return 2 + 4 + blockSize/2
  169. case 7: // Q5_1
  170. return 2 + 2 + 4 + blockSize/2
  171. case 8: // Q8_0
  172. return 2 + blockSize
  173. case 9: // Q8_1
  174. return 4 + 4 + blockSize
  175. case 10: // Q2_K
  176. return blockSize/16 + blockSize/4 + 2 + 2
  177. case 11: // Q3_K
  178. return blockSize/8 + blockSize/4 + 12 + 2
  179. case 12: // Q4_K
  180. return 2 + 2 + 12 + blockSize/2
  181. case 13: // Q5_K
  182. return 2 + 2 + 12 + blockSize/8 + blockSize/2
  183. case 14: // Q6_K
  184. return blockSize/2 + blockSize/4 + blockSize/16 + 2
  185. case 15: // Q8_K
  186. return 2 + blockSize + 2*blockSize/16
  187. case 16: // IQ2_XXS
  188. return 2 + 2*blockSize/8
  189. case 17: // IQ2_XS
  190. return 2 + 2*blockSize/8 + blockSize/32
  191. case 18: // IQ3_XXS
  192. return 2 + blockSize/4 + blockSize/8
  193. case 19: // IQ1_S
  194. return 2 + blockSize/8 + blockSize/16
  195. case 20: // IQ4_NL
  196. return 2 + blockSize/2
  197. case 21: // IQ3_S
  198. return 2 + blockSize/4 + blockSize/8 + blockSize/32 + 4
  199. case 22: // IQ2_S
  200. return 2 + blockSize/4 + blockSize/16
  201. case 23: // IQ4_XS
  202. return 2 + 2 + blockSize/2 + blockSize/64
  203. case 24: // I8
  204. return 1
  205. case 25: // I16
  206. return 2
  207. case 26: // I32
  208. return 4
  209. case 27: // I64
  210. return 8
  211. case 28: // F64
  212. return 8
  213. case 29: // IQ1_M
  214. return blockSize/8 + blockSize/16 + blockSize/32
  215. case 30: // BF16
  216. return 2
  217. default:
  218. return 0
  219. }
  220. }
  221. func (t Tensor) parameters() uint64 {
  222. var count uint64 = 1
  223. for _, n := range t.Shape {
  224. count *= n
  225. }
  226. return count
  227. }
  228. func (t Tensor) Size() uint64 {
  229. return t.parameters() * t.typeSize() / t.blockSize()
  230. }
  231. type container interface {
  232. Name() string
  233. Decode(io.ReadSeeker) (model, error)
  234. }
  235. const (
  236. // Magic constant for `ggml` files (unversioned).
  237. FILE_MAGIC_GGML = 0x67676d6c
  238. // Magic constant for `ggml` files (versioned, ggmf).
  239. FILE_MAGIC_GGMF = 0x67676d66
  240. // Magic constant for `ggml` files (versioned, ggjt).
  241. FILE_MAGIC_GGJT = 0x67676a74
  242. // Magic constant for `ggla` files (LoRA adapter).
  243. FILE_MAGIC_GGLA = 0x67676C61
  244. // Magic constant for `gguf` files (versioned, gguf)
  245. FILE_MAGIC_GGUF_LE = 0x46554747
  246. FILE_MAGIC_GGUF_BE = 0x47475546
  247. )
  248. var ErrUnsupportedFormat = errors.New("unsupported model format")
  249. func DetectGGMLType(b []byte) string {
  250. switch binary.LittleEndian.Uint32(b[:4]) {
  251. case FILE_MAGIC_GGML:
  252. return "ggml"
  253. case FILE_MAGIC_GGMF:
  254. return "ggmf"
  255. case FILE_MAGIC_GGJT:
  256. return "ggjt"
  257. case FILE_MAGIC_GGLA:
  258. return "ggla"
  259. case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE:
  260. return "gguf"
  261. default:
  262. return ""
  263. }
  264. }
  265. // DecodeGGML decodes a GGML model from the given reader.
  266. //
  267. // It collects array values for arrays with a size less than or equal to
  268. // maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
  269. // the maxArraySize is negative, all arrays are collected.
  270. func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
  271. if maxArraySize == 0 {
  272. maxArraySize = 1024
  273. }
  274. rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
  275. var magic uint32
  276. if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
  277. return nil, 0, err
  278. }
  279. var c container
  280. switch magic {
  281. case FILE_MAGIC_GGML, FILE_MAGIC_GGMF, FILE_MAGIC_GGJT:
  282. return nil, 0, ErrUnsupportedFormat
  283. case FILE_MAGIC_GGLA:
  284. c = &containerGGLA{}
  285. case FILE_MAGIC_GGUF_LE:
  286. c = &containerGGUF{ByteOrder: binary.LittleEndian, maxArraySize: maxArraySize}
  287. case FILE_MAGIC_GGUF_BE:
  288. c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
  289. default:
  290. return nil, 0, errors.New("invalid file magic")
  291. }
  292. model, err := c.Decode(rs)
  293. if err != nil {
  294. return nil, 0, err
  295. }
  296. offset, err := rs.Seek(0, io.SeekCurrent)
  297. if err != nil {
  298. return nil, 0, err
  299. }
  300. // final model type
  301. return &GGML{
  302. container: c,
  303. model: model,
  304. }, offset, nil
  305. }
  306. func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload uint64) {
  307. embedding := llm.KV().EmbeddingLength()
  308. heads := llm.KV().HeadCount()
  309. headsKV := llm.KV().HeadCountKV()
  310. vocab := uint64(llm.KV()["tokenizer.ggml.tokens"].(*array).size)
  311. embeddingHeads := llm.KV().EmbeddingHeadCount()
  312. embeddingHeadsK := llm.KV().EmbeddingHeadCountK()
  313. layers := llm.Tensors().Layers()
  314. switch llm.KV().Architecture() {
  315. case "llama":
  316. fullOffload = max(
  317. 4*batch*(1+4*embedding+context*(1+heads)),
  318. 4*batch*(embedding+vocab),
  319. )
  320. partialOffload = 4 * batch * embedding
  321. partialOffload += max(
  322. 4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV),
  323. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  324. )
  325. if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
  326. // mixtral 8x22b
  327. ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
  328. partialOffload = max(
  329. 3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
  330. 4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
  331. )
  332. } else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
  333. // mixtral 8x7b
  334. ffnGateWeight1 := ffnGateWeight.Shape[1]
  335. fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
  336. partialOffload = max(
  337. 4*batch*(3+embeddingHeads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
  338. 4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
  339. )
  340. }
  341. case "gemma", "gemma2":
  342. fullOffload = max(
  343. 4*batch*(embedding+vocab),
  344. 4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
  345. )
  346. partialOffload = max(
  347. 4*embedding*batch+embedding*vocab*105/128+4*vocab*batch,
  348. 4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+
  349. 4*embeddingHeadsK*context*8+
  350. embedding*embeddingHeadsK*heads*9/16,
  351. )
  352. case "command-r":
  353. fullOffload = max(
  354. 4*batch*(embedding+vocab),
  355. 4*batch*(2+4*embedding+context*(1+heads)),
  356. )
  357. partialOffload = max(
  358. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  359. 4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
  360. )
  361. case "qwen2":
  362. fullOffload = max(
  363. 4*batch*(embedding+vocab),
  364. 4*batch*(1+2*embedding+context+context*heads),
  365. )
  366. partialOffload = max(
  367. 4*batch*(embedding+vocab)+embedding*vocab*105/128,
  368. 4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
  369. )
  370. case "phi2":
  371. fullOffload = max(
  372. 4*batch*(embedding+vocab),
  373. 4*batch*(1+4*embedding+context+context*heads),
  374. )
  375. partialOffload = max(
  376. 4*batch*(2*embedding+vocab)+embedding*vocab*105/128,
  377. 4*batch*(2+3*embedding+context+context*heads),
  378. )
  379. case "stablelm":
  380. fullOffload = 4 * batch * (context*(1+heads) + 3*embedding + 2)
  381. partialOffload = max(
  382. 4*batch*(vocab+2*embedding),
  383. fullOffload,
  384. )
  385. case "deepseek2":
  386. fullOffload = max(
  387. 4*batch*(3*embedding+vocab),
  388. 4*batch*(3*embedding+2+context*(1+headsKV)+2*embeddingHeadsK*headsKV),
  389. )
  390. partialOffload = max(
  391. 4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
  392. 4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
  393. )
  394. case "chatglm":
  395. fullOffload = 4 * batch * (embedding + vocab)
  396. partialOffload = 4*batch*(embedding+vocab) + embedding*vocab*105/128
  397. if qkvBias, ok := layers["blk.0"]["attn_qkv.bias"]; ok {
  398. fullOffload = max(
  399. fullOffload,
  400. 4*batch*(2+
  401. 2*embedding+
  402. context+
  403. context*heads+
  404. embeddingHeadsK*heads+
  405. qkvBias.Shape[0]),
  406. )
  407. partialOffload = max(
  408. partialOffload,
  409. 4*batch*(1+
  410. 2*embedding+
  411. embeddingHeadsK*heads+
  412. context+
  413. context*heads)+
  414. 4*embeddingHeadsK*context+
  415. 4*context*embeddingHeadsK+
  416. 4*qkvBias.Shape[0],
  417. )
  418. }
  419. }
  420. return
  421. }
  422. // LoadModel will load a model from disk. The model must be in the GGML format.
  423. //
  424. // It collects array values for arrays with a size less than or equal to
  425. // maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
  426. // the maxArraySize is negative, all arrays are collected.
  427. func LoadModel(model string, maxArraySize int) (*GGML, error) {
  428. if _, err := os.Stat(model); err != nil {
  429. return nil, err
  430. }
  431. f, err := os.Open(model)
  432. if err != nil {
  433. return nil, err
  434. }
  435. defer f.Close()
  436. ggml, _, err := DecodeGGML(f, maxArraySize)
  437. return ggml, err
  438. }