gguf.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923
  1. package llm
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "log/slog"
  9. "slices"
  10. "sort"
  11. "strings"
  12. "golang.org/x/exp/maps"
  13. )
  14. type containerGGUF struct {
  15. ByteOrder binary.ByteOrder
  16. Version uint32
  17. V1 struct {
  18. NumTensor uint32
  19. NumKV uint32
  20. }
  21. V2 struct {
  22. NumTensor uint64
  23. NumKV uint64
  24. }
  25. V3 struct {
  26. NumTensor uint64
  27. NumKV uint64
  28. }
  29. maxArraySize int
  30. }
  31. func (c *containerGGUF) canCollectArray(size int) bool {
  32. return c.maxArraySize < 0 || size <= c.maxArraySize
  33. }
  34. func (c *containerGGUF) Name() string {
  35. return "gguf"
  36. }
  37. func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
  38. if err := binary.Read(rs, c.ByteOrder, &c.Version); err != nil {
  39. return nil, err
  40. }
  41. var err error
  42. switch c.Version {
  43. case 1:
  44. err = binary.Read(rs, c.ByteOrder, &c.V1)
  45. case 2:
  46. err = binary.Read(rs, c.ByteOrder, &c.V2)
  47. default:
  48. err = binary.Read(rs, c.ByteOrder, &c.V3)
  49. }
  50. if err != nil {
  51. return nil, err
  52. }
  53. model := newGGUF(c)
  54. if err := model.Decode(rs); err != nil {
  55. return nil, err
  56. }
  57. return model, nil
  58. }
  59. const (
  60. ggufTypeUint8 uint32 = iota
  61. ggufTypeInt8
  62. ggufTypeUint16
  63. ggufTypeInt16
  64. ggufTypeUint32
  65. ggufTypeInt32
  66. ggufTypeFloat32
  67. ggufTypeBool
  68. ggufTypeString
  69. ggufTypeArray
  70. ggufTypeUint64
  71. ggufTypeInt64
  72. ggufTypeFloat64
  73. )
  74. type gguf struct {
  75. *containerGGUF
  76. kv KV
  77. tensors []*Tensor
  78. offset int64
  79. parameters uint64
  80. scratch [16 << 10]byte
  81. }
  82. func newGGUF(container *containerGGUF) *gguf {
  83. return &gguf{
  84. containerGGUF: container,
  85. kv: make(KV),
  86. }
  87. }
  88. func NewGGUFV3(bo binary.ByteOrder) *gguf {
  89. return newGGUF(&containerGGUF{ByteOrder: bo, Version: 3})
  90. }
  91. func (llm *gguf) KV() KV {
  92. return llm.kv
  93. }
  94. func (llm *gguf) Tensors() Tensors {
  95. return Tensors{
  96. Items: llm.tensors,
  97. Offset: llm.offset,
  98. }
  99. }
  100. func (llm *gguf) numTensor() uint64 {
  101. switch llm.Version {
  102. case 1:
  103. return uint64(llm.V1.NumTensor)
  104. case 2:
  105. return llm.V2.NumTensor
  106. default:
  107. return llm.V3.NumTensor
  108. }
  109. }
  110. func (llm *gguf) numKV() uint64 {
  111. switch llm.Version {
  112. case 1:
  113. return uint64(llm.V1.NumKV)
  114. case 2:
  115. return llm.V2.NumKV
  116. default:
  117. return llm.V3.NumKV
  118. }
  119. }
  120. func (llm *gguf) Decode(rs io.ReadSeeker) error {
  121. // decode key-values
  122. for i := 0; uint64(i) < llm.numKV(); i++ {
  123. k, err := readGGUFString(llm, rs)
  124. if err != nil {
  125. return err
  126. }
  127. t, err := readGGUF[uint32](llm, rs)
  128. if err != nil {
  129. return err
  130. }
  131. var v any
  132. switch t {
  133. case ggufTypeUint8:
  134. v, err = readGGUF[uint8](llm, rs)
  135. case ggufTypeInt8:
  136. v, err = readGGUF[int8](llm, rs)
  137. case ggufTypeUint16:
  138. v, err = readGGUF[uint16](llm, rs)
  139. case ggufTypeInt16:
  140. v, err = readGGUF[int16](llm, rs)
  141. case ggufTypeUint32:
  142. v, err = readGGUF[uint32](llm, rs)
  143. case ggufTypeInt32:
  144. v, err = readGGUF[int32](llm, rs)
  145. case ggufTypeUint64:
  146. v, err = readGGUF[uint64](llm, rs)
  147. case ggufTypeInt64:
  148. v, err = readGGUF[int64](llm, rs)
  149. case ggufTypeFloat32:
  150. v, err = readGGUF[float32](llm, rs)
  151. case ggufTypeFloat64:
  152. v, err = readGGUF[float64](llm, rs)
  153. case ggufTypeBool:
  154. v, err = readGGUF[bool](llm, rs)
  155. case ggufTypeString:
  156. v, err = readGGUFString(llm, rs)
  157. case ggufTypeArray:
  158. v, err = readGGUFArray(llm, rs)
  159. default:
  160. return fmt.Errorf("invalid type: %d", t)
  161. }
  162. if err != nil {
  163. return err
  164. }
  165. llm.kv[k] = v
  166. }
  167. // decode tensors
  168. for range llm.numTensor() {
  169. name, err := readGGUFString(llm, rs)
  170. if err != nil {
  171. return fmt.Errorf("failed to read tensor name: %w", err)
  172. }
  173. // dims is the number of dimensions in the tensor
  174. dims, err := readGGUF[uint32](llm, rs)
  175. if err != nil {
  176. return fmt.Errorf("failed to read tensor dimensions: %w", err)
  177. }
  178. shape := []uint64{}
  179. for i := 0; uint32(i) < dims; i++ {
  180. shapeVal, err := readGGUF[uint64](llm, rs)
  181. if err != nil {
  182. return fmt.Errorf("failed to read tensor shape: %w", err)
  183. }
  184. shape = append(shape, shapeVal)
  185. }
  186. kind, err := readGGUF[uint32](llm, rs)
  187. if err != nil {
  188. return fmt.Errorf("failed to read tensor kind: %w", err)
  189. }
  190. offset, err := readGGUF[uint64](llm, rs)
  191. if err != nil {
  192. return fmt.Errorf("failed to read tensor offset: %w", err)
  193. }
  194. fmt.Println("tensor", name, shape, kind, offset)
  195. tensor := Tensor{
  196. Name: name,
  197. Kind: kind,
  198. Offset: offset,
  199. Shape: shape,
  200. }
  201. llm.tensors = append(llm.tensors, &tensor)
  202. llm.parameters += tensor.parameters()
  203. }
  204. // patch KV with parameter count
  205. llm.kv["general.parameter_count"] = llm.parameters
  206. alignment, ok := llm.kv["general.alignment"].(uint32)
  207. if !ok {
  208. alignment = 32
  209. }
  210. offset, err := rs.Seek(0, io.SeekCurrent)
  211. if err != nil {
  212. return fmt.Errorf("failed to get current offset: %w", err)
  213. }
  214. // align to next 32-byte boundary
  215. llm.offset = offset + llm.padding(offset, int64(alignment))
  216. for _, tensor := range llm.tensors {
  217. offset, err := rs.Seek(0, io.SeekCurrent)
  218. if err != nil {
  219. return fmt.Errorf("failed to get current offset: %w", err)
  220. }
  221. padding := llm.padding(offset, int64(alignment))
  222. if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
  223. return fmt.Errorf("failed to seek to init padding: %w", err)
  224. }
  225. if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil {
  226. return fmt.Errorf("failed to seek to tensor: %w", err)
  227. }
  228. }
  229. return nil
  230. }
  231. func readGGUF[T any](llm *gguf, r io.Reader) (T, error) {
  232. var t T
  233. err := binary.Read(r, llm.ByteOrder, &t)
  234. return t, err
  235. }
  236. func writeGGUF[V any](w io.Writer, t uint32, v V) error {
  237. if err := binary.Write(w, binary.LittleEndian, t); err != nil {
  238. return err
  239. }
  240. return binary.Write(w, binary.LittleEndian, v)
  241. }
  242. func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
  243. var length uint64
  244. if err := binary.Read(r, llm.ByteOrder, &length); err != nil {
  245. return "", err
  246. }
  247. var b bytes.Buffer
  248. if _, err := io.CopyN(&b, r, int64(length)); err != nil {
  249. return "", err
  250. }
  251. // gguf v1 strings are null-terminated
  252. b.Truncate(b.Len() - 1)
  253. return b.String(), nil
  254. }
  255. func discardGGUFString(llm *gguf, r io.Reader) error {
  256. buf := llm.scratch[:8]
  257. _, err := io.ReadFull(r, buf)
  258. if err != nil {
  259. return err
  260. }
  261. size := int(llm.ByteOrder.Uint64(buf))
  262. for size > 0 {
  263. n, err := r.Read(llm.scratch[:min(size, cap(llm.scratch))])
  264. if err != nil {
  265. return err
  266. }
  267. size -= n
  268. }
  269. return nil
  270. }
  271. func readGGUFString(llm *gguf, r io.Reader) (string, error) {
  272. if llm.Version == 1 {
  273. return readGGUFV1String(llm, r)
  274. }
  275. buf := llm.scratch[:8]
  276. _, err := io.ReadFull(r, buf)
  277. if err != nil {
  278. return "", err
  279. }
  280. length := int(llm.ByteOrder.Uint64(buf))
  281. if length > len(llm.scratch) {
  282. buf = make([]byte, length)
  283. } else {
  284. buf = llm.scratch[:length]
  285. }
  286. clear(buf)
  287. _, err = io.ReadFull(r, buf)
  288. if err != nil {
  289. return "", err
  290. }
  291. return string(buf), nil
  292. }
  293. func writeGGUFString(w io.Writer, s string) error {
  294. if err := binary.Write(w, binary.LittleEndian, ggufTypeString); err != nil {
  295. return err
  296. }
  297. if err := binary.Write(w, binary.LittleEndian, uint64(len(s))); err != nil {
  298. return err
  299. }
  300. _, err := io.Copy(w, strings.NewReader(s))
  301. return err
  302. }
  303. type array struct {
  304. size int
  305. values []any
  306. datatype uint32
  307. }
  308. func (a *array) MarshalJSON() ([]byte, error) {
  309. return json.Marshal(a.values)
  310. }
  311. func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
  312. t, err := readGGUF[uint32](llm, r)
  313. if err != nil {
  314. return nil, err
  315. }
  316. n, err := readGGUF[uint32](llm, r)
  317. if err != nil {
  318. return nil, err
  319. }
  320. a := &array{size: int(n)}
  321. if llm.canCollectArray(int(n)) {
  322. a.values = make([]any, 0, int(n))
  323. }
  324. for i := range n {
  325. var e any
  326. switch t {
  327. case ggufTypeUint8:
  328. e, err = readGGUF[uint8](llm, r)
  329. case ggufTypeInt8:
  330. e, err = readGGUF[int8](llm, r)
  331. case ggufTypeUint16:
  332. e, err = readGGUF[uint16](llm, r)
  333. case ggufTypeInt16:
  334. e, err = readGGUF[int16](llm, r)
  335. case ggufTypeUint32:
  336. e, err = readGGUF[uint32](llm, r)
  337. case ggufTypeInt32:
  338. e, err = readGGUF[int32](llm, r)
  339. case ggufTypeUint64:
  340. e, err = readGGUF[uint64](llm, r)
  341. case ggufTypeInt64:
  342. e, err = readGGUF[int64](llm, r)
  343. case ggufTypeFloat32:
  344. e, err = readGGUF[float32](llm, r)
  345. case ggufTypeFloat64:
  346. e, err = readGGUF[float64](llm, r)
  347. case ggufTypeBool:
  348. e, err = readGGUF[bool](llm, r)
  349. case ggufTypeString:
  350. e, err = readGGUFV1String(llm, r)
  351. default:
  352. return nil, fmt.Errorf("invalid array type: %d", t)
  353. }
  354. if err != nil {
  355. return nil, err
  356. }
  357. if a.values != nil {
  358. a.values[i] = e
  359. }
  360. }
  361. return a, nil
  362. }
  363. func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
  364. if llm.Version == 1 {
  365. return readGGUFV1Array(llm, r)
  366. }
  367. t, err := readGGUF[uint32](llm, r)
  368. if err != nil {
  369. return nil, err
  370. }
  371. n, err := readGGUF[uint64](llm, r)
  372. if err != nil {
  373. return nil, err
  374. }
  375. a := &array{size: int(n), datatype: t}
  376. if llm.canCollectArray(int(n)) {
  377. a.values = make([]any, int(n))
  378. }
  379. for i := range n {
  380. var e any
  381. switch t {
  382. case ggufTypeUint8:
  383. e, err = readGGUF[uint8](llm, r)
  384. case ggufTypeInt8:
  385. e, err = readGGUF[int8](llm, r)
  386. case ggufTypeUint16:
  387. e, err = readGGUF[uint16](llm, r)
  388. case ggufTypeInt16:
  389. e, err = readGGUF[int16](llm, r)
  390. case ggufTypeUint32:
  391. e, err = readGGUF[uint32](llm, r)
  392. case ggufTypeInt32:
  393. e, err = readGGUF[int32](llm, r)
  394. case ggufTypeUint64:
  395. e, err = readGGUF[uint64](llm, r)
  396. case ggufTypeInt64:
  397. e, err = readGGUF[int64](llm, r)
  398. case ggufTypeFloat32:
  399. e, err = readGGUF[float32](llm, r)
  400. case ggufTypeFloat64:
  401. e, err = readGGUF[float64](llm, r)
  402. case ggufTypeBool:
  403. e, err = readGGUF[bool](llm, r)
  404. case ggufTypeString:
  405. if a.values != nil {
  406. e, err = readGGUFString(llm, r)
  407. } else {
  408. err = discardGGUFString(llm, r)
  409. }
  410. default:
  411. return nil, fmt.Errorf("invalid array type: %d", t)
  412. }
  413. if err != nil {
  414. return nil, err
  415. }
  416. if a.values != nil {
  417. a.values[i] = e
  418. }
  419. }
  420. return a, nil
  421. }
  422. func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
  423. if err := binary.Write(w, binary.LittleEndian, ggufTypeArray); err != nil {
  424. return err
  425. }
  426. if err := binary.Write(w, binary.LittleEndian, t); err != nil {
  427. return err
  428. }
  429. if err := binary.Write(w, binary.LittleEndian, uint64(len(s))); err != nil {
  430. return err
  431. }
  432. for _, e := range s {
  433. if err := binary.Write(w, binary.LittleEndian, e); err != nil {
  434. return err
  435. }
  436. }
  437. return nil
  438. }
  439. var ggufKVOrder = map[string][]string{
  440. "llama": {
  441. "general.architecture",
  442. "general.name",
  443. "llama.vocab_size",
  444. "llama.context_length",
  445. "llama.embedding_length",
  446. "llama.block_count",
  447. "llama.feed_forward_length",
  448. "llama.attention.head_count",
  449. "llama.attention.head_count_kv",
  450. "llama.attention.layer_norm_rms_epsilon",
  451. "llama.rope.freq_base",
  452. "llama.rope.dimension_count",
  453. "llama.expert_count",
  454. "llama.expert_used_count",
  455. "gemma.context_length",
  456. "gemma.embedding_length",
  457. "gemma.block_count",
  458. "gemma.feed_forward_length",
  459. "gemma.attention.head_count",
  460. "gemma.attention.head_count_kv",
  461. "gemma.attention.layer_norm_rms_epsilon",
  462. "gemma.attention.key_length",
  463. "gemma.attention.value_length",
  464. "general.file_type",
  465. "tokenizer.ggml.pre",
  466. "tokenizer.ggml.model",
  467. "tokenizer.ggml.tokens",
  468. "tokenizer.ggml.scores",
  469. "tokenizer.ggml.merges",
  470. "tokenizer.ggml.token_type",
  471. "tokenizer.ggml.bos_token_id",
  472. "tokenizer.ggml.eos_token_id",
  473. "tokenizer.ggml.unknown_token_id",
  474. "tokenizer.ggml.padding_token_id",
  475. "tokenizer.ggml.add_bos_token",
  476. "tokenizer.ggml.add_eos_token",
  477. "tokenizer.chat_template",
  478. "bert.pooling_type",
  479. },
  480. }
  481. func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
  482. switch llm.Version {
  483. case 3:
  484. llm.V3.NumTensor = uint64(len(tensors))
  485. llm.V3.NumKV = uint64(len(kv))
  486. default:
  487. return fmt.Errorf("not implemented: ggufv%d", llm.Version)
  488. }
  489. if err := binary.Write(ws, llm.ByteOrder, []byte("GGUF")); err != nil {
  490. return err
  491. }
  492. if err := binary.Write(ws, llm.ByteOrder, llm.Version); err != nil {
  493. return err
  494. }
  495. if err := binary.Write(ws, llm.ByteOrder, llm.numTensor()); err != nil {
  496. return err
  497. }
  498. if err := binary.Write(ws, llm.ByteOrder, llm.numKV()); err != nil {
  499. return err
  500. }
  501. kvCheck := make(map[string]bool)
  502. for k := range kv {
  503. kvCheck[k] = false
  504. }
  505. for _, k := range ggufKVOrder["llama"] {
  506. v, ok := kv[k]
  507. if !ok {
  508. continue
  509. }
  510. kvCheck[k] = true
  511. if err := binary.Write(ws, llm.ByteOrder, uint64(len(k))); err != nil {
  512. return err
  513. }
  514. if err := binary.Write(ws, llm.ByteOrder, []byte(k)); err != nil {
  515. return err
  516. }
  517. var err error
  518. switch v := v.(type) {
  519. case uint32:
  520. err = writeGGUF(ws, ggufTypeUint32, v)
  521. case float32:
  522. err = writeGGUF(ws, ggufTypeFloat32, v)
  523. case bool:
  524. err = writeGGUF(ws, ggufTypeBool, v)
  525. case string:
  526. err = writeGGUFString(ws, v)
  527. case []int32:
  528. err = writeGGUFArray(ws, ggufTypeInt32, v)
  529. case []uint32:
  530. err = writeGGUFArray(ws, ggufTypeUint32, v)
  531. case []float32:
  532. err = writeGGUFArray(ws, ggufTypeFloat32, v)
  533. case []string:
  534. if err := binary.Write(ws, llm.ByteOrder, ggufTypeArray); err != nil {
  535. return err
  536. }
  537. if err := binary.Write(ws, llm.ByteOrder, ggufTypeString); err != nil {
  538. return err
  539. }
  540. if err := binary.Write(ws, llm.ByteOrder, uint64(len(v))); err != nil {
  541. return err
  542. }
  543. for _, e := range v {
  544. if err := binary.Write(ws, llm.ByteOrder, uint64(len(e))); err != nil {
  545. return err
  546. }
  547. if err := binary.Write(ws, llm.ByteOrder, []byte(e)); err != nil {
  548. return err
  549. }
  550. }
  551. default:
  552. return fmt.Errorf("improper type for '%s'", k)
  553. }
  554. if err != nil {
  555. return err
  556. }
  557. }
  558. for k, v := range kvCheck {
  559. if !v {
  560. return fmt.Errorf("didn't know how to write kv %s", k)
  561. }
  562. }
  563. for _, tensor := range tensors {
  564. if err := binary.Write(ws, llm.ByteOrder, uint64(len(tensor.Name))); err != nil {
  565. return err
  566. }
  567. if err := binary.Write(ws, llm.ByteOrder, []byte(tensor.Name)); err != nil {
  568. return err
  569. }
  570. var dims int
  571. for cnt := range len(tensor.Shape) {
  572. if tensor.Shape[cnt] > 0 {
  573. dims++
  574. }
  575. }
  576. if err := binary.Write(ws, llm.ByteOrder, uint32(dims)); err != nil {
  577. return err
  578. }
  579. for i := range dims {
  580. if err := binary.Write(ws, llm.ByteOrder, tensor.Shape[dims-1-i]); err != nil {
  581. return err
  582. }
  583. }
  584. if err := binary.Write(ws, llm.ByteOrder, tensor.Kind); err != nil {
  585. return err
  586. }
  587. if err := binary.Write(ws, llm.ByteOrder, tensor.Offset); err != nil {
  588. return err
  589. }
  590. }
  591. var alignment int64 = 32
  592. for _, tensor := range tensors {
  593. offset, err := ws.Seek(0, io.SeekCurrent)
  594. if err != nil {
  595. return err
  596. }
  597. padding := llm.padding(offset, alignment)
  598. if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
  599. return err
  600. }
  601. if _, err := tensor.WriteTo(ws); err != nil {
  602. return err
  603. }
  604. }
  605. return nil
  606. }
  607. func (gguf) padding(offset, align int64) int64 {
  608. return (align - offset%align) % align
  609. }
  610. // Reader and WriterTof
  611. type GGUFWriter struct {
  612. KV
  613. Tensors
  614. }
  615. type writeOffset struct {
  616. io.Writer
  617. offset int
  618. }
  619. func (wo *writeOffset) Write(p []byte) (int, error) {
  620. n, err := wo.Writer.Write(p)
  621. wo.offset += n
  622. return n, err
  623. }
  624. var _ io.Reader = (*GGUFWriter)(nil)
  625. var _ io.WriterTo = (*GGUFWriter)(nil)
  626. func (GGUFWriter) Read([]byte) (int, error) {
  627. panic("not implemeneted")
  628. }
  629. func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
  630. wo := &writeOffset{Writer: w}
  631. if err := binary.Write(wo, binary.LittleEndian, []byte("GGUF")); err != nil {
  632. return 0, err
  633. }
  634. if err := binary.Write(wo, binary.LittleEndian, uint32(3)); err != nil {
  635. return 0, err
  636. }
  637. if err := binary.Write(wo, binary.LittleEndian, uint64(len(gguf.Tensors.Items))); err != nil {
  638. return 0, err
  639. }
  640. if err := binary.Write(wo, binary.LittleEndian, uint64(len(gguf.KV)-1)); err != nil {
  641. return 0, err
  642. }
  643. keys := maps.Keys(gguf.KV)
  644. slices.Sort(keys)
  645. for _, key := range keys {
  646. fmt.Println(key)
  647. switch key {
  648. case "general.parameter_count":
  649. continue
  650. default:
  651. if err := ggufWriteKV(wo, key, gguf.KV[key]); err != nil {
  652. return 0, err
  653. }
  654. }
  655. }
  656. sort.Sort(gguf.Tensors)
  657. var s uint64
  658. for _, t := range gguf.Tensors.Items {
  659. t.Offset = s
  660. if err := ggufWriteTensorInfo(wo, t); err != nil {
  661. return 0, err
  662. }
  663. s += t.Size()
  664. }
  665. tensorOffset := wo.offset
  666. for _, t := range gguf.Tensors.Items {
  667. if err := ggufWriteTensor(wo, t, wo.offset); err != nil {
  668. return 0, err
  669. }
  670. }
  671. return int64(tensorOffset), nil
  672. }
  673. func ggufWriteTensorInfo(ws io.Writer, t *Tensor) error {
  674. if err := binary.Write(ws, binary.LittleEndian, uint64(len(t.Name))); err != nil {
  675. return err
  676. }
  677. if err := binary.Write(ws, binary.LittleEndian, []byte(t.Name)); err != nil {
  678. return err
  679. }
  680. if err := binary.Write(ws, binary.LittleEndian, uint32(len(t.Shape))); err != nil {
  681. return err
  682. }
  683. for i := range len(t.Shape) {
  684. if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil {
  685. return err
  686. }
  687. }
  688. if err := binary.Write(ws, binary.LittleEndian, t.Kind); err != nil {
  689. return err
  690. }
  691. return binary.Write(ws, binary.LittleEndian, t.Offset)
  692. }
  693. func ggufWriteTensor(ws io.Writer, t *Tensor, offset int) error {
  694. slog.Debug(t.Name, "kind", t.Kind, "shape", t.Shape, "offset", t.Offset)
  695. if err := binary.Write(ws, binary.LittleEndian, bytes.Repeat([]byte{0}, int(ggufPadding(int64(offset), 32)))); err != nil {
  696. return err
  697. }
  698. _, err := t.WriteTo(ws)
  699. return err
  700. }
  701. func ggufWriteKV(ws io.Writer, k string, v any) error {
  702. slog.Debug(k, "type", fmt.Sprintf("%T", v))
  703. if err := binary.Write(ws, binary.LittleEndian, uint64(len(k))); err != nil {
  704. return err
  705. }
  706. if err := binary.Write(ws, binary.LittleEndian, []byte(k)); err != nil {
  707. return err
  708. }
  709. var err error
  710. switch v := v.(type) {
  711. case uint32:
  712. err = writeGGUF(ws, ggufTypeUint32, v)
  713. case float32:
  714. err = writeGGUF(ws, ggufTypeFloat32, v)
  715. case bool:
  716. err = writeGGUF(ws, ggufTypeBool, v)
  717. case string:
  718. err = writeGGUFString(ws, v)
  719. case []int32:
  720. err = writeGGUFArray(ws, ggufTypeInt32, v)
  721. case []uint32:
  722. err = writeGGUFArray(ws, ggufTypeUint32, v)
  723. case []float32:
  724. err = writeGGUFArray(ws, ggufTypeFloat32, v)
  725. case []string:
  726. if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
  727. return err
  728. }
  729. if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
  730. return err
  731. }
  732. if err := binary.Write(ws, binary.LittleEndian, uint64(len(v))); err != nil {
  733. return err
  734. }
  735. for _, e := range v {
  736. if err := binary.Write(ws, binary.LittleEndian, uint64(len(e))); err != nil {
  737. return err
  738. }
  739. if err := binary.Write(ws, binary.LittleEndian, []byte(e)); err != nil {
  740. return err
  741. }
  742. }
  743. case *array:
  744. if v.size > 0 {
  745. switch v.values[0].(type) {
  746. case string:
  747. if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
  748. return err
  749. }
  750. if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
  751. return err
  752. }
  753. if err := binary.Write(ws, binary.LittleEndian, uint64(v.size)); err != nil {
  754. return err
  755. }
  756. for _, e := range v.values {
  757. if err := binary.Write(ws, binary.LittleEndian, uint64(len(e.(string)))); err != nil {
  758. return err
  759. }
  760. if err := binary.Write(ws, binary.LittleEndian, []byte(e.(string))); err != nil {
  761. return err
  762. }
  763. }
  764. default:
  765. err = writeGGUFArray(ws, v.datatype, v.values)
  766. }
  767. }
  768. default:
  769. return fmt.Errorf("improper type for '%s'", k)
  770. }
  771. return err
  772. }
  773. func ggufPadding(offset, align int64) int64 {
  774. // we mod twice in the case offset%align = 0
  775. return (align - offset%align) % align
  776. }