reader.go 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. package convert
  2. import (
  3. "errors"
  4. "io"
  5. "io/fs"
  6. "strings"
  7. )
  8. type Tensor interface {
  9. Name() string
  10. Shape() []uint64
  11. Kind() uint32
  12. SetRepacker(repacker)
  13. WriteTo(io.Writer) (int64, error)
  14. }
  15. type tensorBase struct {
  16. name string
  17. shape []uint64
  18. repacker
  19. }
  20. func (t tensorBase) Name() string {
  21. return t.name
  22. }
  23. func (t tensorBase) Shape() []uint64 {
  24. return t.shape
  25. }
  26. const (
  27. tensorKindF32 uint32 = iota
  28. tensorKindF16
  29. )
  30. func (t tensorBase) Kind() uint32 {
  31. if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
  32. t.name == "token_types.weight" {
  33. // these tensors are always F32
  34. return 0
  35. }
  36. switch len(t.shape) {
  37. case 0:
  38. panic("invalid tensor shape")
  39. case 1:
  40. return tensorKindF32
  41. default:
  42. return tensorKindF16
  43. }
  44. }
  45. func (t *tensorBase) SetRepacker(fn repacker) {
  46. t.repacker = fn
  47. }
  48. type repacker func(string, []float32, []uint64) ([]float32, error)
  49. func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
  50. patterns := []struct {
  51. Pattern string
  52. Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
  53. }{
  54. {"model-*-of-*.safetensors", parseSafetensors},
  55. {"model.safetensors", parseSafetensors},
  56. {"pytorch_model-*-of-*.bin", parseTorch},
  57. {"pytorch_model.bin", parseTorch},
  58. {"consolidated.*.pth", parseTorch},
  59. }
  60. for _, pattern := range patterns {
  61. matches, err := fs.Glob(fsys, pattern.Pattern)
  62. if err != nil {
  63. return nil, err
  64. }
  65. if len(matches) > 0 {
  66. return pattern.Func(fsys, replacer, matches...)
  67. }
  68. }
  69. return nil, errors.New("unknown tensor format")
  70. }