reader_torch.go 818 B

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. package convert
  2. import (
  3. "io"
  4. "io/fs"
  5. "github.com/nlpodyssey/gopickle/pytorch"
  6. "github.com/nlpodyssey/gopickle/types"
  7. )
  8. func parseTorch(fsys fs.FS, ps ...string) ([]Tensor, error) {
  9. var ts []Tensor
  10. for _, p := range ps {
  11. pt, err := pytorch.Load(p)
  12. if err != nil {
  13. return nil, err
  14. }
  15. for _, k := range pt.(*types.Dict).Keys() {
  16. t := pt.(*types.Dict).MustGet(k)
  17. var shape []uint64
  18. for dim := range t.(*pytorch.Tensor).Size {
  19. shape = append(shape, uint64(dim))
  20. }
  21. ts = append(ts, torch{
  22. storage: t.(*pytorch.Tensor).Source,
  23. tensorBase: &tensorBase{
  24. name: k.(string),
  25. shape: shape,
  26. },
  27. })
  28. }
  29. }
  30. return ts, nil
  31. }
  32. type torch struct {
  33. storage pytorch.StorageInterface
  34. *tensorBase
  35. }
  36. func (pt torch) WriteTo(w io.Writer) (int64, error) {
  37. return 0, nil
  38. }