gguf.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703
  1. package llm
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "io"
  7. "strings"
  8. "github.com/ollama/ollama/format"
  9. )
  10. type containerGGUF struct {
  11. ByteOrder binary.ByteOrder
  12. Version uint32
  13. V1 struct {
  14. NumTensor uint32
  15. NumKV uint32
  16. }
  17. V2 struct {
  18. NumTensor uint64
  19. NumKV uint64
  20. }
  21. V3 struct {
  22. NumTensor uint64
  23. NumKV uint64
  24. }
  25. }
  26. func (c *containerGGUF) Name() string {
  27. return "gguf"
  28. }
  29. func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
  30. if err := binary.Read(rs, c.ByteOrder, &c.Version); err != nil {
  31. return nil, err
  32. }
  33. var err error
  34. switch c.Version {
  35. case 1:
  36. err = binary.Read(rs, c.ByteOrder, &c.V1)
  37. case 2:
  38. err = binary.Read(rs, c.ByteOrder, &c.V2)
  39. default:
  40. err = binary.Read(rs, c.ByteOrder, &c.V3)
  41. }
  42. if err != nil {
  43. return nil, err
  44. }
  45. model := newGGUF(c)
  46. if err := model.Decode(rs); err != nil {
  47. return nil, err
  48. }
  49. return model, nil
  50. }
  51. const (
  52. _ uint32 = iota
  53. GGUFTokenNormal
  54. GGUFTokenUnknown
  55. GGUFTokenControl
  56. GGUFTokenUserDefined
  57. GGUFTokenUnused
  58. GGUFTokenByte
  59. )
  60. const (
  61. ggufTypeUint8 uint32 = iota
  62. ggufTypeInt8
  63. ggufTypeUint16
  64. ggufTypeInt16
  65. ggufTypeUint32
  66. ggufTypeInt32
  67. ggufTypeFloat32
  68. ggufTypeBool
  69. ggufTypeString
  70. ggufTypeArray
  71. ggufTypeUint64
  72. ggufTypeInt64
  73. ggufTypeFloat64
  74. )
  75. type gguf struct {
  76. *containerGGUF
  77. KV
  78. Tensors []Tensor
  79. parameters uint64
  80. }
  81. func newGGUF(container *containerGGUF) *gguf {
  82. return &gguf{
  83. containerGGUF: container,
  84. KV: make(KV),
  85. }
  86. }
  87. func NewGGUFV3(bo binary.ByteOrder) *gguf {
  88. return newGGUF(&containerGGUF{ByteOrder: bo, Version: 3})
  89. }
  90. func (llm *gguf) numTensor() uint64 {
  91. switch llm.Version {
  92. case 1:
  93. return uint64(llm.V1.NumTensor)
  94. case 2:
  95. return llm.V2.NumTensor
  96. default:
  97. return llm.V3.NumTensor
  98. }
  99. }
  100. func (llm *gguf) numKV() uint64 {
  101. switch llm.Version {
  102. case 1:
  103. return uint64(llm.V1.NumKV)
  104. case 2:
  105. return llm.V2.NumKV
  106. default:
  107. return llm.V3.NumKV
  108. }
  109. }
  110. func (llm *gguf) ModelFamily() string {
  111. if t, ok := llm.KV["general.architecture"].(string); ok {
  112. return t
  113. }
  114. return "unknown"
  115. }
  116. func (llm *gguf) ModelType() string {
  117. if llm.parameters > 0 {
  118. return format.HumanNumber(llm.parameters)
  119. }
  120. return "unknown"
  121. }
  122. func (llm *gguf) FileType() string {
  123. if t, ok := llm.KV["general.file_type"].(uint32); ok {
  124. return fileType(t)
  125. }
  126. return "unknown"
  127. }
  128. func (llm *gguf) Decode(rs io.ReadSeeker) error {
  129. // decode key-values
  130. for i := 0; uint64(i) < llm.numKV(); i++ {
  131. k, err := readGGUFString(llm, rs)
  132. if err != nil {
  133. return err
  134. }
  135. t, err := readGGUF[uint32](llm, rs)
  136. if err != nil {
  137. return err
  138. }
  139. var v any
  140. switch t {
  141. case ggufTypeUint8:
  142. v, err = readGGUF[uint8](llm, rs)
  143. case ggufTypeInt8:
  144. v, err = readGGUF[int8](llm, rs)
  145. case ggufTypeUint16:
  146. v, err = readGGUF[uint16](llm, rs)
  147. case ggufTypeInt16:
  148. v, err = readGGUF[int16](llm, rs)
  149. case ggufTypeUint32:
  150. v, err = readGGUF[uint32](llm, rs)
  151. case ggufTypeInt32:
  152. v, err = readGGUF[int32](llm, rs)
  153. case ggufTypeUint64:
  154. v, err = readGGUF[uint64](llm, rs)
  155. case ggufTypeInt64:
  156. v, err = readGGUF[int64](llm, rs)
  157. case ggufTypeFloat32:
  158. v, err = readGGUF[float32](llm, rs)
  159. case ggufTypeFloat64:
  160. v, err = readGGUF[float64](llm, rs)
  161. case ggufTypeBool:
  162. v, err = readGGUF[bool](llm, rs)
  163. case ggufTypeString:
  164. v, err = readGGUFString(llm, rs)
  165. case ggufTypeArray:
  166. v, err = readGGUFArray(llm, rs)
  167. default:
  168. return fmt.Errorf("invalid type: %d", t)
  169. }
  170. if err != nil {
  171. return err
  172. }
  173. llm.KV[k] = v
  174. }
  175. // decode tensors
  176. for i := 0; uint64(i) < llm.numTensor(); i++ {
  177. name, err := readGGUFString(llm, rs)
  178. if err != nil {
  179. return err
  180. }
  181. // dims is the number of dimensions in the tensor
  182. dims, err := readGGUF[uint32](llm, rs)
  183. if err != nil {
  184. return err
  185. }
  186. shape := [4]uint64{1, 1, 1, 1}
  187. for i := 0; uint32(i) < dims; i++ {
  188. shape[i], err = readGGUF[uint64](llm, rs)
  189. if err != nil {
  190. return err
  191. }
  192. }
  193. kind, err := readGGUF[uint32](llm, rs)
  194. if err != nil {
  195. return err
  196. }
  197. offset, err := readGGUF[uint64](llm, rs)
  198. if err != nil {
  199. return err
  200. }
  201. tensor := Tensor{
  202. Name: name,
  203. Kind: kind,
  204. Offset: offset,
  205. Shape: shape[:],
  206. }
  207. llm.Tensors = append(llm.Tensors, tensor)
  208. llm.parameters += tensor.parameters()
  209. }
  210. alignment, ok := llm.KV["general.alignment"].(uint32)
  211. if !ok {
  212. alignment = 32
  213. }
  214. offset, err := rs.Seek(0, io.SeekCurrent)
  215. if err != nil {
  216. return err
  217. }
  218. padding := llm.padding(offset, int64(alignment))
  219. if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
  220. return err
  221. }
  222. for _, tensor := range llm.Tensors {
  223. padded := (int64(tensor.size()) + int64(alignment) - 1) & ^(int64(alignment) - 1)
  224. if _, err := rs.Seek(padded, io.SeekCurrent); err != nil {
  225. return err
  226. }
  227. }
  228. return nil
  229. }
  230. func (llm *gguf) NumLayers() uint32 {
  231. value, exists := llm.KV[fmt.Sprintf("%s.block_count", llm.ModelFamily())]
  232. if !exists {
  233. return 0
  234. }
  235. return value.(uint32)
  236. }
  237. func (llm *gguf) NumHead() uint32 {
  238. value, exists := llm.KV[fmt.Sprintf("%s.attention.head_count", llm.ModelFamily())]
  239. if !exists {
  240. return 0
  241. }
  242. return value.(uint32)
  243. }
  244. func (llm *gguf) NumEmbed() uint32 {
  245. value, exists := llm.KV[fmt.Sprintf("%s.embedding_length", llm.ModelFamily())]
  246. if !exists {
  247. return 0
  248. }
  249. return value.(uint32)
  250. }
  251. func (llm *gguf) NumHeadKv() uint32 {
  252. value, exists := llm.KV[fmt.Sprintf("%s.attention.head_count_kv", llm.ModelFamily())]
  253. if !exists {
  254. return 0
  255. }
  256. return value.(uint32)
  257. }
  258. func (llm *gguf) NumCtx() uint32 {
  259. value, exists := llm.KV[fmt.Sprintf("%s.context_length", llm.ModelFamily())]
  260. if !exists {
  261. return 0
  262. }
  263. return value.(uint32)
  264. }
  265. func (llm *gguf) NumGQA() uint32 {
  266. numHeadKv := llm.NumHeadKv()
  267. if numHeadKv == 0 {
  268. return 0
  269. }
  270. return llm.NumHead() / numHeadKv
  271. }
  272. func readGGUF[T any](llm *gguf, r io.Reader) (T, error) {
  273. var t T
  274. err := binary.Read(r, llm.ByteOrder, &t)
  275. return t, err
  276. }
  277. func writeGGUF[V any](llm *gguf, w io.Writer, t uint32, v V) error {
  278. if err := binary.Write(w, llm.ByteOrder, t); err != nil {
  279. return err
  280. }
  281. return binary.Write(w, llm.ByteOrder, v)
  282. }
  283. func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
  284. var length uint64
  285. if err := binary.Read(r, llm.ByteOrder, &length); err != nil {
  286. return "", err
  287. }
  288. var b bytes.Buffer
  289. if _, err := io.CopyN(&b, r, int64(length)); err != nil {
  290. return "", err
  291. }
  292. // gguf v1 strings are null-terminated
  293. b.Truncate(b.Len() - 1)
  294. return b.String(), nil
  295. }
  296. func readGGUFString(llm *gguf, r io.Reader) (string, error) {
  297. if llm.Version == 1 {
  298. return readGGUFV1String(llm, r)
  299. }
  300. var length uint64
  301. if err := binary.Read(r, llm.ByteOrder, &length); err != nil {
  302. return "", err
  303. }
  304. var b bytes.Buffer
  305. if _, err := io.CopyN(&b, r, int64(length)); err != nil {
  306. return "", err
  307. }
  308. return b.String(), nil
  309. }
  310. func writeGGUFString(llm *gguf, w io.Writer, s string) error {
  311. if err := binary.Write(w, llm.ByteOrder, ggufTypeString); err != nil {
  312. return err
  313. }
  314. if err := binary.Write(w, llm.ByteOrder, uint64(len(s))); err != nil {
  315. return err
  316. }
  317. _, err := io.Copy(w, strings.NewReader(s))
  318. return err
  319. }
  320. func readGGUFV1Array(llm *gguf, r io.Reader) (a []any, err error) {
  321. t, err := readGGUF[uint32](llm, r)
  322. if err != nil {
  323. return nil, err
  324. }
  325. n, err := readGGUF[uint32](llm, r)
  326. if err != nil {
  327. return nil, err
  328. }
  329. for i := 0; uint32(i) < n; i++ {
  330. var e any
  331. switch t {
  332. case ggufTypeUint8:
  333. e, err = readGGUF[uint8](llm, r)
  334. case ggufTypeInt8:
  335. e, err = readGGUF[int8](llm, r)
  336. case ggufTypeUint16:
  337. e, err = readGGUF[uint16](llm, r)
  338. case ggufTypeInt16:
  339. e, err = readGGUF[int16](llm, r)
  340. case ggufTypeUint32:
  341. e, err = readGGUF[uint32](llm, r)
  342. case ggufTypeInt32:
  343. e, err = readGGUF[int32](llm, r)
  344. case ggufTypeUint64:
  345. e, err = readGGUF[uint64](llm, r)
  346. case ggufTypeInt64:
  347. e, err = readGGUF[int64](llm, r)
  348. case ggufTypeFloat32:
  349. e, err = readGGUF[float32](llm, r)
  350. case ggufTypeFloat64:
  351. e, err = readGGUF[float64](llm, r)
  352. case ggufTypeBool:
  353. e, err = readGGUF[bool](llm, r)
  354. case ggufTypeString:
  355. e, err = readGGUFV1String(llm, r)
  356. default:
  357. return nil, fmt.Errorf("invalid array type: %d", t)
  358. }
  359. if err != nil {
  360. return nil, err
  361. }
  362. a = append(a, e)
  363. }
  364. return
  365. }
  366. func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) {
  367. if llm.Version == 1 {
  368. return readGGUFV1Array(llm, r)
  369. }
  370. t, err := readGGUF[uint32](llm, r)
  371. if err != nil {
  372. return nil, err
  373. }
  374. n, err := readGGUF[uint64](llm, r)
  375. if err != nil {
  376. return nil, err
  377. }
  378. for i := 0; uint64(i) < n; i++ {
  379. var e any
  380. switch t {
  381. case ggufTypeUint8:
  382. e, err = readGGUF[uint8](llm, r)
  383. case ggufTypeInt8:
  384. e, err = readGGUF[int8](llm, r)
  385. case ggufTypeUint16:
  386. e, err = readGGUF[uint16](llm, r)
  387. case ggufTypeInt16:
  388. e, err = readGGUF[int16](llm, r)
  389. case ggufTypeUint32:
  390. e, err = readGGUF[uint32](llm, r)
  391. case ggufTypeInt32:
  392. e, err = readGGUF[int32](llm, r)
  393. case ggufTypeUint64:
  394. e, err = readGGUF[uint64](llm, r)
  395. case ggufTypeInt64:
  396. e, err = readGGUF[int64](llm, r)
  397. case ggufTypeFloat32:
  398. e, err = readGGUF[float32](llm, r)
  399. case ggufTypeFloat64:
  400. e, err = readGGUF[float64](llm, r)
  401. case ggufTypeBool:
  402. e, err = readGGUF[bool](llm, r)
  403. case ggufTypeString:
  404. e, err = readGGUFString(llm, r)
  405. default:
  406. return nil, fmt.Errorf("invalid array type: %d", t)
  407. }
  408. if err != nil {
  409. return nil, err
  410. }
  411. a = append(a, e)
  412. }
  413. return
  414. }
  415. func writeGGUFArray[S ~[]E, E any](llm *gguf, w io.Writer, t uint32, s S) error {
  416. if err := binary.Write(w, llm.ByteOrder, ggufTypeArray); err != nil {
  417. return err
  418. }
  419. if err := binary.Write(w, llm.ByteOrder, t); err != nil {
  420. return err
  421. }
  422. if err := binary.Write(w, llm.ByteOrder, uint64(len(s))); err != nil {
  423. return err
  424. }
  425. for _, e := range s {
  426. if err := binary.Write(w, llm.ByteOrder, e); err != nil {
  427. return err
  428. }
  429. }
  430. return nil
  431. }
  432. var ggufKVOrder = map[string][]string{
  433. "llama": {
  434. "general.architecture",
  435. "general.name",
  436. "llama.context_length",
  437. "llama.embedding_length",
  438. "llama.block_count",
  439. "llama.feed_forward_length",
  440. "llama.rope.dimension_count",
  441. "llama.attention.head_count",
  442. "llama.attention.head_count_kv",
  443. "llama.attention.layer_norm_rms_epsilon",
  444. "llama.rope.freq_base",
  445. "gemma.context_length",
  446. "gemma.embedding_length",
  447. "gemma.block_count",
  448. "gemma.feed_forward_length",
  449. "gemma.attention.head_count",
  450. "gemma.attention.head_count_kv",
  451. "gemma.attention.layer_norm_rms_epsilon",
  452. "gemma.attention.key_length",
  453. "gemma.attention.value_length",
  454. "general.file_type",
  455. "tokenizer.ggml.model",
  456. "tokenizer.ggml.tokens",
  457. "tokenizer.ggml.scores",
  458. "tokenizer.ggml.token_type",
  459. "tokenizer.ggml.bos_token_id",
  460. "tokenizer.ggml.eos_token_id",
  461. "tokenizer.ggml.unknown_token_id",
  462. "tokenizer.ggml.padding_token_id",
  463. "tokenizer.ggml.add_bos_token",
  464. "tokenizer.ggml.add_eos_token",
  465. "tokenizer.chat_template",
  466. },
  467. }
  468. func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
  469. switch llm.Version {
  470. case 3:
  471. llm.V3.NumTensor = uint64(len(tensors))
  472. llm.V3.NumKV = uint64(len(kv))
  473. default:
  474. return fmt.Errorf("not implemented: ggufv%d", llm.Version)
  475. }
  476. if err := binary.Write(ws, llm.ByteOrder, []byte("GGUF")); err != nil {
  477. return err
  478. }
  479. if err := binary.Write(ws, llm.ByteOrder, llm.Version); err != nil {
  480. return err
  481. }
  482. if err := binary.Write(ws, llm.ByteOrder, llm.numTensor()); err != nil {
  483. return err
  484. }
  485. if err := binary.Write(ws, llm.ByteOrder, llm.numKV()); err != nil {
  486. return err
  487. }
  488. for _, k := range ggufKVOrder["llama"] {
  489. v, ok := kv[k]
  490. if !ok {
  491. continue
  492. }
  493. if err := binary.Write(ws, llm.ByteOrder, uint64(len(k))); err != nil {
  494. return err
  495. }
  496. if err := binary.Write(ws, llm.ByteOrder, []byte(k)); err != nil {
  497. return err
  498. }
  499. var err error
  500. switch v := v.(type) {
  501. case uint32:
  502. err = writeGGUF(llm, ws, ggufTypeUint32, v)
  503. case float32:
  504. err = writeGGUF(llm, ws, ggufTypeFloat32, v)
  505. case bool:
  506. err = writeGGUF(llm, ws, ggufTypeBool, v)
  507. case string:
  508. err = writeGGUFString(llm, ws, v)
  509. case []int32:
  510. err = writeGGUFArray(llm, ws, ggufTypeInt32, v)
  511. case []uint32:
  512. err = writeGGUFArray(llm, ws, ggufTypeUint32, v)
  513. case []float32:
  514. err = writeGGUFArray(llm, ws, ggufTypeFloat32, v)
  515. case []string:
  516. if err := binary.Write(ws, llm.ByteOrder, ggufTypeArray); err != nil {
  517. return err
  518. }
  519. if err := binary.Write(ws, llm.ByteOrder, ggufTypeString); err != nil {
  520. return err
  521. }
  522. if err := binary.Write(ws, llm.ByteOrder, uint64(len(v))); err != nil {
  523. return err
  524. }
  525. for _, e := range v {
  526. if err := binary.Write(ws, llm.ByteOrder, uint64(len(e))); err != nil {
  527. return err
  528. }
  529. if err := binary.Write(ws, llm.ByteOrder, []byte(e)); err != nil {
  530. return err
  531. }
  532. }
  533. }
  534. if err != nil {
  535. return err
  536. }
  537. }
  538. for _, tensor := range tensors {
  539. if err := binary.Write(ws, llm.ByteOrder, uint64(len(tensor.Name))); err != nil {
  540. return err
  541. }
  542. if err := binary.Write(ws, llm.ByteOrder, []byte(tensor.Name)); err != nil {
  543. return err
  544. }
  545. dims := 1
  546. if tensor.Shape[1] > 0 {
  547. dims = 2
  548. }
  549. if err := binary.Write(ws, llm.ByteOrder, uint32(dims)); err != nil {
  550. return err
  551. }
  552. for i := 0; i < dims; i++ {
  553. if err := binary.Write(ws, llm.ByteOrder, uint64(tensor.Shape[dims-1-i])); err != nil {
  554. return err
  555. }
  556. }
  557. if err := binary.Write(ws, llm.ByteOrder, tensor.Kind); err != nil {
  558. return err
  559. }
  560. if err := binary.Write(ws, llm.ByteOrder, tensor.Offset); err != nil {
  561. return err
  562. }
  563. }
  564. offset, err := ws.Seek(0, io.SeekCurrent)
  565. if err != nil {
  566. return err
  567. }
  568. padding := llm.padding(offset, 32)
  569. if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding-offset))); err != nil {
  570. return err
  571. }
  572. for _, tensor := range tensors {
  573. if _, err := tensor.WriteTo(ws); err != nil {
  574. return err
  575. }
  576. offset, err := ws.Seek(0, io.SeekCurrent)
  577. if err != nil {
  578. return err
  579. }
  580. padding := llm.padding(offset, 32)
  581. if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding-offset))); err != nil {
  582. return err
  583. }
  584. }
  585. return nil
  586. }
  587. func (gguf) padding(offset, align int64) int64 {
  588. return (offset + align - 1) / align * align
  589. }