reader_safetensors.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. package convert
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "io/fs"
  9. "slices"
  10. "github.com/d4l3k/go-bfloat16"
  11. "github.com/x448/float16"
  12. "golang.org/x/exp/maps"
  13. )
  14. type safetensorMetadata struct {
  15. Type string `json:"dtype"`
  16. Shape []uint64 `json:"shape"`
  17. Offsets []int64 `json:"data_offsets"`
  18. }
  19. func parseSafetensors(fsys fs.FS, ps ...string) ([]Tensor, error) {
  20. var ts []Tensor
  21. for _, p := range ps {
  22. f, err := fsys.Open(p)
  23. if err != nil {
  24. return nil, err
  25. }
  26. defer f.Close()
  27. var n int64
  28. if err := binary.Read(f, binary.LittleEndian, &n); err != nil {
  29. return nil, err
  30. }
  31. b := bytes.NewBuffer(make([]byte, 0, n))
  32. if _, err = io.CopyN(b, f, n); err != nil {
  33. return nil, err
  34. }
  35. var headers map[string]safetensorMetadata
  36. if err := json.NewDecoder(b).Decode(&headers); err != nil {
  37. return nil, err
  38. }
  39. keys := maps.Keys(headers)
  40. slices.Sort(keys)
  41. for _, key := range keys {
  42. if value := headers[key]; value.Type != "" {
  43. ts = append(ts, safetensor{
  44. fs: fsys,
  45. path: p,
  46. dtype: value.Type,
  47. offset: safetensorsPad(n, value.Offsets[0]),
  48. size: safetensorsPad(n, value.Offsets[1]) - safetensorsPad(n, value.Offsets[0]),
  49. tensorBase: &tensorBase{
  50. name: key,
  51. shape: value.Shape,
  52. },
  53. })
  54. }
  55. }
  56. }
  57. return ts, nil
  58. }
  59. // safetensorsPad returns the padded size of the safetensors file given a length n and offset s
  60. func safetensorsPad(n, offset int64) int64 {
  61. return 8 + n + offset
  62. }
  63. type safetensor struct {
  64. fs fs.FS
  65. path string
  66. dtype string
  67. offset int64
  68. size int64
  69. *tensorBase
  70. }
  71. func (st safetensor) WriteTo(w io.Writer) (int64, error) {
  72. f, err := st.fs.Open(st.path)
  73. if err != nil {
  74. return 0, err
  75. }
  76. defer f.Close()
  77. if seeker, ok := f.(io.Seeker); ok {
  78. if _, err := seeker.Seek(st.offset, io.SeekStart); err != nil {
  79. return 0, err
  80. }
  81. } else {
  82. if _, err := io.CopyN(io.Discard, f, st.offset); err != nil {
  83. return 0, err
  84. }
  85. }
  86. var f32s []float32
  87. switch st.dtype {
  88. case "F32":
  89. f32s = make([]float32, st.size/4)
  90. if err = binary.Read(f, binary.LittleEndian, f32s); err != nil {
  91. return 0, err
  92. }
  93. case "F16":
  94. u16s := make([]uint16, st.size/2)
  95. if err = binary.Read(f, binary.LittleEndian, u16s); err != nil {
  96. return 0, err
  97. }
  98. f32s = make([]float32, len(u16s))
  99. for i := range u16s {
  100. f32s[i] = float16.Frombits(u16s[i]).Float32()
  101. }
  102. case "BF16":
  103. u8s := make([]uint8, st.size)
  104. if err = binary.Read(f, binary.LittleEndian, u8s); err != nil {
  105. return 0, err
  106. }
  107. f32s = bfloat16.DecodeFloat32(u8s)
  108. default:
  109. return 0, fmt.Errorf("unknown data type: %s", st.dtype)
  110. }
  111. if st.repacker != nil {
  112. f32s, err = st.repacker(st.Name(), f32s, st.Shape())
  113. if err != nil {
  114. return 0, err
  115. }
  116. }
  117. switch st.Kind() {
  118. case tensorKindF32:
  119. return 0, binary.Write(w, binary.LittleEndian, f32s)
  120. case tensorKindF16:
  121. f16s := make([]uint16, len(f32s))
  122. for i := range f32s {
  123. f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
  124. }
  125. return 0, binary.Write(w, binary.LittleEndian, f16s)
  126. default:
  127. return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
  128. }
  129. }