reader_torch.go 875 B

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. package convert
  2. import (
  3. "io"
  4. "io/fs"
  5. "strings"
  6. "github.com/nlpodyssey/gopickle/pytorch"
  7. "github.com/nlpodyssey/gopickle/types"
  8. )
  9. func parseTorch(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]Tensor, error) {
  10. var ts []Tensor
  11. for _, p := range ps {
  12. pt, err := pytorch.Load(p)
  13. if err != nil {
  14. return nil, err
  15. }
  16. for _, k := range pt.(*types.Dict).Keys() {
  17. t := pt.(*types.Dict).MustGet(k)
  18. var shape []uint64
  19. for dim := range t.(*pytorch.Tensor).Size {
  20. shape = append(shape, uint64(dim))
  21. }
  22. ts = append(ts, torch{
  23. storage: t.(*pytorch.Tensor).Source,
  24. tensorBase: &tensorBase{
  25. name: replacer.Replace(k.(string)),
  26. shape: shape,
  27. },
  28. })
  29. }
  30. }
  31. return ts, nil
  32. }
  33. type torch struct {
  34. storage pytorch.StorageInterface
  35. *tensorBase
  36. }
  37. func (pt torch) WriteTo(w io.Writer) (int64, error) {
  38. return 0, nil
  39. }