reader.go 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. package convert
  2. import (
  3. "errors"
  4. "io"
  5. "io/fs"
  6. "strings"
  7. )
  8. type Tensor interface {
  9. Name() string
  10. Shape() []uint64
  11. Kind() uint32
  12. SetRepacker(repacker)
  13. WriteTo(io.Writer) (int64, error)
  14. }
  15. type tensorBase struct {
  16. name string
  17. shape []uint64
  18. repacker
  19. }
  20. func (t tensorBase) Name() string {
  21. return t.name
  22. }
  23. func (t tensorBase) Shape() []uint64 {
  24. return t.shape
  25. }
  26. const (
  27. tensorKindF32 uint32 = iota
  28. tensorKindF16
  29. )
  30. func (t tensorBase) Kind() uint32 {
  31. if strings.HasSuffix(t.name, ".block_sparse_moe.gate.weight") {
  32. return 0
  33. }
  34. switch len(t.shape) {
  35. case 0:
  36. panic("invalid tensor shape")
  37. case 1:
  38. return tensorKindF32
  39. default:
  40. return tensorKindF16
  41. }
  42. }
  43. func (t *tensorBase) SetRepacker(fn repacker) {
  44. t.repacker = fn
  45. }
  46. type repacker func(string, []float32, []uint64) ([]float32, error)
  47. func parseTensors(fsys fs.FS) ([]Tensor, error) {
  48. patterns := []struct {
  49. Pattern string
  50. Func func(fs.FS, ...string) ([]Tensor, error)
  51. }{
  52. {"model-*-of-*.safetensors", parseSafetensors},
  53. {"model.safetensors", parseSafetensors},
  54. {"pytorch_model-*-of-*.bin", parseTorch},
  55. {"pytorch_model.bin", parseTorch},
  56. {"consolidated.*.pth", parseTorch},
  57. }
  58. for _, pattern := range patterns {
  59. matches, err := fs.Glob(fsys, pattern.Pattern)
  60. if err != nil {
  61. return nil, err
  62. }
  63. if len(matches) > 0 {
  64. return pattern.Func(fsys, matches...)
  65. }
  66. }
  67. return nil, errors.New("unknown tensor format")
  68. }