reader_safetensors.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. package convert
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "io/fs"
  10. "slices"
  11. "strings"
  12. "github.com/d4l3k/go-bfloat16"
  13. "github.com/x448/float16"
  14. "golang.org/x/exp/maps"
  15. )
  16. type safetensorMetadata struct {
  17. Type string `json:"dtype"`
  18. Shape []uint64 `json:"shape"`
  19. Offsets []int64 `json:"data_offsets"`
  20. }
  21. func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]Tensor, error) {
  22. var ts []Tensor
  23. for _, p := range ps {
  24. f, err := fsys.Open(p)
  25. if err != nil {
  26. return nil, err
  27. }
  28. defer f.Close()
  29. var n int64
  30. if err := binary.Read(f, binary.LittleEndian, &n); err != nil {
  31. return nil, err
  32. }
  33. b := bytes.NewBuffer(make([]byte, 0, n))
  34. if _, err = io.CopyN(b, f, n); err != nil {
  35. return nil, err
  36. }
  37. var headers map[string]safetensorMetadata
  38. if err := json.NewDecoder(b).Decode(&headers); err != nil {
  39. return nil, err
  40. }
  41. keys := maps.Keys(headers)
  42. slices.Sort(keys)
  43. for _, key := range keys {
  44. if value := headers[key]; value.Type != "" {
  45. // bitsandbytes quantized models are unsupported
  46. if len(value.Shape) == 0 {
  47. return nil, errors.New("unsupported safetensors model")
  48. }
  49. ts = append(ts, safetensor{
  50. fs: fsys,
  51. path: p,
  52. dtype: value.Type,
  53. offset: safetensorsPad(n, value.Offsets[0]),
  54. size: safetensorsPad(n, value.Offsets[1]) - safetensorsPad(n, value.Offsets[0]),
  55. tensorBase: &tensorBase{
  56. name: replacer.Replace(key),
  57. shape: value.Shape,
  58. },
  59. })
  60. }
  61. }
  62. }
  63. return ts, nil
  64. }
  65. // safetensorsPad returns the padded size of the safetensors file given a length n and offset s
  66. func safetensorsPad(n, offset int64) int64 {
  67. return 8 + n + offset
  68. }
  69. type safetensor struct {
  70. fs fs.FS
  71. path string
  72. dtype string
  73. offset int64
  74. size int64
  75. *tensorBase
  76. }
  77. func (st safetensor) WriteTo(w io.Writer) (int64, error) {
  78. f, err := st.fs.Open(st.path)
  79. if err != nil {
  80. return 0, err
  81. }
  82. defer f.Close()
  83. if seeker, ok := f.(io.Seeker); ok {
  84. if _, err := seeker.Seek(st.offset, io.SeekStart); err != nil {
  85. return 0, err
  86. }
  87. } else {
  88. if _, err := io.CopyN(io.Discard, f, st.offset); err != nil {
  89. return 0, err
  90. }
  91. }
  92. var f32s []float32
  93. switch st.dtype {
  94. case "F32":
  95. f32s = make([]float32, st.size/4)
  96. if err = binary.Read(f, binary.LittleEndian, f32s); err != nil {
  97. return 0, err
  98. }
  99. case "F16":
  100. u16s := make([]uint16, st.size/2)
  101. if err = binary.Read(f, binary.LittleEndian, u16s); err != nil {
  102. return 0, err
  103. }
  104. f32s = make([]float32, len(u16s))
  105. for i := range u16s {
  106. f32s[i] = float16.Frombits(u16s[i]).Float32()
  107. }
  108. case "BF16":
  109. u8s := make([]uint8, st.size)
  110. if err = binary.Read(f, binary.LittleEndian, u8s); err != nil {
  111. return 0, err
  112. }
  113. f32s = bfloat16.DecodeFloat32(u8s)
  114. default:
  115. return 0, fmt.Errorf("unknown data type: %s", st.dtype)
  116. }
  117. if st.repacker != nil {
  118. f32s, err = st.repacker(st.Name(), f32s, st.Shape())
  119. if err != nil {
  120. return 0, err
  121. }
  122. }
  123. switch st.Kind() {
  124. case tensorKindF32:
  125. return 0, binary.Write(w, binary.LittleEndian, f32s)
  126. case tensorKindF16:
  127. f16s := make([]uint16, len(f32s))
  128. for i := range f32s {
  129. f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
  130. }
  131. return 0, binary.Write(w, binary.LittleEndian, f16s)
  132. default:
  133. return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
  134. }
  135. }