reader.go 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. package convert
  2. import (
  3. "errors"
  4. "io"
  5. "io/fs"
  6. "strings"
  7. )
  8. type Tensor interface {
  9. Name() string
  10. Shape() []uint64
  11. Kind() uint32
  12. SetRepacker(repacker)
  13. WriteTo(io.Writer) (int64, error)
  14. }
  15. type tensorBase struct {
  16. name string
  17. shape []uint64
  18. repacker
  19. }
  20. func (t tensorBase) Name() string {
  21. return t.name
  22. }
  23. func (t tensorBase) Shape() []uint64 {
  24. return t.shape
  25. }
  26. const (
  27. tensorKindF32 uint32 = iota
  28. tensorKindF16
  29. )
  30. func (t tensorBase) Kind() uint32 {
  31. if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
  32. t.name == "token_types.weight" {
  33. // these tensors are always F32
  34. return 0
  35. }
  36. switch len(t.shape) {
  37. case 0:
  38. panic("invalid tensor shape")
  39. case 1:
  40. return tensorKindF32
  41. default:
  42. return tensorKindF16
  43. }
  44. }
  45. func (t *tensorBase) SetRepacker(fn repacker) {
  46. t.repacker = fn
  47. }
  48. type repacker func(string, []float32, []uint64) ([]float32, error)
  49. func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
  50. patterns := []struct {
  51. Pattern string
  52. Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
  53. }{
  54. {"model-*-of-*.safetensors", parseSafetensors},
  55. {"model.safetensors", parseSafetensors},
  56. {"adapters.safetensors", parseSafetensors},
  57. {"adapter_model.safetensors", parseSafetensors},
  58. {"pytorch_model-*-of-*.bin", parseTorch},
  59. {"pytorch_model.bin", parseTorch},
  60. {"consolidated.*.pth", parseTorch},
  61. }
  62. for _, pattern := range patterns {
  63. matches, err := fs.Glob(fsys, pattern.Pattern)
  64. if err != nil {
  65. return nil, err
  66. }
  67. if len(matches) > 0 {
  68. return pattern.Func(fsys, replacer, matches...)
  69. }
  70. }
  71. return nil, errors.New("unknown tensor format")
  72. }