gguf.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667
  1. package llm
  2. import (
  3. "bytes"
  4. "cmp"
  5. "encoding/binary"
  6. "encoding/json"
  7. "fmt"
  8. "io"
  9. "log/slog"
  10. "slices"
  11. "strings"
  12. "golang.org/x/exp/maps"
  13. "github.com/ollama/ollama/api"
  14. )
  15. type containerGGUF struct {
  16. ByteOrder binary.ByteOrder
  17. Version uint32
  18. V1 struct {
  19. NumTensor uint32
  20. NumKV uint32
  21. }
  22. V2 struct {
  23. NumTensor uint64
  24. NumKV uint64
  25. }
  26. V3 struct {
  27. NumTensor uint64
  28. NumKV uint64
  29. }
  30. maxArraySize int
  31. }
  32. func (c *containerGGUF) canCollectArray(size int) bool {
  33. return c.maxArraySize < 0 || size <= c.maxArraySize
  34. }
  35. func (c *containerGGUF) Name() string {
  36. return "gguf"
  37. }
  38. func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
  39. if err := binary.Read(rs, c.ByteOrder, &c.Version); err != nil {
  40. return nil, err
  41. }
  42. var err error
  43. switch c.Version {
  44. case 1:
  45. err = binary.Read(rs, c.ByteOrder, &c.V1)
  46. case 2:
  47. err = binary.Read(rs, c.ByteOrder, &c.V2)
  48. default:
  49. err = binary.Read(rs, c.ByteOrder, &c.V3)
  50. }
  51. if err != nil {
  52. return nil, err
  53. }
  54. model := newGGUF(c)
  55. if err := model.Decode(rs); err != nil {
  56. return nil, err
  57. }
  58. return model, nil
  59. }
  60. const (
  61. ggufTypeUint8 uint32 = iota
  62. ggufTypeInt8
  63. ggufTypeUint16
  64. ggufTypeInt16
  65. ggufTypeUint32
  66. ggufTypeInt32
  67. ggufTypeFloat32
  68. ggufTypeBool
  69. ggufTypeString
  70. ggufTypeArray
  71. ggufTypeUint64
  72. ggufTypeInt64
  73. ggufTypeFloat64
  74. )
  75. type gguf struct {
  76. *containerGGUF
  77. kv KV
  78. tensors []*Tensor
  79. parameters uint64
  80. tensorOffset uint64
  81. scratch [16 << 10]byte
  82. }
  83. func newGGUF(container *containerGGUF) *gguf {
  84. return &gguf{
  85. containerGGUF: container,
  86. kv: make(KV),
  87. }
  88. }
  89. func (llm *gguf) KV() KV {
  90. return llm.kv
  91. }
  92. func (llm *gguf) Tensors() Tensors {
  93. return Tensors{
  94. Items: llm.tensors,
  95. Offset: llm.tensorOffset,
  96. }
  97. }
  98. func (llm *gguf) numTensor() uint64 {
  99. switch llm.Version {
  100. case 1:
  101. return uint64(llm.V1.NumTensor)
  102. case 2:
  103. return llm.V2.NumTensor
  104. default:
  105. return llm.V3.NumTensor
  106. }
  107. }
  108. func (llm *gguf) numKV() uint64 {
  109. switch llm.Version {
  110. case 1:
  111. return uint64(llm.V1.NumKV)
  112. case 2:
  113. return llm.V2.NumKV
  114. default:
  115. return llm.V3.NumKV
  116. }
  117. }
  118. func (llm *gguf) Decode(rs io.ReadSeeker) error {
  119. // decode key-values
  120. for i := 0; uint64(i) < llm.numKV(); i++ {
  121. k, err := readGGUFString(llm, rs)
  122. if err != nil {
  123. return err
  124. }
  125. t, err := readGGUF[uint32](llm, rs)
  126. if err != nil {
  127. return err
  128. }
  129. var v any
  130. switch t {
  131. case ggufTypeUint8:
  132. v, err = readGGUF[uint8](llm, rs)
  133. case ggufTypeInt8:
  134. v, err = readGGUF[int8](llm, rs)
  135. case ggufTypeUint16:
  136. v, err = readGGUF[uint16](llm, rs)
  137. case ggufTypeInt16:
  138. v, err = readGGUF[int16](llm, rs)
  139. case ggufTypeUint32:
  140. v, err = readGGUF[uint32](llm, rs)
  141. case ggufTypeInt32:
  142. v, err = readGGUF[int32](llm, rs)
  143. case ggufTypeUint64:
  144. v, err = readGGUF[uint64](llm, rs)
  145. case ggufTypeInt64:
  146. v, err = readGGUF[int64](llm, rs)
  147. case ggufTypeFloat32:
  148. v, err = readGGUF[float32](llm, rs)
  149. case ggufTypeFloat64:
  150. v, err = readGGUF[float64](llm, rs)
  151. case ggufTypeBool:
  152. v, err = readGGUF[bool](llm, rs)
  153. case ggufTypeString:
  154. v, err = readGGUFString(llm, rs)
  155. case ggufTypeArray:
  156. v, err = readGGUFArray(llm, rs)
  157. default:
  158. return fmt.Errorf("invalid type: %d", t)
  159. }
  160. if err != nil {
  161. return err
  162. }
  163. llm.kv[k] = v
  164. }
  165. // decode tensors
  166. for range llm.numTensor() {
  167. name, err := readGGUFString(llm, rs)
  168. if err != nil {
  169. return fmt.Errorf("failed to read tensor name: %w", err)
  170. }
  171. // dims is the number of dimensions in the tensor
  172. dims, err := readGGUF[uint32](llm, rs)
  173. if err != nil {
  174. return fmt.Errorf("failed to read tensor dimensions: %w", err)
  175. }
  176. shape := make([]uint64, dims)
  177. for i := 0; uint32(i) < dims; i++ {
  178. shape[i], err = readGGUF[uint64](llm, rs)
  179. if err != nil {
  180. return fmt.Errorf("failed to read tensor shape: %w", err)
  181. }
  182. }
  183. kind, err := readGGUF[uint32](llm, rs)
  184. if err != nil {
  185. return fmt.Errorf("failed to read tensor kind: %w", err)
  186. }
  187. offset, err := readGGUF[uint64](llm, rs)
  188. if err != nil {
  189. return fmt.Errorf("failed to read tensor offset: %w", err)
  190. }
  191. tensor := Tensor{
  192. Name: name,
  193. Kind: kind,
  194. Offset: offset,
  195. Shape: shape[:],
  196. }
  197. llm.tensors = append(llm.tensors, &tensor)
  198. llm.parameters += tensor.parameters()
  199. }
  200. // patch KV with parameter count
  201. llm.kv["general.parameter_count"] = llm.parameters
  202. alignment, ok := llm.kv["general.alignment"].(uint32)
  203. if !ok {
  204. alignment = 32
  205. }
  206. offset, err := rs.Seek(0, io.SeekCurrent)
  207. if err != nil {
  208. return err
  209. }
  210. padding := ggufPadding(offset, int64(alignment))
  211. llm.tensorOffset = uint64(offset + padding)
  212. for _, tensor := range llm.tensors {
  213. offset, err := rs.Seek(0, io.SeekCurrent)
  214. if err != nil {
  215. return fmt.Errorf("failed to get current offset: %w", err)
  216. }
  217. padding := ggufPadding(offset, int64(alignment))
  218. if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
  219. return fmt.Errorf("failed to seek to init padding: %w", err)
  220. }
  221. if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil {
  222. return fmt.Errorf("failed to seek to tensor: %w", err)
  223. }
  224. }
  225. return nil
  226. }
  227. func readGGUF[T any](llm *gguf, r io.Reader) (T, error) {
  228. var t T
  229. err := binary.Read(r, llm.ByteOrder, &t)
  230. return t, err
  231. }
  232. func writeGGUF[V any](w io.Writer, t uint32, v V) error {
  233. if err := binary.Write(w, binary.LittleEndian, t); err != nil {
  234. return err
  235. }
  236. return binary.Write(w, binary.LittleEndian, v)
  237. }
  238. func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
  239. var length uint64
  240. if err := binary.Read(r, llm.ByteOrder, &length); err != nil {
  241. return "", err
  242. }
  243. var b bytes.Buffer
  244. if _, err := io.CopyN(&b, r, int64(length)); err != nil {
  245. return "", err
  246. }
  247. // gguf v1 strings are null-terminated
  248. b.Truncate(b.Len() - 1)
  249. return b.String(), nil
  250. }
  251. func discardGGUFString(llm *gguf, r io.Reader) error {
  252. buf := llm.scratch[:8]
  253. _, err := io.ReadFull(r, buf)
  254. if err != nil {
  255. return err
  256. }
  257. size := int(llm.ByteOrder.Uint64(buf))
  258. for size > 0 {
  259. n, err := r.Read(llm.scratch[:min(size, cap(llm.scratch))])
  260. if err != nil {
  261. return err
  262. }
  263. size -= n
  264. }
  265. return nil
  266. }
  267. func readGGUFString(llm *gguf, r io.Reader) (string, error) {
  268. if llm.Version == 1 {
  269. return readGGUFV1String(llm, r)
  270. }
  271. buf := llm.scratch[:8]
  272. _, err := io.ReadFull(r, buf)
  273. if err != nil {
  274. return "", err
  275. }
  276. length := int(llm.ByteOrder.Uint64(buf))
  277. if length > len(llm.scratch) {
  278. buf = make([]byte, length)
  279. } else {
  280. buf = llm.scratch[:length]
  281. }
  282. clear(buf)
  283. _, err = io.ReadFull(r, buf)
  284. if err != nil {
  285. return "", err
  286. }
  287. return string(buf), nil
  288. }
  289. func writeGGUFString(w io.Writer, s string) error {
  290. if err := binary.Write(w, binary.LittleEndian, ggufTypeString); err != nil {
  291. return err
  292. }
  293. if err := binary.Write(w, binary.LittleEndian, uint64(len(s))); err != nil {
  294. return err
  295. }
  296. _, err := io.Copy(w, strings.NewReader(s))
  297. return err
  298. }
  299. type array struct {
  300. size int
  301. values []any
  302. }
  303. func (a *array) MarshalJSON() ([]byte, error) {
  304. return json.Marshal(a.values)
  305. }
  306. func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
  307. t, err := readGGUF[uint32](llm, r)
  308. if err != nil {
  309. return nil, err
  310. }
  311. n, err := readGGUF[uint32](llm, r)
  312. if err != nil {
  313. return nil, err
  314. }
  315. a := &array{size: int(n)}
  316. if llm.canCollectArray(int(n)) {
  317. a.values = make([]any, 0, int(n))
  318. }
  319. for i := range n {
  320. var e any
  321. switch t {
  322. case ggufTypeUint8:
  323. e, err = readGGUF[uint8](llm, r)
  324. case ggufTypeInt8:
  325. e, err = readGGUF[int8](llm, r)
  326. case ggufTypeUint16:
  327. e, err = readGGUF[uint16](llm, r)
  328. case ggufTypeInt16:
  329. e, err = readGGUF[int16](llm, r)
  330. case ggufTypeUint32:
  331. e, err = readGGUF[uint32](llm, r)
  332. case ggufTypeInt32:
  333. e, err = readGGUF[int32](llm, r)
  334. case ggufTypeUint64:
  335. e, err = readGGUF[uint64](llm, r)
  336. case ggufTypeInt64:
  337. e, err = readGGUF[int64](llm, r)
  338. case ggufTypeFloat32:
  339. e, err = readGGUF[float32](llm, r)
  340. case ggufTypeFloat64:
  341. e, err = readGGUF[float64](llm, r)
  342. case ggufTypeBool:
  343. e, err = readGGUF[bool](llm, r)
  344. case ggufTypeString:
  345. e, err = readGGUFV1String(llm, r)
  346. default:
  347. return nil, fmt.Errorf("invalid array type: %d", t)
  348. }
  349. if err != nil {
  350. return nil, err
  351. }
  352. if a.values != nil {
  353. a.values[i] = e
  354. }
  355. }
  356. return a, nil
  357. }
  358. func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
  359. if llm.Version == 1 {
  360. return readGGUFV1Array(llm, r)
  361. }
  362. t, err := readGGUF[uint32](llm, r)
  363. if err != nil {
  364. return nil, err
  365. }
  366. n, err := readGGUF[uint64](llm, r)
  367. if err != nil {
  368. return nil, err
  369. }
  370. a := &array{size: int(n)}
  371. if llm.canCollectArray(int(n)) {
  372. a.values = make([]any, int(n))
  373. }
  374. for i := range n {
  375. var e any
  376. switch t {
  377. case ggufTypeUint8:
  378. e, err = readGGUF[uint8](llm, r)
  379. case ggufTypeInt8:
  380. e, err = readGGUF[int8](llm, r)
  381. case ggufTypeUint16:
  382. e, err = readGGUF[uint16](llm, r)
  383. case ggufTypeInt16:
  384. e, err = readGGUF[int16](llm, r)
  385. case ggufTypeUint32:
  386. e, err = readGGUF[uint32](llm, r)
  387. case ggufTypeInt32:
  388. e, err = readGGUF[int32](llm, r)
  389. case ggufTypeUint64:
  390. e, err = readGGUF[uint64](llm, r)
  391. case ggufTypeInt64:
  392. e, err = readGGUF[int64](llm, r)
  393. case ggufTypeFloat32:
  394. e, err = readGGUF[float32](llm, r)
  395. case ggufTypeFloat64:
  396. e, err = readGGUF[float64](llm, r)
  397. case ggufTypeBool:
  398. e, err = readGGUF[bool](llm, r)
  399. case ggufTypeString:
  400. if a.values != nil {
  401. e, err = readGGUFString(llm, r)
  402. } else {
  403. err = discardGGUFString(llm, r)
  404. }
  405. default:
  406. return nil, fmt.Errorf("invalid array type: %d", t)
  407. }
  408. if err != nil {
  409. return nil, err
  410. }
  411. if a.values != nil {
  412. a.values[i] = e
  413. }
  414. }
  415. return a, nil
  416. }
  417. // writeGGUFArray writes a slice s of type E to the write with a gguf type of t
  418. func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
  419. if err := binary.Write(w, binary.LittleEndian, ggufTypeArray); err != nil {
  420. return err
  421. }
  422. if err := binary.Write(w, binary.LittleEndian, t); err != nil {
  423. return err
  424. }
  425. if err := binary.Write(w, binary.LittleEndian, uint64(len(s))); err != nil {
  426. return err
  427. }
  428. return binary.Write(w, binary.LittleEndian, s)
  429. }
  430. func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor, fn func(api.ProgressResponse)) error {
  431. if err := binary.Write(ws, binary.LittleEndian, []byte("GGUF")); err != nil {
  432. return err
  433. }
  434. if err := binary.Write(ws, binary.LittleEndian, uint32(3)); err != nil {
  435. return err
  436. }
  437. if err := binary.Write(ws, binary.LittleEndian, uint64(len(ts))); err != nil {
  438. return err
  439. }
  440. if err := binary.Write(ws, binary.LittleEndian, uint64(len(kv))); err != nil {
  441. return err
  442. }
  443. keys := maps.Keys(kv)
  444. slices.Sort(keys)
  445. for _, key := range keys {
  446. if err := ggufWriteKV(ws, key, kv[key]); err != nil {
  447. return err
  448. }
  449. }
  450. slices.SortStableFunc(ts, func(a, b Tensor) int {
  451. if i, j := a.block(), b.block(); i < 0 && j > 0 {
  452. return 1
  453. } else if i > 0 && j < 0 {
  454. return -1
  455. } else {
  456. return cmp.Compare(i, j)
  457. }
  458. })
  459. var s uint64
  460. for _, t := range ts {
  461. t.Offset = s
  462. if err := ggufWriteTensorInfo(ws, t); err != nil {
  463. return err
  464. }
  465. s += t.Size()
  466. }
  467. var alignment int64 = 32
  468. for i, t := range ts {
  469. fn(api.ProgressResponse{
  470. Status: fmt.Sprintf("converting model %d%%", 100*(i+1)/len(ts)),
  471. })
  472. if err := ggufWriteTensor(ws, t, alignment); err != nil {
  473. return err
  474. }
  475. }
  476. return nil
  477. }
  478. func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
  479. slog.Debug(k, "type", fmt.Sprintf("%T", v))
  480. if err := binary.Write(ws, binary.LittleEndian, uint64(len(k))); err != nil {
  481. return err
  482. }
  483. if err := binary.Write(ws, binary.LittleEndian, []byte(k)); err != nil {
  484. return err
  485. }
  486. var err error
  487. switch v := v.(type) {
  488. case uint32:
  489. err = writeGGUF(ws, ggufTypeUint32, v)
  490. case float32:
  491. err = writeGGUF(ws, ggufTypeFloat32, v)
  492. case bool:
  493. err = writeGGUF(ws, ggufTypeBool, v)
  494. case string:
  495. err = writeGGUFString(ws, v)
  496. case []int32:
  497. err = writeGGUFArray(ws, ggufTypeInt32, v)
  498. case []uint32:
  499. err = writeGGUFArray(ws, ggufTypeUint32, v)
  500. case []float32:
  501. err = writeGGUFArray(ws, ggufTypeFloat32, v)
  502. case []string:
  503. if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
  504. return err
  505. }
  506. if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
  507. return err
  508. }
  509. if err := binary.Write(ws, binary.LittleEndian, uint64(len(v))); err != nil {
  510. return err
  511. }
  512. for _, e := range v {
  513. if err := binary.Write(ws, binary.LittleEndian, uint64(len(e))); err != nil {
  514. return err
  515. }
  516. if err := binary.Write(ws, binary.LittleEndian, []byte(e)); err != nil {
  517. return err
  518. }
  519. }
  520. default:
  521. return fmt.Errorf("improper type for '%s'", k)
  522. }
  523. return err
  524. }
  525. func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
  526. slog.Debug(t.Name, "kind", t.Kind, "shape", t.Shape, "offset", t.Offset)
  527. if err := binary.Write(ws, binary.LittleEndian, uint64(len(t.Name))); err != nil {
  528. return err
  529. }
  530. if err := binary.Write(ws, binary.LittleEndian, []byte(t.Name)); err != nil {
  531. return err
  532. }
  533. if err := binary.Write(ws, binary.LittleEndian, uint32(len(t.Shape))); err != nil {
  534. return err
  535. }
  536. for i := range len(t.Shape) {
  537. if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil {
  538. return err
  539. }
  540. }
  541. if err := binary.Write(ws, binary.LittleEndian, t.Kind); err != nil {
  542. return err
  543. }
  544. return binary.Write(ws, binary.LittleEndian, t.Offset)
  545. }
  546. func ggufWriteTensor(ws io.WriteSeeker, t Tensor, alignment int64) error {
  547. offset, err := ws.Seek(0, io.SeekCurrent)
  548. if err != nil {
  549. return err
  550. }
  551. if err := binary.Write(ws, binary.LittleEndian, bytes.Repeat([]byte{0}, int(ggufPadding(offset, alignment)))); err != nil {
  552. return err
  553. }
  554. _, err = t.WriteTo(ws)
  555. return err
  556. }
  557. func ggufPadding(offset, align int64) int64 {
  558. return (align - offset%align) % align
  559. }