reader_safetensors.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. package convert
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "io/fs"
  10. "slices"
  11. "strings"
  12. "github.com/d4l3k/go-bfloat16"
  13. "github.com/x448/float16"
  14. "golang.org/x/exp/maps"
  15. )
  16. type safetensorMetadata struct {
  17. Type string `json:"dtype"`
  18. Shape []uint64 `json:"shape"`
  19. Offsets []int64 `json:"data_offsets"`
  20. }
  21. func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]Tensor, error) {
  22. var ts []Tensor
  23. for _, p := range ps {
  24. f, err := fsys.Open(p)
  25. if err != nil {
  26. return nil, err
  27. }
  28. defer f.Close()
  29. var n int64
  30. if err := binary.Read(f, binary.LittleEndian, &n); err != nil {
  31. return nil, err
  32. }
  33. b := bytes.NewBuffer(make([]byte, 0, n))
  34. if _, err = io.CopyN(b, f, n); err != nil {
  35. return nil, err
  36. }
  37. var headers map[string]safetensorMetadata
  38. if err := json.NewDecoder(b).Decode(&headers); err != nil {
  39. return nil, err
  40. }
  41. keys := maps.Keys(headers)
  42. slices.Sort(keys)
  43. names := make(map[string]struct{}, len(keys))
  44. for _, key := range keys {
  45. if value := headers[key]; value.Type != "" {
  46. // bitsandbytes quantized models are unsupported
  47. if len(value.Shape) == 0 {
  48. return nil, errors.New("unsupported safetensors model")
  49. }
  50. ggufName := replacer.Replace(key)
  51. if _, ok := names[ggufName]; ok {
  52. return nil, fmt.Errorf("duplicate tensor name '%s' was found for this model", ggufName)
  53. }
  54. names[ggufName] = struct{}{}
  55. ts = append(ts, safetensor{
  56. fs: fsys,
  57. path: p,
  58. dtype: value.Type,
  59. offset: safetensorsPad(n, value.Offsets[0]),
  60. size: safetensorsPad(n, value.Offsets[1]) - safetensorsPad(n, value.Offsets[0]),
  61. tensorBase: &tensorBase{
  62. name: ggufName,
  63. shape: value.Shape,
  64. },
  65. })
  66. }
  67. }
  68. }
  69. return ts, nil
  70. }
  71. // safetensorsPad returns the padded size of the safetensors file given a length n and offset s
  72. func safetensorsPad(n, offset int64) int64 {
  73. return 8 + n + offset
  74. }
  75. type safetensor struct {
  76. fs fs.FS
  77. path string
  78. dtype string
  79. offset int64
  80. size int64
  81. *tensorBase
  82. }
  83. func (st safetensor) WriteTo(w io.Writer) (int64, error) {
  84. f, err := st.fs.Open(st.path)
  85. if err != nil {
  86. return 0, err
  87. }
  88. defer f.Close()
  89. if seeker, ok := f.(io.Seeker); ok {
  90. if _, err := seeker.Seek(st.offset, io.SeekStart); err != nil {
  91. return 0, err
  92. }
  93. } else {
  94. if _, err := io.CopyN(io.Discard, f, st.offset); err != nil {
  95. return 0, err
  96. }
  97. }
  98. var f32s []float32
  99. switch st.dtype {
  100. case "F32":
  101. f32s = make([]float32, st.size/4)
  102. if err = binary.Read(f, binary.LittleEndian, f32s); err != nil {
  103. return 0, err
  104. }
  105. case "F16":
  106. u16s := make([]uint16, st.size/2)
  107. if err = binary.Read(f, binary.LittleEndian, u16s); err != nil {
  108. return 0, err
  109. }
  110. f32s = make([]float32, len(u16s))
  111. for i := range u16s {
  112. f32s[i] = float16.Frombits(u16s[i]).Float32()
  113. }
  114. case "BF16":
  115. u8s := make([]uint8, st.size)
  116. if err = binary.Read(f, binary.LittleEndian, u8s); err != nil {
  117. return 0, err
  118. }
  119. f32s = bfloat16.DecodeFloat32(u8s)
  120. default:
  121. return 0, fmt.Errorf("unknown data type: %s", st.dtype)
  122. }
  123. if st.repacker != nil {
  124. f32s, err = st.repacker(st.Name(), f32s, st.Shape())
  125. if err != nil {
  126. return 0, err
  127. }
  128. }
  129. switch st.Kind() {
  130. case tensorKindF32:
  131. return 0, binary.Write(w, binary.LittleEndian, f32s)
  132. case tensorKindF16:
  133. f16s := make([]uint16, len(f32s))
  134. for i := range f32s {
  135. f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
  136. }
  137. return 0, binary.Write(w, binary.LittleEndian, f16s)
  138. default:
  139. return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
  140. }
  141. }