123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662 |
- package llm
- import (
- "bytes"
- "cmp"
- "encoding/binary"
- "encoding/json"
- "fmt"
- "io"
- "log/slog"
- "slices"
- "strings"
- "golang.org/x/exp/maps"
- )
- type containerGGUF struct {
- ByteOrder binary.ByteOrder
- Version uint32
- V1 struct {
- NumTensor uint32
- NumKV uint32
- }
- V2 struct {
- NumTensor uint64
- NumKV uint64
- }
- V3 struct {
- NumTensor uint64
- NumKV uint64
- }
- maxArraySize int
- }
- func (c *containerGGUF) canCollectArray(size int) bool {
- return c.maxArraySize < 0 || size <= c.maxArraySize
- }
- func (c *containerGGUF) Name() string {
- return "gguf"
- }
- func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
- if err := binary.Read(rs, c.ByteOrder, &c.Version); err != nil {
- return nil, err
- }
- var err error
- switch c.Version {
- case 1:
- err = binary.Read(rs, c.ByteOrder, &c.V1)
- case 2:
- err = binary.Read(rs, c.ByteOrder, &c.V2)
- default:
- err = binary.Read(rs, c.ByteOrder, &c.V3)
- }
- if err != nil {
- return nil, err
- }
- model := newGGUF(c)
- if err := model.Decode(rs); err != nil {
- return nil, err
- }
- return model, nil
- }
- const (
- ggufTypeUint8 uint32 = iota
- ggufTypeInt8
- ggufTypeUint16
- ggufTypeInt16
- ggufTypeUint32
- ggufTypeInt32
- ggufTypeFloat32
- ggufTypeBool
- ggufTypeString
- ggufTypeArray
- ggufTypeUint64
- ggufTypeInt64
- ggufTypeFloat64
- )
- type gguf struct {
- *containerGGUF
- kv KV
- tensors []*Tensor
- parameters uint64
- tensorOffset uint64
- scratch [16 << 10]byte
- }
- func newGGUF(container *containerGGUF) *gguf {
- return &gguf{
- containerGGUF: container,
- kv: make(KV),
- }
- }
- func (llm *gguf) KV() KV {
- return llm.kv
- }
- func (llm *gguf) Tensors() *Tensors {
- return &Tensors{
- Items: llm.tensors,
- Offset: llm.tensorOffset,
- }
- }
- func (llm *gguf) numTensor() uint64 {
- switch llm.Version {
- case 1:
- return uint64(llm.V1.NumTensor)
- case 2:
- return llm.V2.NumTensor
- default:
- return llm.V3.NumTensor
- }
- }
- func (llm *gguf) numKV() uint64 {
- switch llm.Version {
- case 1:
- return uint64(llm.V1.NumKV)
- case 2:
- return llm.V2.NumKV
- default:
- return llm.V3.NumKV
- }
- }
- func (llm *gguf) Decode(rs io.ReadSeeker) error {
- // decode key-values
- for i := 0; uint64(i) < llm.numKV(); i++ {
- k, err := readGGUFString(llm, rs)
- if err != nil {
- return err
- }
- t, err := readGGUF[uint32](llm, rs)
- if err != nil {
- return err
- }
- var v any
- switch t {
- case ggufTypeUint8:
- v, err = readGGUF[uint8](llm, rs)
- case ggufTypeInt8:
- v, err = readGGUF[int8](llm, rs)
- case ggufTypeUint16:
- v, err = readGGUF[uint16](llm, rs)
- case ggufTypeInt16:
- v, err = readGGUF[int16](llm, rs)
- case ggufTypeUint32:
- v, err = readGGUF[uint32](llm, rs)
- case ggufTypeInt32:
- v, err = readGGUF[int32](llm, rs)
- case ggufTypeUint64:
- v, err = readGGUF[uint64](llm, rs)
- case ggufTypeInt64:
- v, err = readGGUF[int64](llm, rs)
- case ggufTypeFloat32:
- v, err = readGGUF[float32](llm, rs)
- case ggufTypeFloat64:
- v, err = readGGUF[float64](llm, rs)
- case ggufTypeBool:
- v, err = readGGUF[bool](llm, rs)
- case ggufTypeString:
- v, err = readGGUFString(llm, rs)
- case ggufTypeArray:
- v, err = readGGUFArray(llm, rs)
- default:
- return fmt.Errorf("invalid type: %d", t)
- }
- if err != nil {
- return err
- }
- llm.kv[k] = v
- }
- // decode tensors
- for range llm.numTensor() {
- name, err := readGGUFString(llm, rs)
- if err != nil {
- return fmt.Errorf("failed to read tensor name: %w", err)
- }
- // dims is the number of dimensions in the tensor
- dims, err := readGGUF[uint32](llm, rs)
- if err != nil {
- return fmt.Errorf("failed to read tensor dimensions: %w", err)
- }
- shape := make([]uint64, dims)
- for i := 0; uint32(i) < dims; i++ {
- shape[i], err = readGGUF[uint64](llm, rs)
- if err != nil {
- return fmt.Errorf("failed to read tensor shape: %w", err)
- }
- }
- kind, err := readGGUF[uint32](llm, rs)
- if err != nil {
- return fmt.Errorf("failed to read tensor kind: %w", err)
- }
- offset, err := readGGUF[uint64](llm, rs)
- if err != nil {
- return fmt.Errorf("failed to read tensor offset: %w", err)
- }
- tensor := Tensor{
- Name: name,
- Kind: kind,
- Offset: offset,
- Shape: shape[:],
- }
- llm.tensors = append(llm.tensors, &tensor)
- llm.parameters += tensor.parameters()
- }
- // patch KV with parameter count
- llm.kv["general.parameter_count"] = llm.parameters
- alignment, ok := llm.kv["general.alignment"].(uint32)
- if !ok {
- alignment = 32
- }
- offset, err := rs.Seek(0, io.SeekCurrent)
- if err != nil {
- return err
- }
- padding := ggufPadding(offset, int64(alignment))
- llm.tensorOffset = uint64(offset + padding)
- for _, tensor := range llm.tensors {
- offset, err := rs.Seek(0, io.SeekCurrent)
- if err != nil {
- return fmt.Errorf("failed to get current offset: %w", err)
- }
- padding := ggufPadding(offset, int64(alignment))
- if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
- return fmt.Errorf("failed to seek to init padding: %w", err)
- }
- if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil {
- return fmt.Errorf("failed to seek to tensor: %w", err)
- }
- }
- return nil
- }
- func readGGUF[T any](llm *gguf, r io.Reader) (T, error) {
- var t T
- err := binary.Read(r, llm.ByteOrder, &t)
- return t, err
- }
- func writeGGUF[V any](w io.Writer, t uint32, v V) error {
- if err := binary.Write(w, binary.LittleEndian, t); err != nil {
- return err
- }
- return binary.Write(w, binary.LittleEndian, v)
- }
- func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
- var length uint64
- if err := binary.Read(r, llm.ByteOrder, &length); err != nil {
- return "", err
- }
- var b bytes.Buffer
- if _, err := io.CopyN(&b, r, int64(length)); err != nil {
- return "", err
- }
- // gguf v1 strings are null-terminated
- b.Truncate(b.Len() - 1)
- return b.String(), nil
- }
- func discardGGUFString(llm *gguf, r io.Reader) error {
- buf := llm.scratch[:8]
- _, err := io.ReadFull(r, buf)
- if err != nil {
- return err
- }
- size := int(llm.ByteOrder.Uint64(buf))
- for size > 0 {
- n, err := r.Read(llm.scratch[:min(size, cap(llm.scratch))])
- if err != nil {
- return err
- }
- size -= n
- }
- return nil
- }
- func readGGUFString(llm *gguf, r io.Reader) (string, error) {
- if llm.Version == 1 {
- return readGGUFV1String(llm, r)
- }
- buf := llm.scratch[:8]
- _, err := io.ReadFull(r, buf)
- if err != nil {
- return "", err
- }
- length := int(llm.ByteOrder.Uint64(buf))
- if length > len(llm.scratch) {
- buf = make([]byte, length)
- } else {
- buf = llm.scratch[:length]
- }
- clear(buf)
- _, err = io.ReadFull(r, buf)
- if err != nil {
- return "", err
- }
- return string(buf), nil
- }
- func writeGGUFString(w io.Writer, s string) error {
- if err := binary.Write(w, binary.LittleEndian, ggufTypeString); err != nil {
- return err
- }
- if err := binary.Write(w, binary.LittleEndian, uint64(len(s))); err != nil {
- return err
- }
- _, err := io.Copy(w, strings.NewReader(s))
- return err
- }
- type array struct {
- size int
- values []any
- }
- func (a *array) MarshalJSON() ([]byte, error) {
- return json.Marshal(a.values)
- }
- func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
- t, err := readGGUF[uint32](llm, r)
- if err != nil {
- return nil, err
- }
- n, err := readGGUF[uint32](llm, r)
- if err != nil {
- return nil, err
- }
- a := &array{size: int(n)}
- if llm.canCollectArray(int(n)) {
- a.values = make([]any, 0, int(n))
- }
- for i := range n {
- var e any
- switch t {
- case ggufTypeUint8:
- e, err = readGGUF[uint8](llm, r)
- case ggufTypeInt8:
- e, err = readGGUF[int8](llm, r)
- case ggufTypeUint16:
- e, err = readGGUF[uint16](llm, r)
- case ggufTypeInt16:
- e, err = readGGUF[int16](llm, r)
- case ggufTypeUint32:
- e, err = readGGUF[uint32](llm, r)
- case ggufTypeInt32:
- e, err = readGGUF[int32](llm, r)
- case ggufTypeUint64:
- e, err = readGGUF[uint64](llm, r)
- case ggufTypeInt64:
- e, err = readGGUF[int64](llm, r)
- case ggufTypeFloat32:
- e, err = readGGUF[float32](llm, r)
- case ggufTypeFloat64:
- e, err = readGGUF[float64](llm, r)
- case ggufTypeBool:
- e, err = readGGUF[bool](llm, r)
- case ggufTypeString:
- e, err = readGGUFV1String(llm, r)
- default:
- return nil, fmt.Errorf("invalid array type: %d", t)
- }
- if err != nil {
- return nil, err
- }
- if a.values != nil {
- a.values[i] = e
- }
- }
- return a, nil
- }
- func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
- if llm.Version == 1 {
- return readGGUFV1Array(llm, r)
- }
- t, err := readGGUF[uint32](llm, r)
- if err != nil {
- return nil, err
- }
- n, err := readGGUF[uint64](llm, r)
- if err != nil {
- return nil, err
- }
- a := &array{size: int(n)}
- if llm.canCollectArray(int(n)) {
- a.values = make([]any, int(n))
- }
- for i := range n {
- var e any
- switch t {
- case ggufTypeUint8:
- e, err = readGGUF[uint8](llm, r)
- case ggufTypeInt8:
- e, err = readGGUF[int8](llm, r)
- case ggufTypeUint16:
- e, err = readGGUF[uint16](llm, r)
- case ggufTypeInt16:
- e, err = readGGUF[int16](llm, r)
- case ggufTypeUint32:
- e, err = readGGUF[uint32](llm, r)
- case ggufTypeInt32:
- e, err = readGGUF[int32](llm, r)
- case ggufTypeUint64:
- e, err = readGGUF[uint64](llm, r)
- case ggufTypeInt64:
- e, err = readGGUF[int64](llm, r)
- case ggufTypeFloat32:
- e, err = readGGUF[float32](llm, r)
- case ggufTypeFloat64:
- e, err = readGGUF[float64](llm, r)
- case ggufTypeBool:
- e, err = readGGUF[bool](llm, r)
- case ggufTypeString:
- if a.values != nil {
- e, err = readGGUFString(llm, r)
- } else {
- err = discardGGUFString(llm, r)
- }
- default:
- return nil, fmt.Errorf("invalid array type: %d", t)
- }
- if err != nil {
- return nil, err
- }
- if a.values != nil {
- a.values[i] = e
- }
- }
- return a, nil
- }
- // writeGGUFArray writes a slice s of type E to the write with a gguf type of t
- func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
- if err := binary.Write(w, binary.LittleEndian, ggufTypeArray); err != nil {
- return err
- }
- if err := binary.Write(w, binary.LittleEndian, t); err != nil {
- return err
- }
- if err := binary.Write(w, binary.LittleEndian, uint64(len(s))); err != nil {
- return err
- }
- return binary.Write(w, binary.LittleEndian, s)
- }
- func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
- if err := binary.Write(ws, binary.LittleEndian, []byte("GGUF")); err != nil {
- return err
- }
- if err := binary.Write(ws, binary.LittleEndian, uint32(3)); err != nil {
- return err
- }
- if err := binary.Write(ws, binary.LittleEndian, uint64(len(ts))); err != nil {
- return err
- }
- if err := binary.Write(ws, binary.LittleEndian, uint64(len(kv))); err != nil {
- return err
- }
- keys := maps.Keys(kv)
- slices.Sort(keys)
- for _, key := range keys {
- if err := ggufWriteKV(ws, key, kv[key]); err != nil {
- return err
- }
- }
- slices.SortStableFunc(ts, func(a, b Tensor) int {
- if i, j := a.block(), b.block(); i < 0 && j > 0 {
- return 1
- } else if i > 0 && j < 0 {
- return -1
- } else {
- return cmp.Compare(i, j)
- }
- })
- var s uint64
- for _, t := range ts {
- t.Offset = s
- if err := ggufWriteTensorInfo(ws, t); err != nil {
- return err
- }
- s += t.Size()
- }
- var alignment int64 = 32
- for _, t := range ts {
- if err := ggufWriteTensor(ws, t, alignment); err != nil {
- return err
- }
- }
- return nil
- }
- func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
- slog.Debug(k, "type", fmt.Sprintf("%T", v))
- if err := binary.Write(ws, binary.LittleEndian, uint64(len(k))); err != nil {
- return err
- }
- if err := binary.Write(ws, binary.LittleEndian, []byte(k)); err != nil {
- return err
- }
- var err error
- switch v := v.(type) {
- case uint32:
- err = writeGGUF(ws, ggufTypeUint32, v)
- case float32:
- err = writeGGUF(ws, ggufTypeFloat32, v)
- case bool:
- err = writeGGUF(ws, ggufTypeBool, v)
- case string:
- err = writeGGUFString(ws, v)
- case []int32:
- err = writeGGUFArray(ws, ggufTypeInt32, v)
- case []uint32:
- err = writeGGUFArray(ws, ggufTypeUint32, v)
- case []float32:
- err = writeGGUFArray(ws, ggufTypeFloat32, v)
- case []string:
- if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
- return err
- }
- if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
- return err
- }
- if err := binary.Write(ws, binary.LittleEndian, uint64(len(v))); err != nil {
- return err
- }
- for _, e := range v {
- if err := binary.Write(ws, binary.LittleEndian, uint64(len(e))); err != nil {
- return err
- }
- if err := binary.Write(ws, binary.LittleEndian, []byte(e)); err != nil {
- return err
- }
- }
- default:
- return fmt.Errorf("improper type for '%s'", k)
- }
- return err
- }
- func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
- slog.Debug(t.Name, "kind", t.Kind, "shape", t.Shape, "offset", t.Offset)
- if err := binary.Write(ws, binary.LittleEndian, uint64(len(t.Name))); err != nil {
- return err
- }
- if err := binary.Write(ws, binary.LittleEndian, []byte(t.Name)); err != nil {
- return err
- }
- if err := binary.Write(ws, binary.LittleEndian, uint32(len(t.Shape))); err != nil {
- return err
- }
- for i := range len(t.Shape) {
- if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil {
- return err
- }
- }
- if err := binary.Write(ws, binary.LittleEndian, t.Kind); err != nil {
- return err
- }
- return binary.Write(ws, binary.LittleEndian, t.Offset)
- }
- func ggufWriteTensor(ws io.WriteSeeker, t Tensor, alignment int64) error {
- offset, err := ws.Seek(0, io.SeekCurrent)
- if err != nil {
- return err
- }
- if err := binary.Write(ws, binary.LittleEndian, bytes.Repeat([]byte{0}, int(ggufPadding(offset, alignment)))); err != nil {
- return err
- }
- _, err = t.WriteTo(ws)
- return err
- }
- func ggufPadding(offset, align int64) int64 {
- return (align - offset%align) % align
- }
|