12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- package convert
- import (
- "errors"
- "io"
- "io/fs"
- "strings"
- )
- type Tensor interface {
- Name() string
- Shape() []uint64
- Kind() uint32
- SetRepacker(repacker)
- WriteTo(io.Writer) (int64, error)
- }
- type tensorBase struct {
- name string
- shape []uint64
- repacker
- }
- func (t tensorBase) Name() string {
- return t.name
- }
- func (t tensorBase) Shape() []uint64 {
- return t.shape
- }
- const (
- tensorKindF32 uint32 = iota
- tensorKindF16
- )
- func (t tensorBase) Kind() uint32 {
- if strings.HasSuffix(t.name, ".block_sparse_moe.gate.weight") {
- return 0
- }
- switch len(t.shape) {
- case 0:
- panic("invalid tensor shape")
- case 1:
- return tensorKindF32
- default:
- return tensorKindF16
- }
- }
- func (t *tensorBase) SetRepacker(fn repacker) {
- t.repacker = fn
- }
- type repacker func(string, []float32, []uint64) ([]float32, error)
- func parseTensors(fsys fs.FS) ([]Tensor, error) {
- patterns := []struct {
- Pattern string
- Func func(fs.FS, ...string) ([]Tensor, error)
- }{
- {"model-*-of-*.safetensors", parseSafetensors},
- {"model.safetensors", parseSafetensors},
- {"pytorch_model-*-of-*.bin", parseTorch},
- {"pytorch_model.bin", parseTorch},
- {"consolidated.*.pth", parseTorch},
- }
- for _, pattern := range patterns {
- matches, err := fs.Glob(fsys, pattern.Pattern)
- if err != nil {
- return nil, err
- }
- if len(matches) > 0 {
- return pattern.Func(fsys, matches...)
- }
- }
- return nil, errors.New("unknown tensor format")
- }
|