gguf.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913
  1. package llm
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "log/slog"
  9. "slices"
  10. "sort"
  11. "strings"
  12. "golang.org/x/exp/maps"
  13. )
  14. type containerGGUF struct {
  15. ByteOrder binary.ByteOrder
  16. Version uint32
  17. V1 struct {
  18. NumTensor uint32
  19. NumKV uint32
  20. }
  21. V2 struct {
  22. NumTensor uint64
  23. NumKV uint64
  24. }
  25. V3 struct {
  26. NumTensor uint64
  27. NumKV uint64
  28. }
  29. maxArraySize int
  30. }
  31. func (c *containerGGUF) canCollectArray(size int) bool {
  32. return c.maxArraySize < 0 || size <= c.maxArraySize
  33. }
  34. func (c *containerGGUF) Name() string {
  35. return "gguf"
  36. }
  37. func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
  38. if err := binary.Read(rs, c.ByteOrder, &c.Version); err != nil {
  39. return nil, err
  40. }
  41. var err error
  42. switch c.Version {
  43. case 1:
  44. err = binary.Read(rs, c.ByteOrder, &c.V1)
  45. case 2:
  46. err = binary.Read(rs, c.ByteOrder, &c.V2)
  47. default:
  48. err = binary.Read(rs, c.ByteOrder, &c.V3)
  49. }
  50. if err != nil {
  51. return nil, err
  52. }
  53. model := newGGUF(c)
  54. if err := model.Decode(rs); err != nil {
  55. return nil, err
  56. }
  57. return model, nil
  58. }
  59. const (
  60. ggufTypeUint8 uint32 = iota
  61. ggufTypeInt8
  62. ggufTypeUint16
  63. ggufTypeInt16
  64. ggufTypeUint32
  65. ggufTypeInt32
  66. ggufTypeFloat32
  67. ggufTypeBool
  68. ggufTypeString
  69. ggufTypeArray
  70. ggufTypeUint64
  71. ggufTypeInt64
  72. ggufTypeFloat64
  73. )
  74. type gguf struct {
  75. *containerGGUF
  76. kv KV
  77. tensors []*Tensor
  78. parameters uint64
  79. scratch [16 << 10]byte
  80. }
  81. func newGGUF(container *containerGGUF) *gguf {
  82. return &gguf{
  83. containerGGUF: container,
  84. kv: make(KV),
  85. }
  86. }
  87. func NewGGUFV3(bo binary.ByteOrder) *gguf {
  88. return newGGUF(&containerGGUF{ByteOrder: bo, Version: 3})
  89. }
  90. func (llm *gguf) KV() KV {
  91. return llm.kv
  92. }
  93. func (llm *gguf) Tensors() Tensors {
  94. return llm.tensors
  95. }
  96. func (llm *gguf) numTensor() uint64 {
  97. switch llm.Version {
  98. case 1:
  99. return uint64(llm.V1.NumTensor)
  100. case 2:
  101. return llm.V2.NumTensor
  102. default:
  103. return llm.V3.NumTensor
  104. }
  105. }
  106. func (llm *gguf) numKV() uint64 {
  107. switch llm.Version {
  108. case 1:
  109. return uint64(llm.V1.NumKV)
  110. case 2:
  111. return llm.V2.NumKV
  112. default:
  113. return llm.V3.NumKV
  114. }
  115. }
  116. func (llm *gguf) Decode(rs io.ReadSeeker) error {
  117. // decode key-values
  118. fmt.Println(llm.numKV())
  119. for i := 0; uint64(i) < llm.numKV(); i++ {
  120. k, err := readGGUFString(llm, rs)
  121. if err != nil {
  122. return err
  123. }
  124. fmt.Printf("k: %#v\n", k)
  125. t, err := readGGUF[uint32](llm, rs)
  126. if err != nil {
  127. return err
  128. }
  129. var v any
  130. switch t {
  131. case ggufTypeUint8:
  132. v, err = readGGUF[uint8](llm, rs)
  133. case ggufTypeInt8:
  134. v, err = readGGUF[int8](llm, rs)
  135. case ggufTypeUint16:
  136. v, err = readGGUF[uint16](llm, rs)
  137. case ggufTypeInt16:
  138. v, err = readGGUF[int16](llm, rs)
  139. case ggufTypeUint32:
  140. v, err = readGGUF[uint32](llm, rs)
  141. case ggufTypeInt32:
  142. v, err = readGGUF[int32](llm, rs)
  143. case ggufTypeUint64:
  144. v, err = readGGUF[uint64](llm, rs)
  145. case ggufTypeInt64:
  146. v, err = readGGUF[int64](llm, rs)
  147. case ggufTypeFloat32:
  148. v, err = readGGUF[float32](llm, rs)
  149. case ggufTypeFloat64:
  150. v, err = readGGUF[float64](llm, rs)
  151. case ggufTypeBool:
  152. v, err = readGGUF[bool](llm, rs)
  153. case ggufTypeString:
  154. v, err = readGGUFString(llm, rs)
  155. case ggufTypeArray:
  156. v, err = readGGUFArray(llm, rs)
  157. default:
  158. return fmt.Errorf("invalid type: %d", t)
  159. }
  160. if err != nil {
  161. return err
  162. }
  163. llm.kv[k] = v
  164. }
  165. // decode tensors
  166. for range llm.numTensor() {
  167. name, err := readGGUFString(llm, rs)
  168. if err != nil {
  169. return fmt.Errorf("failed to read tensor name: %w", err)
  170. }
  171. // dims is the number of dimensions in the tensor
  172. dims, err := readGGUF[uint32](llm, rs)
  173. if err != nil {
  174. return fmt.Errorf("failed to read tensor dimensions: %w", err)
  175. }
  176. shape := []uint64{}
  177. for i := 0; uint32(i) < dims; i++ {
  178. shapeVal, err := readGGUF[uint64](llm, rs)
  179. if err != nil {
  180. return fmt.Errorf("failed to read tensor shape: %w", err)
  181. }
  182. shape = append(shape, shapeVal)
  183. }
  184. fmt.Println("tensor ", name, " shape ", shape)
  185. kind, err := readGGUF[uint32](llm, rs)
  186. if err != nil {
  187. return fmt.Errorf("failed to read tensor kind: %w", err)
  188. }
  189. offset, err := readGGUF[uint64](llm, rs)
  190. if err != nil {
  191. return fmt.Errorf("failed to read tensor offset: %w", err)
  192. }
  193. tensor := Tensor{
  194. Name: name,
  195. Kind: kind,
  196. Offset: offset,
  197. Shape: shape,
  198. }
  199. llm.tensors = append(llm.tensors, &tensor)
  200. llm.parameters += tensor.parameters()
  201. }
  202. // patch KV with parameter count
  203. llm.kv["general.parameter_count"] = llm.parameters
  204. alignment, ok := llm.kv["general.alignment"].(uint32)
  205. if !ok {
  206. alignment = 32
  207. }
  208. for _, tensor := range llm.tensors {
  209. offset, err := rs.Seek(0, io.SeekCurrent)
  210. if err != nil {
  211. return fmt.Errorf("failed to get current offset: %w", err)
  212. }
  213. padding := llm.padding(offset, int64(alignment))
  214. if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
  215. return fmt.Errorf("failed to seek to init padding: %w", err)
  216. }
  217. if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil {
  218. return fmt.Errorf("failed to seek to tensor: %w", err)
  219. }
  220. }
  221. return nil
  222. }
  223. func readGGUF[T any](llm *gguf, r io.Reader) (T, error) {
  224. var t T
  225. err := binary.Read(r, llm.ByteOrder, &t)
  226. return t, err
  227. }
  228. func writeGGUF[V any](w io.Writer, t uint32, v V) error {
  229. if err := binary.Write(w, binary.LittleEndian, t); err != nil {
  230. return err
  231. }
  232. return binary.Write(w, binary.LittleEndian, v)
  233. }
  234. func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
  235. var length uint64
  236. if err := binary.Read(r, llm.ByteOrder, &length); err != nil {
  237. return "", err
  238. }
  239. var b bytes.Buffer
  240. if _, err := io.CopyN(&b, r, int64(length)); err != nil {
  241. return "", err
  242. }
  243. // gguf v1 strings are null-terminated
  244. b.Truncate(b.Len() - 1)
  245. return b.String(), nil
  246. }
  247. func discardGGUFString(llm *gguf, r io.Reader) error {
  248. buf := llm.scratch[:8]
  249. _, err := io.ReadFull(r, buf)
  250. if err != nil {
  251. return err
  252. }
  253. size := int(llm.ByteOrder.Uint64(buf))
  254. for size > 0 {
  255. n, err := r.Read(llm.scratch[:min(size, cap(llm.scratch))])
  256. if err != nil {
  257. return err
  258. }
  259. size -= n
  260. }
  261. return nil
  262. }
  263. func readGGUFString(llm *gguf, r io.Reader) (string, error) {
  264. if llm.Version == 1 {
  265. return readGGUFV1String(llm, r)
  266. }
  267. buf := llm.scratch[:8]
  268. _, err := io.ReadFull(r, buf)
  269. if err != nil {
  270. return "", err
  271. }
  272. length := int(llm.ByteOrder.Uint64(buf))
  273. if length > len(llm.scratch) {
  274. buf = make([]byte, length)
  275. } else {
  276. buf = llm.scratch[:length]
  277. }
  278. clear(buf)
  279. _, err = io.ReadFull(r, buf)
  280. if err != nil {
  281. return "", err
  282. }
  283. return string(buf), nil
  284. }
  285. func writeGGUFString(w io.Writer, s string) error {
  286. if err := binary.Write(w, binary.LittleEndian, ggufTypeString); err != nil {
  287. return err
  288. }
  289. if err := binary.Write(w, binary.LittleEndian, uint64(len(s))); err != nil {
  290. return err
  291. }
  292. _, err := io.Copy(w, strings.NewReader(s))
  293. return err
  294. }
  295. type array struct {
  296. size int
  297. values []any
  298. datatype uint32
  299. }
  300. func (a *array) MarshalJSON() ([]byte, error) {
  301. return json.Marshal(a.values)
  302. }
  303. func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
  304. t, err := readGGUF[uint32](llm, r)
  305. if err != nil {
  306. return nil, err
  307. }
  308. n, err := readGGUF[uint32](llm, r)
  309. if err != nil {
  310. return nil, err
  311. }
  312. a := &array{size: int(n)}
  313. if llm.canCollectArray(int(n)) {
  314. a.values = make([]any, 0, int(n))
  315. }
  316. for i := range n {
  317. var e any
  318. switch t {
  319. case ggufTypeUint8:
  320. e, err = readGGUF[uint8](llm, r)
  321. case ggufTypeInt8:
  322. e, err = readGGUF[int8](llm, r)
  323. case ggufTypeUint16:
  324. e, err = readGGUF[uint16](llm, r)
  325. case ggufTypeInt16:
  326. e, err = readGGUF[int16](llm, r)
  327. case ggufTypeUint32:
  328. e, err = readGGUF[uint32](llm, r)
  329. case ggufTypeInt32:
  330. e, err = readGGUF[int32](llm, r)
  331. case ggufTypeUint64:
  332. e, err = readGGUF[uint64](llm, r)
  333. case ggufTypeInt64:
  334. e, err = readGGUF[int64](llm, r)
  335. case ggufTypeFloat32:
  336. e, err = readGGUF[float32](llm, r)
  337. case ggufTypeFloat64:
  338. e, err = readGGUF[float64](llm, r)
  339. case ggufTypeBool:
  340. e, err = readGGUF[bool](llm, r)
  341. case ggufTypeString:
  342. e, err = readGGUFV1String(llm, r)
  343. default:
  344. return nil, fmt.Errorf("invalid array type: %d", t)
  345. }
  346. if err != nil {
  347. return nil, err
  348. }
  349. if a.values != nil {
  350. a.values[i] = e
  351. }
  352. }
  353. return a, nil
  354. }
  355. func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
  356. if llm.Version == 1 {
  357. return readGGUFV1Array(llm, r)
  358. }
  359. t, err := readGGUF[uint32](llm, r)
  360. if err != nil {
  361. return nil, err
  362. }
  363. n, err := readGGUF[uint64](llm, r)
  364. if err != nil {
  365. return nil, err
  366. }
  367. a := &array{size: int(n), datatype: t}
  368. if llm.canCollectArray(int(n)) {
  369. a.values = make([]any, int(n))
  370. }
  371. for i := range n {
  372. var e any
  373. switch t {
  374. case ggufTypeUint8:
  375. e, err = readGGUF[uint8](llm, r)
  376. case ggufTypeInt8:
  377. e, err = readGGUF[int8](llm, r)
  378. case ggufTypeUint16:
  379. e, err = readGGUF[uint16](llm, r)
  380. case ggufTypeInt16:
  381. e, err = readGGUF[int16](llm, r)
  382. case ggufTypeUint32:
  383. e, err = readGGUF[uint32](llm, r)
  384. case ggufTypeInt32:
  385. e, err = readGGUF[int32](llm, r)
  386. case ggufTypeUint64:
  387. e, err = readGGUF[uint64](llm, r)
  388. case ggufTypeInt64:
  389. e, err = readGGUF[int64](llm, r)
  390. case ggufTypeFloat32:
  391. e, err = readGGUF[float32](llm, r)
  392. case ggufTypeFloat64:
  393. e, err = readGGUF[float64](llm, r)
  394. case ggufTypeBool:
  395. e, err = readGGUF[bool](llm, r)
  396. case ggufTypeString:
  397. if a.values != nil {
  398. e, err = readGGUFString(llm, r)
  399. } else {
  400. err = discardGGUFString(llm, r)
  401. }
  402. default:
  403. return nil, fmt.Errorf("invalid array type: %d", t)
  404. }
  405. if err != nil {
  406. return nil, err
  407. }
  408. if a.values != nil {
  409. a.values[i] = e
  410. }
  411. }
  412. return a, nil
  413. }
  414. func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
  415. if err := binary.Write(w, binary.LittleEndian, ggufTypeArray); err != nil {
  416. return err
  417. }
  418. if err := binary.Write(w, binary.LittleEndian, t); err != nil {
  419. return err
  420. }
  421. if err := binary.Write(w, binary.LittleEndian, uint64(len(s))); err != nil {
  422. return err
  423. }
  424. for _, e := range s {
  425. if err := binary.Write(w, binary.LittleEndian, e); err != nil {
  426. return err
  427. }
  428. }
  429. return nil
  430. }
  431. var ggufKVOrder = map[string][]string{
  432. "llama": {
  433. "general.architecture",
  434. "general.name",
  435. "llama.vocab_size",
  436. "llama.context_length",
  437. "llama.embedding_length",
  438. "llama.block_count",
  439. "llama.feed_forward_length",
  440. "llama.attention.head_count",
  441. "llama.attention.head_count_kv",
  442. "llama.attention.layer_norm_rms_epsilon",
  443. "llama.rope.freq_base",
  444. "llama.rope.dimension_count",
  445. "llama.expert_count",
  446. "llama.expert_used_count",
  447. "gemma.context_length",
  448. "gemma.embedding_length",
  449. "gemma.block_count",
  450. "gemma.feed_forward_length",
  451. "gemma.attention.head_count",
  452. "gemma.attention.head_count_kv",
  453. "gemma.attention.layer_norm_rms_epsilon",
  454. "gemma.attention.key_length",
  455. "gemma.attention.value_length",
  456. "general.file_type",
  457. "tokenizer.ggml.pre",
  458. "tokenizer.ggml.model",
  459. "tokenizer.ggml.tokens",
  460. "tokenizer.ggml.scores",
  461. "tokenizer.ggml.merges",
  462. "tokenizer.ggml.token_type",
  463. "tokenizer.ggml.bos_token_id",
  464. "tokenizer.ggml.eos_token_id",
  465. "tokenizer.ggml.unknown_token_id",
  466. "tokenizer.ggml.padding_token_id",
  467. "tokenizer.ggml.add_bos_token",
  468. "tokenizer.ggml.add_eos_token",
  469. "tokenizer.chat_template",
  470. },
  471. }
  472. func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
  473. switch llm.Version {
  474. case 3:
  475. llm.V3.NumTensor = uint64(len(tensors))
  476. llm.V3.NumKV = uint64(len(kv))
  477. default:
  478. return fmt.Errorf("not implemented: ggufv%d", llm.Version)
  479. }
  480. if err := binary.Write(ws, llm.ByteOrder, []byte("GGUF")); err != nil {
  481. return err
  482. }
  483. if err := binary.Write(ws, llm.ByteOrder, llm.Version); err != nil {
  484. return err
  485. }
  486. if err := binary.Write(ws, llm.ByteOrder, llm.numTensor()); err != nil {
  487. return err
  488. }
  489. if err := binary.Write(ws, llm.ByteOrder, llm.numKV()); err != nil {
  490. return err
  491. }
  492. kvCheck := make(map[string]bool)
  493. for k := range kv {
  494. kvCheck[k] = false
  495. }
  496. for _, k := range ggufKVOrder["llama"] {
  497. v, ok := kv[k]
  498. if !ok {
  499. continue
  500. }
  501. kvCheck[k] = true
  502. if err := binary.Write(ws, llm.ByteOrder, uint64(len(k))); err != nil {
  503. return err
  504. }
  505. if err := binary.Write(ws, llm.ByteOrder, []byte(k)); err != nil {
  506. return err
  507. }
  508. var err error
  509. switch v := v.(type) {
  510. case uint32:
  511. err = writeGGUF(ws, ggufTypeUint32, v)
  512. case float32:
  513. err = writeGGUF(ws, ggufTypeFloat32, v)
  514. case bool:
  515. err = writeGGUF(ws, ggufTypeBool, v)
  516. case string:
  517. err = writeGGUFString(ws, v)
  518. case []int32:
  519. err = writeGGUFArray(ws, ggufTypeInt32, v)
  520. case []uint32:
  521. err = writeGGUFArray(ws, ggufTypeUint32, v)
  522. case []float32:
  523. err = writeGGUFArray(ws, ggufTypeFloat32, v)
  524. case []string:
  525. if err := binary.Write(ws, llm.ByteOrder, ggufTypeArray); err != nil {
  526. return err
  527. }
  528. if err := binary.Write(ws, llm.ByteOrder, ggufTypeString); err != nil {
  529. return err
  530. }
  531. if err := binary.Write(ws, llm.ByteOrder, uint64(len(v))); err != nil {
  532. return err
  533. }
  534. for _, e := range v {
  535. if err := binary.Write(ws, llm.ByteOrder, uint64(len(e))); err != nil {
  536. return err
  537. }
  538. if err := binary.Write(ws, llm.ByteOrder, []byte(e)); err != nil {
  539. return err
  540. }
  541. }
  542. default:
  543. return fmt.Errorf("improper type for '%s'", k)
  544. }
  545. if err != nil {
  546. return err
  547. }
  548. }
  549. for k, v := range kvCheck {
  550. if !v {
  551. return fmt.Errorf("didn't know how to write kv %s", k)
  552. }
  553. }
  554. for _, tensor := range tensors {
  555. if err := binary.Write(ws, llm.ByteOrder, uint64(len(tensor.Name))); err != nil {
  556. return err
  557. }
  558. if err := binary.Write(ws, llm.ByteOrder, []byte(tensor.Name)); err != nil {
  559. return err
  560. }
  561. var dims int
  562. for cnt := range len(tensor.Shape) {
  563. if tensor.Shape[cnt] > 0 {
  564. dims++
  565. }
  566. }
  567. if err := binary.Write(ws, llm.ByteOrder, uint32(dims)); err != nil {
  568. return err
  569. }
  570. for i := range dims {
  571. if err := binary.Write(ws, llm.ByteOrder, tensor.Shape[dims-1-i]); err != nil {
  572. return err
  573. }
  574. }
  575. if err := binary.Write(ws, llm.ByteOrder, tensor.Kind); err != nil {
  576. return err
  577. }
  578. if err := binary.Write(ws, llm.ByteOrder, tensor.Offset); err != nil {
  579. return err
  580. }
  581. }
  582. var alignment int64 = 32
  583. for _, tensor := range tensors {
  584. offset, err := ws.Seek(0, io.SeekCurrent)
  585. if err != nil {
  586. return err
  587. }
  588. padding := llm.padding(offset, alignment)
  589. if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
  590. return err
  591. }
  592. if _, err := tensor.WriteTo(ws); err != nil {
  593. return err
  594. }
  595. }
  596. return nil
  597. }
  598. func (gguf) padding(offset, align int64) int64 {
  599. return (align - offset%align) % align
  600. }
  601. // Reader and WriterTo
  602. type GGUFWriter struct {
  603. KV
  604. Tensors
  605. }
  606. type writeOffset struct {
  607. io.Writer
  608. offset int
  609. }
  610. func (wo *writeOffset) Write(p []byte) (int, error) {
  611. n, err := wo.Writer.Write(p)
  612. wo.offset += n
  613. return n, err
  614. }
  615. var _ io.Reader = (*GGUFWriter)(nil)
  616. var _ io.WriterTo = (*GGUFWriter)(nil)
  617. func (GGUFWriter) Read([]byte) (int, error) {
  618. panic("not implemeneted")
  619. }
  620. func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
  621. wo := &writeOffset{Writer: w}
  622. if err := binary.Write(wo, binary.LittleEndian, []byte("GGUF")); err != nil {
  623. return 0, err
  624. }
  625. if err := binary.Write(wo, binary.LittleEndian, uint32(3)); err != nil {
  626. return 0, err
  627. }
  628. if err := binary.Write(wo, binary.LittleEndian, uint64(len(gguf.Tensors))); err != nil {
  629. return 0, err
  630. }
  631. if err := binary.Write(wo, binary.LittleEndian, uint64(len(gguf.KV)-1)); err != nil {
  632. return 0, err
  633. }
  634. keys := maps.Keys(gguf.KV)
  635. slices.Sort(keys)
  636. for _, key := range keys {
  637. fmt.Println(key)
  638. switch key {
  639. case "general.parameter_count":
  640. continue
  641. default:
  642. if err := ggufWriteKV(wo, key, gguf.KV[key]); err != nil {
  643. return 0, err
  644. }
  645. }
  646. }
  647. sort.Sort(gguf.Tensors)
  648. var s uint64
  649. for _, t := range gguf.Tensors {
  650. t.Offset = s
  651. if err := ggufWriteTensorInfo(wo, t); err != nil {
  652. return 0, err
  653. }
  654. s += t.Size()
  655. }
  656. for _, t := range gguf.Tensors {
  657. if err := ggufWriteTensor(wo, t, wo.offset); err != nil {
  658. return 0, err
  659. }
  660. }
  661. return 0, nil
  662. }
  663. func ggufWriteTensorInfo(ws io.Writer, t *Tensor) error {
  664. if err := binary.Write(ws, binary.LittleEndian, uint64(len(t.Name))); err != nil {
  665. return err
  666. }
  667. if err := binary.Write(ws, binary.LittleEndian, []byte(t.Name)); err != nil {
  668. return err
  669. }
  670. if err := binary.Write(ws, binary.LittleEndian, uint32(len(t.Shape))); err != nil {
  671. return err
  672. }
  673. fmt.Println("tensor ", t.Name, " shape ", t.Shape)
  674. for i := range len(t.Shape) {
  675. if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil {
  676. return err
  677. }
  678. }
  679. if err := binary.Write(ws, binary.LittleEndian, t.Kind); err != nil {
  680. return err
  681. }
  682. return binary.Write(ws, binary.LittleEndian, t.Offset)
  683. }
  684. func ggufWriteTensor(ws io.Writer, t *Tensor, offset int) error {
  685. slog.Debug(t.Name, "kind", t.Kind, "shape", t.Shape, "offset", t.Offset)
  686. if err := binary.Write(ws, binary.LittleEndian, bytes.Repeat([]byte{0}, int(ggufPadding(int64(offset), 32)))); err != nil {
  687. return err
  688. }
  689. _, err := t.WriteTo(ws)
  690. return err
  691. }
  692. func ggufWriteKV(ws io.Writer, k string, v any) error {
  693. slog.Debug(k, "type", fmt.Sprintf("%T", v))
  694. if err := binary.Write(ws, binary.LittleEndian, uint64(len(k))); err != nil {
  695. return err
  696. }
  697. if err := binary.Write(ws, binary.LittleEndian, []byte(k)); err != nil {
  698. return err
  699. }
  700. var err error
  701. switch v := v.(type) {
  702. case uint32:
  703. err = writeGGUF(ws, ggufTypeUint32, v)
  704. case float32:
  705. err = writeGGUF(ws, ggufTypeFloat32, v)
  706. case bool:
  707. err = writeGGUF(ws, ggufTypeBool, v)
  708. case string:
  709. err = writeGGUFString(ws, v)
  710. case []int32:
  711. err = writeGGUFArray(ws, ggufTypeInt32, v)
  712. case []uint32:
  713. err = writeGGUFArray(ws, ggufTypeUint32, v)
  714. case []float32:
  715. err = writeGGUFArray(ws, ggufTypeFloat32, v)
  716. case []string:
  717. if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
  718. return err
  719. }
  720. if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
  721. return err
  722. }
  723. if err := binary.Write(ws, binary.LittleEndian, uint64(len(v))); err != nil {
  724. return err
  725. }
  726. for _, e := range v {
  727. if err := binary.Write(ws, binary.LittleEndian, uint64(len(e))); err != nil {
  728. return err
  729. }
  730. if err := binary.Write(ws, binary.LittleEndian, []byte(e)); err != nil {
  731. return err
  732. }
  733. }
  734. case *array:
  735. if v.size > 0 {
  736. switch v.values[0].(type) {
  737. case string:
  738. if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
  739. return err
  740. }
  741. if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
  742. return err
  743. }
  744. if err := binary.Write(ws, binary.LittleEndian, uint64(v.size)); err != nil {
  745. return err
  746. }
  747. for _, e := range v.values {
  748. if err := binary.Write(ws, binary.LittleEndian, uint64(len(e.(string)))); err != nil {
  749. return err
  750. }
  751. if err := binary.Write(ws, binary.LittleEndian, []byte(e.(string))); err != nil {
  752. return err
  753. }
  754. }
  755. default:
  756. err = writeGGUFArray(ws, v.datatype, v.values)
  757. }
  758. }
  759. default:
  760. fmt.Println("type is", v)
  761. return fmt.Errorf("improper type for '%s'", k)
  762. }
  763. return err
  764. }
  765. func ggufPadding(offset, align int64) int64 {
  766. return align - offset%align
  767. }