123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- package convert
- import (
- "bytes"
- "encoding/binary"
- "encoding/json"
- "fmt"
- "io"
- "io/fs"
- "slices"
- "github.com/d4l3k/go-bfloat16"
- "github.com/x448/float16"
- "golang.org/x/exp/maps"
- )
- type safetensorMetadata struct {
- Type string `json:"dtype"`
- Shape []uint64 `json:"shape"`
- Offsets []int64 `json:"data_offsets"`
- }
- func parseSafetensors(fsys fs.FS, ps ...string) ([]Tensor, error) {
- var ts []Tensor
- for _, p := range ps {
- f, err := fsys.Open(p)
- if err != nil {
- return nil, err
- }
- defer f.Close()
- var n int64
- if err := binary.Read(f, binary.LittleEndian, &n); err != nil {
- return nil, err
- }
- b := bytes.NewBuffer(make([]byte, 0, n))
- if _, err = io.CopyN(b, f, n); err != nil {
- return nil, err
- }
- var headers map[string]safetensorMetadata
- if err := json.NewDecoder(b).Decode(&headers); err != nil {
- return nil, err
- }
- keys := maps.Keys(headers)
- slices.Sort(keys)
- for _, key := range keys {
- if value := headers[key]; value.Type != "" {
- ts = append(ts, safetensor{
- fs: fsys,
- path: p,
- dtype: value.Type,
- offset: safetensorsPad(n, value.Offsets[0]),
- size: safetensorsPad(n, value.Offsets[1]) - safetensorsPad(n, value.Offsets[0]),
- tensorBase: &tensorBase{
- name: key,
- shape: value.Shape,
- },
- })
- }
- }
- }
- return ts, nil
- }
- // safetensorsPad returns the padded size of the safetensors file given a length n and offset s
- func safetensorsPad(n, offset int64) int64 {
- return 8 + n + offset
- }
- type safetensor struct {
- fs fs.FS
- path string
- dtype string
- offset int64
- size int64
- *tensorBase
- }
- func (st safetensor) WriteTo(w io.Writer) (int64, error) {
- f, err := st.fs.Open(st.path)
- if err != nil {
- return 0, err
- }
- defer f.Close()
- if seeker, ok := f.(io.Seeker); ok {
- if _, err := seeker.Seek(st.offset, io.SeekStart); err != nil {
- return 0, err
- }
- } else {
- if _, err := io.CopyN(io.Discard, f, st.offset); err != nil {
- return 0, err
- }
- }
- var f32s []float32
- switch st.dtype {
- case "F32":
- f32s = make([]float32, st.size/4)
- if err = binary.Read(f, binary.LittleEndian, f32s); err != nil {
- return 0, err
- }
- case "F16":
- u16s := make([]uint16, st.size/2)
- if err = binary.Read(f, binary.LittleEndian, u16s); err != nil {
- return 0, err
- }
- f32s = make([]float32, len(u16s))
- for i := range u16s {
- f32s[i] = float16.Frombits(u16s[i]).Float32()
- }
- case "BF16":
- u8s := make([]uint8, st.size)
- if err = binary.Read(f, binary.LittleEndian, u8s); err != nil {
- return 0, err
- }
- f32s = bfloat16.DecodeFloat32(u8s)
- default:
- return 0, fmt.Errorf("unknown data type: %s", st.dtype)
- }
- if st.repacker != nil {
- f32s, err = st.repacker(st.Name(), f32s, st.Shape())
- if err != nil {
- return 0, err
- }
- }
- switch st.Kind() {
- case tensorKindF32:
- return 0, binary.Write(w, binary.LittleEndian, f32s)
- case tensorKindF16:
- f16s := make([]uint16, len(f32s))
- for i := range f32s {
- f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
- }
- return 0, binary.Write(w, binary.LittleEndian, f16s)
- default:
- return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
- }
- }
|