gguf.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922
  1. package llm
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "log/slog"
  9. "slices"
  10. "sort"
  11. "strings"
  12. "golang.org/x/exp/maps"
  13. )
  14. type containerGGUF struct {
  15. ByteOrder binary.ByteOrder
  16. Version uint32
  17. V1 struct {
  18. NumTensor uint32
  19. NumKV uint32
  20. }
  21. V2 struct {
  22. NumTensor uint64
  23. NumKV uint64
  24. }
  25. V3 struct {
  26. NumTensor uint64
  27. NumKV uint64
  28. }
  29. maxArraySize int
  30. }
  31. func (c *containerGGUF) canCollectArray(size int) bool {
  32. return c.maxArraySize < 0 || size <= c.maxArraySize
  33. }
  34. func (c *containerGGUF) Name() string {
  35. return "gguf"
  36. }
  37. func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
  38. if err := binary.Read(rs, c.ByteOrder, &c.Version); err != nil {
  39. return nil, err
  40. }
  41. var err error
  42. switch c.Version {
  43. case 1:
  44. err = binary.Read(rs, c.ByteOrder, &c.V1)
  45. case 2:
  46. err = binary.Read(rs, c.ByteOrder, &c.V2)
  47. default:
  48. err = binary.Read(rs, c.ByteOrder, &c.V3)
  49. }
  50. if err != nil {
  51. return nil, err
  52. }
  53. model := newGGUF(c)
  54. if err := model.Decode(rs); err != nil {
  55. return nil, err
  56. }
  57. return model, nil
  58. }
  59. const (
  60. ggufTypeUint8 uint32 = iota
  61. ggufTypeInt8
  62. ggufTypeUint16
  63. ggufTypeInt16
  64. ggufTypeUint32
  65. ggufTypeInt32
  66. ggufTypeFloat32
  67. ggufTypeBool
  68. ggufTypeString
  69. ggufTypeArray
  70. ggufTypeUint64
  71. ggufTypeInt64
  72. ggufTypeFloat64
  73. )
  74. type gguf struct {
  75. *containerGGUF
  76. kv KV
  77. tensors []*Tensor
  78. offset int64
  79. parameters uint64
  80. scratch [16 << 10]byte
  81. }
  82. func newGGUF(container *containerGGUF) *gguf {
  83. return &gguf{
  84. containerGGUF: container,
  85. kv: make(KV),
  86. }
  87. }
  88. func NewGGUFV3(bo binary.ByteOrder) *gguf {
  89. return newGGUF(&containerGGUF{ByteOrder: bo, Version: 3})
  90. }
  91. func (llm *gguf) KV() KV {
  92. return llm.kv
  93. }
  94. func (llm *gguf) Tensors() Tensors {
  95. return Tensors{
  96. Items: llm.tensors,
  97. Offset: llm.offset,
  98. }
  99. }
  100. func (llm *gguf) numTensor() uint64 {
  101. switch llm.Version {
  102. case 1:
  103. return uint64(llm.V1.NumTensor)
  104. case 2:
  105. return llm.V2.NumTensor
  106. default:
  107. return llm.V3.NumTensor
  108. }
  109. }
  110. func (llm *gguf) numKV() uint64 {
  111. switch llm.Version {
  112. case 1:
  113. return uint64(llm.V1.NumKV)
  114. case 2:
  115. return llm.V2.NumKV
  116. default:
  117. return llm.V3.NumKV
  118. }
  119. }
  120. func (llm *gguf) Decode(rs io.ReadSeeker) error {
  121. // decode key-values
  122. for i := 0; uint64(i) < llm.numKV(); i++ {
  123. k, err := readGGUFString(llm, rs)
  124. if err != nil {
  125. return err
  126. }
  127. t, err := readGGUF[uint32](llm, rs)
  128. if err != nil {
  129. return err
  130. }
  131. var v any
  132. switch t {
  133. case ggufTypeUint8:
  134. v, err = readGGUF[uint8](llm, rs)
  135. case ggufTypeInt8:
  136. v, err = readGGUF[int8](llm, rs)
  137. case ggufTypeUint16:
  138. v, err = readGGUF[uint16](llm, rs)
  139. case ggufTypeInt16:
  140. v, err = readGGUF[int16](llm, rs)
  141. case ggufTypeUint32:
  142. v, err = readGGUF[uint32](llm, rs)
  143. case ggufTypeInt32:
  144. v, err = readGGUF[int32](llm, rs)
  145. case ggufTypeUint64:
  146. v, err = readGGUF[uint64](llm, rs)
  147. case ggufTypeInt64:
  148. v, err = readGGUF[int64](llm, rs)
  149. case ggufTypeFloat32:
  150. v, err = readGGUF[float32](llm, rs)
  151. case ggufTypeFloat64:
  152. v, err = readGGUF[float64](llm, rs)
  153. case ggufTypeBool:
  154. v, err = readGGUF[bool](llm, rs)
  155. case ggufTypeString:
  156. v, err = readGGUFString(llm, rs)
  157. case ggufTypeArray:
  158. v, err = readGGUFArray(llm, rs)
  159. default:
  160. return fmt.Errorf("invalid type: %d", t)
  161. }
  162. if err != nil {
  163. return err
  164. }
  165. llm.kv[k] = v
  166. }
  167. // decode tensors
  168. for range llm.numTensor() {
  169. name, err := readGGUFString(llm, rs)
  170. if err != nil {
  171. return fmt.Errorf("failed to read tensor name: %w", err)
  172. }
  173. // dims is the number of dimensions in the tensor
  174. dims, err := readGGUF[uint32](llm, rs)
  175. if err != nil {
  176. return fmt.Errorf("failed to read tensor dimensions: %w", err)
  177. }
  178. shape := []uint64{}
  179. for i := 0; uint32(i) < dims; i++ {
  180. shapeVal, err := readGGUF[uint64](llm, rs)
  181. if err != nil {
  182. return fmt.Errorf("failed to read tensor shape: %w", err)
  183. }
  184. shape = append(shape, shapeVal)
  185. }
  186. kind, err := readGGUF[uint32](llm, rs)
  187. if err != nil {
  188. return fmt.Errorf("failed to read tensor kind: %w", err)
  189. }
  190. offset, err := readGGUF[uint64](llm, rs)
  191. if err != nil {
  192. return fmt.Errorf("failed to read tensor offset: %w", err)
  193. }
  194. tensor := Tensor{
  195. Name: name,
  196. Kind: kind,
  197. Offset: offset,
  198. Shape: shape,
  199. }
  200. llm.tensors = append(llm.tensors, &tensor)
  201. llm.parameters += tensor.parameters()
  202. }
  203. // patch KV with parameter count
  204. llm.kv["general.parameter_count"] = llm.parameters
  205. alignment, ok := llm.kv["general.alignment"].(uint32)
  206. if !ok {
  207. alignment = 32
  208. }
  209. offset, err := rs.Seek(0, io.SeekCurrent)
  210. if err != nil {
  211. return fmt.Errorf("failed to get current offset: %w", err)
  212. }
  213. // align to next 32-byte boundary
  214. llm.offset = offset + llm.padding(offset, int64(alignment))
  215. for _, tensor := range llm.tensors {
  216. offset, err := rs.Seek(0, io.SeekCurrent)
  217. if err != nil {
  218. return fmt.Errorf("failed to get current offset: %w", err)
  219. }
  220. padding := llm.padding(offset, int64(alignment))
  221. if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
  222. return fmt.Errorf("failed to seek to init padding: %w", err)
  223. }
  224. if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil {
  225. return fmt.Errorf("failed to seek to tensor: %w", err)
  226. }
  227. }
  228. return nil
  229. }
  230. func readGGUF[T any](llm *gguf, r io.Reader) (T, error) {
  231. var t T
  232. err := binary.Read(r, llm.ByteOrder, &t)
  233. return t, err
  234. }
  235. func writeGGUF[V any](w io.Writer, t uint32, v V) error {
  236. if err := binary.Write(w, binary.LittleEndian, t); err != nil {
  237. return err
  238. }
  239. return binary.Write(w, binary.LittleEndian, v)
  240. }
  241. func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
  242. var length uint64
  243. if err := binary.Read(r, llm.ByteOrder, &length); err != nil {
  244. return "", err
  245. }
  246. var b bytes.Buffer
  247. if _, err := io.CopyN(&b, r, int64(length)); err != nil {
  248. return "", err
  249. }
  250. // gguf v1 strings are null-terminated
  251. b.Truncate(b.Len() - 1)
  252. return b.String(), nil
  253. }
  254. func discardGGUFString(llm *gguf, r io.Reader) error {
  255. buf := llm.scratch[:8]
  256. _, err := io.ReadFull(r, buf)
  257. if err != nil {
  258. return err
  259. }
  260. size := int(llm.ByteOrder.Uint64(buf))
  261. for size > 0 {
  262. n, err := r.Read(llm.scratch[:min(size, cap(llm.scratch))])
  263. if err != nil {
  264. return err
  265. }
  266. size -= n
  267. }
  268. return nil
  269. }
  270. func readGGUFString(llm *gguf, r io.Reader) (string, error) {
  271. if llm.Version == 1 {
  272. return readGGUFV1String(llm, r)
  273. }
  274. buf := llm.scratch[:8]
  275. _, err := io.ReadFull(r, buf)
  276. if err != nil {
  277. return "", err
  278. }
  279. length := int(llm.ByteOrder.Uint64(buf))
  280. if length > len(llm.scratch) {
  281. buf = make([]byte, length)
  282. } else {
  283. buf = llm.scratch[:length]
  284. }
  285. clear(buf)
  286. _, err = io.ReadFull(r, buf)
  287. if err != nil {
  288. return "", err
  289. }
  290. return string(buf), nil
  291. }
  292. func writeGGUFString(w io.Writer, s string) error {
  293. if err := binary.Write(w, binary.LittleEndian, ggufTypeString); err != nil {
  294. return err
  295. }
  296. if err := binary.Write(w, binary.LittleEndian, uint64(len(s))); err != nil {
  297. return err
  298. }
  299. _, err := io.Copy(w, strings.NewReader(s))
  300. return err
  301. }
  302. type array struct {
  303. size int
  304. values []any
  305. datatype uint32
  306. }
  307. func (a *array) MarshalJSON() ([]byte, error) {
  308. return json.Marshal(a.values)
  309. }
  310. func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
  311. t, err := readGGUF[uint32](llm, r)
  312. if err != nil {
  313. return nil, err
  314. }
  315. n, err := readGGUF[uint32](llm, r)
  316. if err != nil {
  317. return nil, err
  318. }
  319. a := &array{size: int(n)}
  320. if llm.canCollectArray(int(n)) {
  321. a.values = make([]any, 0, int(n))
  322. }
  323. for i := range n {
  324. var e any
  325. switch t {
  326. case ggufTypeUint8:
  327. e, err = readGGUF[uint8](llm, r)
  328. case ggufTypeInt8:
  329. e, err = readGGUF[int8](llm, r)
  330. case ggufTypeUint16:
  331. e, err = readGGUF[uint16](llm, r)
  332. case ggufTypeInt16:
  333. e, err = readGGUF[int16](llm, r)
  334. case ggufTypeUint32:
  335. e, err = readGGUF[uint32](llm, r)
  336. case ggufTypeInt32:
  337. e, err = readGGUF[int32](llm, r)
  338. case ggufTypeUint64:
  339. e, err = readGGUF[uint64](llm, r)
  340. case ggufTypeInt64:
  341. e, err = readGGUF[int64](llm, r)
  342. case ggufTypeFloat32:
  343. e, err = readGGUF[float32](llm, r)
  344. case ggufTypeFloat64:
  345. e, err = readGGUF[float64](llm, r)
  346. case ggufTypeBool:
  347. e, err = readGGUF[bool](llm, r)
  348. case ggufTypeString:
  349. e, err = readGGUFV1String(llm, r)
  350. default:
  351. return nil, fmt.Errorf("invalid array type: %d", t)
  352. }
  353. if err != nil {
  354. return nil, err
  355. }
  356. if a.values != nil {
  357. a.values[i] = e
  358. }
  359. }
  360. return a, nil
  361. }
  362. func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
  363. if llm.Version == 1 {
  364. return readGGUFV1Array(llm, r)
  365. }
  366. t, err := readGGUF[uint32](llm, r)
  367. if err != nil {
  368. return nil, err
  369. }
  370. n, err := readGGUF[uint64](llm, r)
  371. if err != nil {
  372. return nil, err
  373. }
  374. a := &array{size: int(n), datatype: t}
  375. if llm.canCollectArray(int(n)) {
  376. a.values = make([]any, int(n))
  377. }
  378. for i := range n {
  379. var e any
  380. switch t {
  381. case ggufTypeUint8:
  382. e, err = readGGUF[uint8](llm, r)
  383. case ggufTypeInt8:
  384. e, err = readGGUF[int8](llm, r)
  385. case ggufTypeUint16:
  386. e, err = readGGUF[uint16](llm, r)
  387. case ggufTypeInt16:
  388. e, err = readGGUF[int16](llm, r)
  389. case ggufTypeUint32:
  390. e, err = readGGUF[uint32](llm, r)
  391. case ggufTypeInt32:
  392. e, err = readGGUF[int32](llm, r)
  393. case ggufTypeUint64:
  394. e, err = readGGUF[uint64](llm, r)
  395. case ggufTypeInt64:
  396. e, err = readGGUF[int64](llm, r)
  397. case ggufTypeFloat32:
  398. e, err = readGGUF[float32](llm, r)
  399. case ggufTypeFloat64:
  400. e, err = readGGUF[float64](llm, r)
  401. case ggufTypeBool:
  402. e, err = readGGUF[bool](llm, r)
  403. case ggufTypeString:
  404. if a.values != nil {
  405. e, err = readGGUFString(llm, r)
  406. } else {
  407. err = discardGGUFString(llm, r)
  408. }
  409. default:
  410. return nil, fmt.Errorf("invalid array type: %d", t)
  411. }
  412. if err != nil {
  413. return nil, err
  414. }
  415. if a.values != nil {
  416. a.values[i] = e
  417. }
  418. }
  419. return a, nil
  420. }
  421. func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
  422. if err := binary.Write(w, binary.LittleEndian, ggufTypeArray); err != nil {
  423. return err
  424. }
  425. if err := binary.Write(w, binary.LittleEndian, t); err != nil {
  426. return err
  427. }
  428. if err := binary.Write(w, binary.LittleEndian, uint64(len(s))); err != nil {
  429. return err
  430. }
  431. for _, e := range s {
  432. if err := binary.Write(w, binary.LittleEndian, e); err != nil {
  433. return err
  434. }
  435. }
  436. return nil
  437. }
  438. var ggufKVOrder = map[string][]string{
  439. "llama": {
  440. "general.architecture",
  441. "general.name",
  442. "llama.vocab_size",
  443. "llama.context_length",
  444. "llama.embedding_length",
  445. "llama.block_count",
  446. "llama.feed_forward_length",
  447. "llama.attention.head_count",
  448. "llama.attention.head_count_kv",
  449. "llama.attention.layer_norm_rms_epsilon",
  450. "llama.rope.freq_base",
  451. "llama.rope.dimension_count",
  452. "llama.expert_count",
  453. "llama.expert_used_count",
  454. "gemma.context_length",
  455. "gemma.embedding_length",
  456. "gemma.block_count",
  457. "gemma.feed_forward_length",
  458. "gemma.attention.head_count",
  459. "gemma.attention.head_count_kv",
  460. "gemma.attention.layer_norm_rms_epsilon",
  461. "gemma.attention.key_length",
  462. "gemma.attention.value_length",
  463. "general.file_type",
  464. "tokenizer.ggml.pre",
  465. "tokenizer.ggml.model",
  466. "tokenizer.ggml.tokens",
  467. "tokenizer.ggml.scores",
  468. "tokenizer.ggml.merges",
  469. "tokenizer.ggml.token_type",
  470. "tokenizer.ggml.bos_token_id",
  471. "tokenizer.ggml.eos_token_id",
  472. "tokenizer.ggml.unknown_token_id",
  473. "tokenizer.ggml.padding_token_id",
  474. "tokenizer.ggml.add_bos_token",
  475. "tokenizer.ggml.add_eos_token",
  476. "tokenizer.chat_template",
  477. "bert.pooling_type",
  478. },
  479. }
  480. func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
  481. switch llm.Version {
  482. case 3:
  483. llm.V3.NumTensor = uint64(len(tensors))
  484. llm.V3.NumKV = uint64(len(kv))
  485. default:
  486. return fmt.Errorf("not implemented: ggufv%d", llm.Version)
  487. }
  488. if err := binary.Write(ws, llm.ByteOrder, []byte("GGUF")); err != nil {
  489. return err
  490. }
  491. if err := binary.Write(ws, llm.ByteOrder, llm.Version); err != nil {
  492. return err
  493. }
  494. if err := binary.Write(ws, llm.ByteOrder, llm.numTensor()); err != nil {
  495. return err
  496. }
  497. if err := binary.Write(ws, llm.ByteOrder, llm.numKV()); err != nil {
  498. return err
  499. }
  500. kvCheck := make(map[string]bool)
  501. for k := range kv {
  502. kvCheck[k] = false
  503. }
  504. for _, k := range ggufKVOrder["llama"] {
  505. v, ok := kv[k]
  506. if !ok {
  507. continue
  508. }
  509. kvCheck[k] = true
  510. if err := binary.Write(ws, llm.ByteOrder, uint64(len(k))); err != nil {
  511. return err
  512. }
  513. if err := binary.Write(ws, llm.ByteOrder, []byte(k)); err != nil {
  514. return err
  515. }
  516. var err error
  517. switch v := v.(type) {
  518. case uint32:
  519. err = writeGGUF(ws, ggufTypeUint32, v)
  520. case float32:
  521. err = writeGGUF(ws, ggufTypeFloat32, v)
  522. case bool:
  523. err = writeGGUF(ws, ggufTypeBool, v)
  524. case string:
  525. err = writeGGUFString(ws, v)
  526. case []int32:
  527. err = writeGGUFArray(ws, ggufTypeInt32, v)
  528. case []uint32:
  529. err = writeGGUFArray(ws, ggufTypeUint32, v)
  530. case []float32:
  531. err = writeGGUFArray(ws, ggufTypeFloat32, v)
  532. case []string:
  533. if err := binary.Write(ws, llm.ByteOrder, ggufTypeArray); err != nil {
  534. return err
  535. }
  536. if err := binary.Write(ws, llm.ByteOrder, ggufTypeString); err != nil {
  537. return err
  538. }
  539. if err := binary.Write(ws, llm.ByteOrder, uint64(len(v))); err != nil {
  540. return err
  541. }
  542. for _, e := range v {
  543. if err := binary.Write(ws, llm.ByteOrder, uint64(len(e))); err != nil {
  544. return err
  545. }
  546. if err := binary.Write(ws, llm.ByteOrder, []byte(e)); err != nil {
  547. return err
  548. }
  549. }
  550. default:
  551. return fmt.Errorf("improper type for '%s'", k)
  552. }
  553. if err != nil {
  554. return err
  555. }
  556. }
  557. for k, v := range kvCheck {
  558. if !v {
  559. return fmt.Errorf("didn't know how to write kv %s", k)
  560. }
  561. }
  562. for _, tensor := range tensors {
  563. if err := binary.Write(ws, llm.ByteOrder, uint64(len(tensor.Name))); err != nil {
  564. return err
  565. }
  566. if err := binary.Write(ws, llm.ByteOrder, []byte(tensor.Name)); err != nil {
  567. return err
  568. }
  569. var dims int
  570. for cnt := range len(tensor.Shape) {
  571. if tensor.Shape[cnt] > 0 {
  572. dims++
  573. }
  574. }
  575. if err := binary.Write(ws, llm.ByteOrder, uint32(dims)); err != nil {
  576. return err
  577. }
  578. for i := range dims {
  579. if err := binary.Write(ws, llm.ByteOrder, tensor.Shape[dims-1-i]); err != nil {
  580. return err
  581. }
  582. }
  583. if err := binary.Write(ws, llm.ByteOrder, tensor.Kind); err != nil {
  584. return err
  585. }
  586. if err := binary.Write(ws, llm.ByteOrder, tensor.Offset); err != nil {
  587. return err
  588. }
  589. }
  590. var alignment int64 = 32
  591. for _, tensor := range tensors {
  592. offset, err := ws.Seek(0, io.SeekCurrent)
  593. if err != nil {
  594. return err
  595. }
  596. padding := llm.padding(offset, alignment)
  597. if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
  598. return err
  599. }
  600. if _, err := tensor.WriteTo(ws); err != nil {
  601. return err
  602. }
  603. }
  604. return nil
  605. }
  606. func (gguf) padding(offset, align int64) int64 {
  607. return (align - offset%align) % align
  608. }
  609. // Reader and WriterTof
  610. type GGUFWriter struct {
  611. KV
  612. Tensors
  613. }
  614. type writeOffset struct {
  615. io.Writer
  616. offset int
  617. }
  618. func (wo *writeOffset) Write(p []byte) (int, error) {
  619. n, err := wo.Writer.Write(p)
  620. wo.offset += n
  621. return n, err
  622. }
  623. var _ io.Reader = (*GGUFWriter)(nil)
  624. var _ io.WriterTo = (*GGUFWriter)(nil)
  625. func (GGUFWriter) Read([]byte) (int, error) {
  626. panic("not implemeneted")
  627. }
  628. func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
  629. wo := &writeOffset{Writer: w}
  630. if err := binary.Write(wo, binary.LittleEndian, []byte("GGUF")); err != nil {
  631. return 0, err
  632. }
  633. if err := binary.Write(wo, binary.LittleEndian, uint32(3)); err != nil {
  634. return 0, err
  635. }
  636. if err := binary.Write(wo, binary.LittleEndian, uint64(len(gguf.Tensors.Items))); err != nil {
  637. return 0, err
  638. }
  639. if err := binary.Write(wo, binary.LittleEndian, uint64(len(gguf.KV)-1)); err != nil {
  640. return 0, err
  641. }
  642. keys := maps.Keys(gguf.KV)
  643. slices.Sort(keys)
  644. for _, key := range keys {
  645. switch key {
  646. case "general.parameter_count":
  647. // don't write general param count as its added in by us
  648. continue
  649. default:
  650. if err := ggufWriteKV(wo, key, gguf.KV[key]); err != nil {
  651. return 0, err
  652. }
  653. }
  654. }
  655. sort.Sort(gguf.Tensors)
  656. var s uint64
  657. for _, t := range gguf.Tensors.Items {
  658. t.Offset = s
  659. if err := ggufWriteTensorInfo(wo, t); err != nil {
  660. return 0, err
  661. }
  662. s += t.Size()
  663. }
  664. tensorOffset := wo.offset
  665. for _, t := range gguf.Tensors.Items {
  666. if err := ggufWriteTensor(wo, t, wo.offset); err != nil {
  667. return 0, err
  668. }
  669. }
  670. return int64(tensorOffset), nil
  671. }
  672. func ggufWriteTensorInfo(ws io.Writer, t *Tensor) error {
  673. if err := binary.Write(ws, binary.LittleEndian, uint64(len(t.Name))); err != nil {
  674. return err
  675. }
  676. if err := binary.Write(ws, binary.LittleEndian, []byte(t.Name)); err != nil {
  677. return err
  678. }
  679. if err := binary.Write(ws, binary.LittleEndian, uint32(len(t.Shape))); err != nil {
  680. return err
  681. }
  682. for i := range len(t.Shape) {
  683. if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil {
  684. return err
  685. }
  686. }
  687. if err := binary.Write(ws, binary.LittleEndian, t.Kind); err != nil {
  688. return err
  689. }
  690. return binary.Write(ws, binary.LittleEndian, t.Offset)
  691. }
  692. func ggufWriteTensor(ws io.Writer, t *Tensor, offset int) error {
  693. slog.Debug(t.Name, "kind", t.Kind, "shape", t.Shape, "offset", t.Offset)
  694. if err := binary.Write(ws, binary.LittleEndian, bytes.Repeat([]byte{0}, int(ggufPadding(int64(offset), 32)))); err != nil {
  695. return err
  696. }
  697. _, err := t.WriteTo(ws)
  698. return err
  699. }
  700. func ggufWriteKV(ws io.Writer, k string, v any) error {
  701. slog.Debug(k, "type", fmt.Sprintf("%T", v))
  702. if err := binary.Write(ws, binary.LittleEndian, uint64(len(k))); err != nil {
  703. return err
  704. }
  705. if err := binary.Write(ws, binary.LittleEndian, []byte(k)); err != nil {
  706. return err
  707. }
  708. var err error
  709. switch v := v.(type) {
  710. case uint32:
  711. err = writeGGUF(ws, ggufTypeUint32, v)
  712. case float32:
  713. err = writeGGUF(ws, ggufTypeFloat32, v)
  714. case bool:
  715. err = writeGGUF(ws, ggufTypeBool, v)
  716. case string:
  717. err = writeGGUFString(ws, v)
  718. case []int32:
  719. err = writeGGUFArray(ws, ggufTypeInt32, v)
  720. case []uint32:
  721. err = writeGGUFArray(ws, ggufTypeUint32, v)
  722. case []float32:
  723. err = writeGGUFArray(ws, ggufTypeFloat32, v)
  724. case []string:
  725. if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
  726. return err
  727. }
  728. if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
  729. return err
  730. }
  731. if err := binary.Write(ws, binary.LittleEndian, uint64(len(v))); err != nil {
  732. return err
  733. }
  734. for _, e := range v {
  735. if err := binary.Write(ws, binary.LittleEndian, uint64(len(e))); err != nil {
  736. return err
  737. }
  738. if err := binary.Write(ws, binary.LittleEndian, []byte(e)); err != nil {
  739. return err
  740. }
  741. }
  742. case *array:
  743. if v.size > 0 {
  744. switch v.values[0].(type) {
  745. case string:
  746. if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
  747. return err
  748. }
  749. if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
  750. return err
  751. }
  752. if err := binary.Write(ws, binary.LittleEndian, uint64(v.size)); err != nil {
  753. return err
  754. }
  755. for _, e := range v.values {
  756. if err := binary.Write(ws, binary.LittleEndian, uint64(len(e.(string)))); err != nil {
  757. return err
  758. }
  759. if err := binary.Write(ws, binary.LittleEndian, []byte(e.(string))); err != nil {
  760. return err
  761. }
  762. }
  763. default:
  764. err = writeGGUFArray(ws, v.datatype, v.values)
  765. }
  766. }
  767. default:
  768. return fmt.Errorf("improper type for '%s'", k)
  769. }
  770. return err
  771. }
  772. func ggufPadding(offset, align int64) int64 {
  773. // we mod twice in the case offset%align = 0
  774. return (align - offset%align) % align
  775. }