gguf.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910
  1. package llm
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "log/slog"
  9. "slices"
  10. "strings"
  11. "golang.org/x/exp/maps"
  12. )
  13. type containerGGUF struct {
  14. ByteOrder binary.ByteOrder
  15. Version uint32
  16. V1 struct {
  17. NumTensor uint32
  18. NumKV uint32
  19. }
  20. V2 struct {
  21. NumTensor uint64
  22. NumKV uint64
  23. }
  24. V3 struct {
  25. NumTensor uint64
  26. NumKV uint64
  27. }
  28. maxArraySize int
  29. }
  30. func (c *containerGGUF) canCollectArray(size int) bool {
  31. return c.maxArraySize < 0 || size <= c.maxArraySize
  32. }
  33. func (c *containerGGUF) Name() string {
  34. return "gguf"
  35. }
  36. func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
  37. if err := binary.Read(rs, c.ByteOrder, &c.Version); err != nil {
  38. return nil, err
  39. }
  40. var err error
  41. switch c.Version {
  42. case 1:
  43. err = binary.Read(rs, c.ByteOrder, &c.V1)
  44. case 2:
  45. err = binary.Read(rs, c.ByteOrder, &c.V2)
  46. default:
  47. err = binary.Read(rs, c.ByteOrder, &c.V3)
  48. }
  49. if err != nil {
  50. return nil, err
  51. }
  52. model := newGGUF(c)
  53. if err := model.Decode(rs); err != nil {
  54. return nil, err
  55. }
  56. return model, nil
  57. }
  58. const (
  59. ggufTypeUint8 uint32 = iota
  60. ggufTypeInt8
  61. ggufTypeUint16
  62. ggufTypeInt16
  63. ggufTypeUint32
  64. ggufTypeInt32
  65. ggufTypeFloat32
  66. ggufTypeBool
  67. ggufTypeString
  68. ggufTypeArray
  69. ggufTypeUint64
  70. ggufTypeInt64
  71. ggufTypeFloat64
  72. )
  73. type gguf struct {
  74. *containerGGUF
  75. kv KV
  76. tensors []*Tensor
  77. parameters uint64
  78. scratch [16 << 10]byte
  79. }
  80. func newGGUF(container *containerGGUF) *gguf {
  81. return &gguf{
  82. containerGGUF: container,
  83. kv: make(KV),
  84. }
  85. }
  86. func NewGGUFV3(bo binary.ByteOrder) *gguf {
  87. return newGGUF(&containerGGUF{ByteOrder: bo, Version: 3})
  88. }
  89. func (llm *gguf) KV() KV {
  90. return llm.kv
  91. }
  92. func (llm *gguf) Tensors() Tensors {
  93. return llm.tensors
  94. }
  95. func (llm *gguf) numTensor() uint64 {
  96. switch llm.Version {
  97. case 1:
  98. return uint64(llm.V1.NumTensor)
  99. case 2:
  100. return llm.V2.NumTensor
  101. default:
  102. return llm.V3.NumTensor
  103. }
  104. }
  105. func (llm *gguf) numKV() uint64 {
  106. switch llm.Version {
  107. case 1:
  108. return uint64(llm.V1.NumKV)
  109. case 2:
  110. return llm.V2.NumKV
  111. default:
  112. return llm.V3.NumKV
  113. }
  114. }
  115. func (llm *gguf) Decode(rs io.ReadSeeker) error {
  116. // decode key-values
  117. for i := 0; uint64(i) < llm.numKV(); i++ {
  118. k, err := readGGUFString(llm, rs)
  119. if err != nil {
  120. return err
  121. }
  122. t, err := readGGUF[uint32](llm, rs)
  123. if err != nil {
  124. return err
  125. }
  126. var v any
  127. switch t {
  128. case ggufTypeUint8:
  129. v, err = readGGUF[uint8](llm, rs)
  130. case ggufTypeInt8:
  131. v, err = readGGUF[int8](llm, rs)
  132. case ggufTypeUint16:
  133. v, err = readGGUF[uint16](llm, rs)
  134. case ggufTypeInt16:
  135. v, err = readGGUF[int16](llm, rs)
  136. case ggufTypeUint32:
  137. v, err = readGGUF[uint32](llm, rs)
  138. case ggufTypeInt32:
  139. v, err = readGGUF[int32](llm, rs)
  140. case ggufTypeUint64:
  141. v, err = readGGUF[uint64](llm, rs)
  142. case ggufTypeInt64:
  143. v, err = readGGUF[int64](llm, rs)
  144. case ggufTypeFloat32:
  145. v, err = readGGUF[float32](llm, rs)
  146. case ggufTypeFloat64:
  147. v, err = readGGUF[float64](llm, rs)
  148. case ggufTypeBool:
  149. v, err = readGGUF[bool](llm, rs)
  150. case ggufTypeString:
  151. v, err = readGGUFString(llm, rs)
  152. case ggufTypeArray:
  153. v, err = readGGUFArray(llm, rs)
  154. default:
  155. return fmt.Errorf("invalid type: %d", t)
  156. }
  157. if err != nil {
  158. return err
  159. }
  160. llm.kv[k] = v
  161. }
  162. // decode tensors
  163. for range llm.numTensor() {
  164. name, err := readGGUFString(llm, rs)
  165. if err != nil {
  166. return fmt.Errorf("failed to read tensor name: %w", err)
  167. }
  168. // dims is the number of dimensions in the tensor
  169. dims, err := readGGUF[uint32](llm, rs)
  170. if err != nil {
  171. return fmt.Errorf("failed to read tensor dimensions: %w", err)
  172. }
  173. shape := []uint64{}
  174. for i := 0; uint32(i) < dims; i++ {
  175. shapeVal, err := readGGUF[uint64](llm, rs)
  176. if err != nil {
  177. return fmt.Errorf("failed to read tensor shape: %w", err)
  178. }
  179. shape = append(shape, shapeVal)
  180. }
  181. kind, err := readGGUF[uint32](llm, rs)
  182. if err != nil {
  183. return fmt.Errorf("failed to read tensor kind: %w", err)
  184. }
  185. offset, err := readGGUF[uint64](llm, rs)
  186. if err != nil {
  187. return fmt.Errorf("failed to read tensor offset: %w", err)
  188. }
  189. fmt.Println("tensor", name, shape, kind, offset)
  190. tensor := Tensor{
  191. Name: name,
  192. Kind: kind,
  193. Offset: offset,
  194. Shape: shape,
  195. }
  196. llm.tensors = append(llm.tensors, &tensor)
  197. llm.parameters += tensor.parameters()
  198. }
  199. // patch KV with parameter count
  200. llm.kv["general.parameter_count"] = llm.parameters
  201. alignment, ok := llm.kv["general.alignment"].(uint32)
  202. if !ok {
  203. alignment = 32
  204. }
  205. for _, tensor := range llm.tensors {
  206. offset, err := rs.Seek(0, io.SeekCurrent)
  207. if err != nil {
  208. return fmt.Errorf("failed to get current offset: %w", err)
  209. }
  210. padding := llm.padding(offset, int64(alignment))
  211. if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
  212. return fmt.Errorf("failed to seek to init padding: %w", err)
  213. }
  214. if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil {
  215. return fmt.Errorf("failed to seek to tensor: %w", err)
  216. }
  217. }
  218. return nil
  219. }
  220. func readGGUF[T any](llm *gguf, r io.Reader) (T, error) {
  221. var t T
  222. err := binary.Read(r, llm.ByteOrder, &t)
  223. return t, err
  224. }
  225. func writeGGUF[V any](w io.Writer, t uint32, v V) error {
  226. if err := binary.Write(w, binary.LittleEndian, t); err != nil {
  227. return err
  228. }
  229. return binary.Write(w, binary.LittleEndian, v)
  230. }
  231. func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
  232. var length uint64
  233. if err := binary.Read(r, llm.ByteOrder, &length); err != nil {
  234. return "", err
  235. }
  236. var b bytes.Buffer
  237. if _, err := io.CopyN(&b, r, int64(length)); err != nil {
  238. return "", err
  239. }
  240. // gguf v1 strings are null-terminated
  241. b.Truncate(b.Len() - 1)
  242. return b.String(), nil
  243. }
  244. func discardGGUFString(llm *gguf, r io.Reader) error {
  245. buf := llm.scratch[:8]
  246. _, err := io.ReadFull(r, buf)
  247. if err != nil {
  248. return err
  249. }
  250. size := int(llm.ByteOrder.Uint64(buf))
  251. for size > 0 {
  252. n, err := r.Read(llm.scratch[:min(size, cap(llm.scratch))])
  253. if err != nil {
  254. return err
  255. }
  256. size -= n
  257. }
  258. return nil
  259. }
  260. func readGGUFString(llm *gguf, r io.Reader) (string, error) {
  261. if llm.Version == 1 {
  262. return readGGUFV1String(llm, r)
  263. }
  264. buf := llm.scratch[:8]
  265. _, err := io.ReadFull(r, buf)
  266. if err != nil {
  267. return "", err
  268. }
  269. length := int(llm.ByteOrder.Uint64(buf))
  270. if length > len(llm.scratch) {
  271. buf = make([]byte, length)
  272. } else {
  273. buf = llm.scratch[:length]
  274. }
  275. clear(buf)
  276. _, err = io.ReadFull(r, buf)
  277. if err != nil {
  278. return "", err
  279. }
  280. return string(buf), nil
  281. }
  282. func writeGGUFString(w io.Writer, s string) error {
  283. if err := binary.Write(w, binary.LittleEndian, ggufTypeString); err != nil {
  284. return err
  285. }
  286. if err := binary.Write(w, binary.LittleEndian, uint64(len(s))); err != nil {
  287. return err
  288. }
  289. _, err := io.Copy(w, strings.NewReader(s))
  290. return err
  291. }
  292. type array struct {
  293. size int
  294. values []any
  295. datatype uint32
  296. }
  297. func (a *array) MarshalJSON() ([]byte, error) {
  298. return json.Marshal(a.values)
  299. }
  300. func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
  301. t, err := readGGUF[uint32](llm, r)
  302. if err != nil {
  303. return nil, err
  304. }
  305. n, err := readGGUF[uint32](llm, r)
  306. if err != nil {
  307. return nil, err
  308. }
  309. a := &array{size: int(n)}
  310. if llm.canCollectArray(int(n)) {
  311. a.values = make([]any, 0, int(n))
  312. }
  313. for i := range n {
  314. var e any
  315. switch t {
  316. case ggufTypeUint8:
  317. e, err = readGGUF[uint8](llm, r)
  318. case ggufTypeInt8:
  319. e, err = readGGUF[int8](llm, r)
  320. case ggufTypeUint16:
  321. e, err = readGGUF[uint16](llm, r)
  322. case ggufTypeInt16:
  323. e, err = readGGUF[int16](llm, r)
  324. case ggufTypeUint32:
  325. e, err = readGGUF[uint32](llm, r)
  326. case ggufTypeInt32:
  327. e, err = readGGUF[int32](llm, r)
  328. case ggufTypeUint64:
  329. e, err = readGGUF[uint64](llm, r)
  330. case ggufTypeInt64:
  331. e, err = readGGUF[int64](llm, r)
  332. case ggufTypeFloat32:
  333. e, err = readGGUF[float32](llm, r)
  334. case ggufTypeFloat64:
  335. e, err = readGGUF[float64](llm, r)
  336. case ggufTypeBool:
  337. e, err = readGGUF[bool](llm, r)
  338. case ggufTypeString:
  339. e, err = readGGUFV1String(llm, r)
  340. default:
  341. return nil, fmt.Errorf("invalid array type: %d", t)
  342. }
  343. if err != nil {
  344. return nil, err
  345. }
  346. if a.values != nil {
  347. a.values[i] = e
  348. }
  349. }
  350. return a, nil
  351. }
  352. func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
  353. if llm.Version == 1 {
  354. return readGGUFV1Array(llm, r)
  355. }
  356. t, err := readGGUF[uint32](llm, r)
  357. if err != nil {
  358. return nil, err
  359. }
  360. n, err := readGGUF[uint64](llm, r)
  361. if err != nil {
  362. return nil, err
  363. }
  364. a := &array{size: int(n), datatype: t}
  365. if llm.canCollectArray(int(n)) {
  366. a.values = make([]any, int(n))
  367. }
  368. for i := range n {
  369. var e any
  370. switch t {
  371. case ggufTypeUint8:
  372. e, err = readGGUF[uint8](llm, r)
  373. case ggufTypeInt8:
  374. e, err = readGGUF[int8](llm, r)
  375. case ggufTypeUint16:
  376. e, err = readGGUF[uint16](llm, r)
  377. case ggufTypeInt16:
  378. e, err = readGGUF[int16](llm, r)
  379. case ggufTypeUint32:
  380. e, err = readGGUF[uint32](llm, r)
  381. case ggufTypeInt32:
  382. e, err = readGGUF[int32](llm, r)
  383. case ggufTypeUint64:
  384. e, err = readGGUF[uint64](llm, r)
  385. case ggufTypeInt64:
  386. e, err = readGGUF[int64](llm, r)
  387. case ggufTypeFloat32:
  388. e, err = readGGUF[float32](llm, r)
  389. case ggufTypeFloat64:
  390. e, err = readGGUF[float64](llm, r)
  391. case ggufTypeBool:
  392. e, err = readGGUF[bool](llm, r)
  393. case ggufTypeString:
  394. if a.values != nil {
  395. e, err = readGGUFString(llm, r)
  396. } else {
  397. err = discardGGUFString(llm, r)
  398. }
  399. default:
  400. return nil, fmt.Errorf("invalid array type: %d", t)
  401. }
  402. if err != nil {
  403. return nil, err
  404. }
  405. if a.values != nil {
  406. a.values[i] = e
  407. }
  408. }
  409. return a, nil
  410. }
  411. func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
  412. if err := binary.Write(w, binary.LittleEndian, ggufTypeArray); err != nil {
  413. return err
  414. }
  415. if err := binary.Write(w, binary.LittleEndian, t); err != nil {
  416. return err
  417. }
  418. if err := binary.Write(w, binary.LittleEndian, uint64(len(s))); err != nil {
  419. return err
  420. }
  421. for _, e := range s {
  422. if err := binary.Write(w, binary.LittleEndian, e); err != nil {
  423. return err
  424. }
  425. }
  426. return nil
  427. }
  428. var ggufKVOrder = map[string][]string{
  429. "llama": {
  430. "general.architecture",
  431. "general.name",
  432. "llama.vocab_size",
  433. "llama.context_length",
  434. "llama.embedding_length",
  435. "llama.block_count",
  436. "llama.feed_forward_length",
  437. "llama.attention.head_count",
  438. "llama.attention.head_count_kv",
  439. "llama.attention.layer_norm_rms_epsilon",
  440. "llama.rope.freq_base",
  441. "llama.rope.dimension_count",
  442. "llama.expert_count",
  443. "llama.expert_used_count",
  444. "gemma.context_length",
  445. "gemma.embedding_length",
  446. "gemma.block_count",
  447. "gemma.feed_forward_length",
  448. "gemma.attention.head_count",
  449. "gemma.attention.head_count_kv",
  450. "gemma.attention.layer_norm_rms_epsilon",
  451. "gemma.attention.key_length",
  452. "gemma.attention.value_length",
  453. "general.file_type",
  454. "tokenizer.ggml.pre",
  455. "tokenizer.ggml.model",
  456. "tokenizer.ggml.tokens",
  457. "tokenizer.ggml.scores",
  458. "tokenizer.ggml.merges",
  459. "tokenizer.ggml.token_type",
  460. "tokenizer.ggml.bos_token_id",
  461. "tokenizer.ggml.eos_token_id",
  462. "tokenizer.ggml.unknown_token_id",
  463. "tokenizer.ggml.padding_token_id",
  464. "tokenizer.ggml.add_bos_token",
  465. "tokenizer.ggml.add_eos_token",
  466. "tokenizer.chat_template",
  467. },
  468. }
  469. func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
  470. switch llm.Version {
  471. case 3:
  472. llm.V3.NumTensor = uint64(len(tensors))
  473. llm.V3.NumKV = uint64(len(kv))
  474. default:
  475. return fmt.Errorf("not implemented: ggufv%d", llm.Version)
  476. }
  477. if err := binary.Write(ws, llm.ByteOrder, []byte("GGUF")); err != nil {
  478. return err
  479. }
  480. if err := binary.Write(ws, llm.ByteOrder, llm.Version); err != nil {
  481. return err
  482. }
  483. if err := binary.Write(ws, llm.ByteOrder, llm.numTensor()); err != nil {
  484. return err
  485. }
  486. if err := binary.Write(ws, llm.ByteOrder, llm.numKV()); err != nil {
  487. return err
  488. }
  489. kvCheck := make(map[string]bool)
  490. for k := range kv {
  491. kvCheck[k] = false
  492. }
  493. for _, k := range ggufKVOrder["llama"] {
  494. v, ok := kv[k]
  495. if !ok {
  496. continue
  497. }
  498. kvCheck[k] = true
  499. if err := binary.Write(ws, llm.ByteOrder, uint64(len(k))); err != nil {
  500. return err
  501. }
  502. if err := binary.Write(ws, llm.ByteOrder, []byte(k)); err != nil {
  503. return err
  504. }
  505. var err error
  506. switch v := v.(type) {
  507. case uint32:
  508. err = writeGGUF(ws, ggufTypeUint32, v)
  509. case float32:
  510. err = writeGGUF(ws, ggufTypeFloat32, v)
  511. case bool:
  512. err = writeGGUF(ws, ggufTypeBool, v)
  513. case string:
  514. err = writeGGUFString(ws, v)
  515. case []int32:
  516. err = writeGGUFArray(ws, ggufTypeInt32, v)
  517. case []uint32:
  518. err = writeGGUFArray(ws, ggufTypeUint32, v)
  519. case []float32:
  520. err = writeGGUFArray(ws, ggufTypeFloat32, v)
  521. case []string:
  522. if err := binary.Write(ws, llm.ByteOrder, ggufTypeArray); err != nil {
  523. return err
  524. }
  525. if err := binary.Write(ws, llm.ByteOrder, ggufTypeString); err != nil {
  526. return err
  527. }
  528. if err := binary.Write(ws, llm.ByteOrder, uint64(len(v))); err != nil {
  529. return err
  530. }
  531. for _, e := range v {
  532. if err := binary.Write(ws, llm.ByteOrder, uint64(len(e))); err != nil {
  533. return err
  534. }
  535. if err := binary.Write(ws, llm.ByteOrder, []byte(e)); err != nil {
  536. return err
  537. }
  538. }
  539. default:
  540. return fmt.Errorf("improper type for '%s'", k)
  541. }
  542. if err != nil {
  543. return err
  544. }
  545. }
  546. for k, v := range kvCheck {
  547. if !v {
  548. return fmt.Errorf("didn't know how to write kv %s", k)
  549. }
  550. }
  551. for _, tensor := range tensors {
  552. if err := binary.Write(ws, llm.ByteOrder, uint64(len(tensor.Name))); err != nil {
  553. return err
  554. }
  555. if err := binary.Write(ws, llm.ByteOrder, []byte(tensor.Name)); err != nil {
  556. return err
  557. }
  558. var dims int
  559. for cnt := range len(tensor.Shape) {
  560. if tensor.Shape[cnt] > 0 {
  561. dims++
  562. }
  563. }
  564. if err := binary.Write(ws, llm.ByteOrder, uint32(dims)); err != nil {
  565. return err
  566. }
  567. for i := range dims {
  568. if err := binary.Write(ws, llm.ByteOrder, tensor.Shape[dims-1-i]); err != nil {
  569. return err
  570. }
  571. }
  572. if err := binary.Write(ws, llm.ByteOrder, tensor.Kind); err != nil {
  573. return err
  574. }
  575. if err := binary.Write(ws, llm.ByteOrder, tensor.Offset); err != nil {
  576. return err
  577. }
  578. }
  579. var alignment int64 = 32
  580. for _, tensor := range tensors {
  581. offset, err := ws.Seek(0, io.SeekCurrent)
  582. if err != nil {
  583. return err
  584. }
  585. padding := llm.padding(offset, alignment)
  586. if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
  587. return err
  588. }
  589. if _, err := tensor.WriteTo(ws); err != nil {
  590. return err
  591. }
  592. }
  593. return nil
  594. }
  595. func (gguf) padding(offset, align int64) int64 {
  596. return (align - offset%align) % align
  597. }
  598. // Reader and WriterTo
  599. type GGUFWriter struct {
  600. KV
  601. Tensors
  602. }
  603. type writeOffset struct {
  604. io.Writer
  605. offset int
  606. }
  607. func (wo *writeOffset) Write(p []byte) (int, error) {
  608. n, err := wo.Writer.Write(p)
  609. wo.offset += n
  610. return n, err
  611. }
  612. var _ io.Reader = (*GGUFWriter)(nil)
  613. var _ io.WriterTo = (*GGUFWriter)(nil)
  614. func (GGUFWriter) Read([]byte) (int, error) {
  615. panic("not implemeneted")
  616. }
  617. func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
  618. wo := &writeOffset{Writer: w}
  619. if err := binary.Write(wo, binary.LittleEndian, []byte("GGUF")); err != nil {
  620. return 0, err
  621. }
  622. if err := binary.Write(wo, binary.LittleEndian, uint32(3)); err != nil {
  623. return 0, err
  624. }
  625. if err := binary.Write(wo, binary.LittleEndian, uint64(len(gguf.Tensors))); err != nil {
  626. return 0, err
  627. }
  628. if err := binary.Write(wo, binary.LittleEndian, uint64(len(gguf.KV)-1)); err != nil {
  629. return 0, err
  630. }
  631. keys := maps.Keys(gguf.KV)
  632. slices.Sort(keys)
  633. for _, key := range keys {
  634. fmt.Println(key)
  635. switch key {
  636. case "general.parameter_count":
  637. continue
  638. default:
  639. if err := ggufWriteKV(wo, key, gguf.KV[key]); err != nil {
  640. return 0, err
  641. }
  642. }
  643. }
  644. //sort.Sort(gguf.Tensors)
  645. var s uint64
  646. for _, t := range gguf.Tensors {
  647. t.Offset = s
  648. if err := ggufWriteTensorInfo(wo, t); err != nil {
  649. return 0, err
  650. }
  651. s += t.Size()
  652. }
  653. tensorOffset := wo.offset
  654. for _, t := range gguf.Tensors {
  655. if err := ggufWriteTensor(wo, t, wo.offset); err != nil {
  656. return 0, err
  657. }
  658. }
  659. return int64(tensorOffset), nil
  660. }
  661. func ggufWriteTensorInfo(ws io.Writer, t *Tensor) error {
  662. if err := binary.Write(ws, binary.LittleEndian, uint64(len(t.Name))); err != nil {
  663. return err
  664. }
  665. if err := binary.Write(ws, binary.LittleEndian, []byte(t.Name)); err != nil {
  666. return err
  667. }
  668. if err := binary.Write(ws, binary.LittleEndian, uint32(len(t.Shape))); err != nil {
  669. return err
  670. }
  671. for i := range len(t.Shape) {
  672. if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil {
  673. return err
  674. }
  675. }
  676. if err := binary.Write(ws, binary.LittleEndian, t.Kind); err != nil {
  677. return err
  678. }
  679. return binary.Write(ws, binary.LittleEndian, t.Offset)
  680. }
  681. func ggufWriteTensor(ws io.Writer, t *Tensor, offset int) error {
  682. slog.Debug(t.Name, "kind", t.Kind, "shape", t.Shape, "offset", t.Offset)
  683. fmt.Println(int(ggufPadding(int64(offset), 32)))
  684. /* if err := binary.Write(ws, binary.LittleEndian, bytes.Repeat([]byte{0}, int(ggufPadding(int64(offset), 32)))); err != nil {
  685. return err
  686. } */
  687. _, err := t.WriteTo(ws)
  688. return err
  689. }
  690. func ggufWriteKV(ws io.Writer, k string, v any) error {
  691. slog.Debug(k, "type", fmt.Sprintf("%T", v))
  692. if err := binary.Write(ws, binary.LittleEndian, uint64(len(k))); err != nil {
  693. return err
  694. }
  695. if err := binary.Write(ws, binary.LittleEndian, []byte(k)); err != nil {
  696. return err
  697. }
  698. var err error
  699. switch v := v.(type) {
  700. case uint32:
  701. err = writeGGUF(ws, ggufTypeUint32, v)
  702. case float32:
  703. err = writeGGUF(ws, ggufTypeFloat32, v)
  704. case bool:
  705. err = writeGGUF(ws, ggufTypeBool, v)
  706. case string:
  707. err = writeGGUFString(ws, v)
  708. case []int32:
  709. err = writeGGUFArray(ws, ggufTypeInt32, v)
  710. case []uint32:
  711. err = writeGGUFArray(ws, ggufTypeUint32, v)
  712. case []float32:
  713. err = writeGGUFArray(ws, ggufTypeFloat32, v)
  714. case []string:
  715. if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
  716. return err
  717. }
  718. if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
  719. return err
  720. }
  721. if err := binary.Write(ws, binary.LittleEndian, uint64(len(v))); err != nil {
  722. return err
  723. }
  724. for _, e := range v {
  725. if err := binary.Write(ws, binary.LittleEndian, uint64(len(e))); err != nil {
  726. return err
  727. }
  728. if err := binary.Write(ws, binary.LittleEndian, []byte(e)); err != nil {
  729. return err
  730. }
  731. }
  732. case *array:
  733. if v.size > 0 {
  734. switch v.values[0].(type) {
  735. case string:
  736. if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
  737. return err
  738. }
  739. if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
  740. return err
  741. }
  742. if err := binary.Write(ws, binary.LittleEndian, uint64(v.size)); err != nil {
  743. return err
  744. }
  745. for _, e := range v.values {
  746. if err := binary.Write(ws, binary.LittleEndian, uint64(len(e.(string)))); err != nil {
  747. return err
  748. }
  749. if err := binary.Write(ws, binary.LittleEndian, []byte(e.(string))); err != nil {
  750. return err
  751. }
  752. }
  753. default:
  754. err = writeGGUFArray(ws, v.datatype, v.values)
  755. }
  756. }
  757. default:
  758. fmt.Println("type is", v)
  759. return fmt.Errorf("improper type for '%s'", k)
  760. }
  761. return err
  762. }
  763. func ggufPadding(offset, align int64) int64 {
  764. return align - offset%align
  765. }