123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- package convert
- import (
- "errors"
- "io"
- "io/fs"
- "strings"
- )
- type Tensor interface {
- Name() string
- Shape() []uint64
- Kind() uint32
- SetRepacker(repacker)
- WriteTo(io.Writer) (int64, error)
- }
- type tensorBase struct {
- name string
- shape []uint64
- repacker
- }
- func (t tensorBase) Name() string {
- return t.name
- }
- func (t tensorBase) Shape() []uint64 {
- return t.shape
- }
- const (
- tensorKindF32 uint32 = iota
- tensorKindF16
- )
- func (t tensorBase) Kind() uint32 {
- if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
- t.name == "token_types.weight" {
- // these tensors are always F32
- return 0
- }
- switch len(t.shape) {
- case 0:
- panic("invalid tensor shape")
- case 1:
- return tensorKindF32
- default:
- return tensorKindF16
- }
- }
- func (t *tensorBase) SetRepacker(fn repacker) {
- t.repacker = fn
- }
- type repacker func(string, []float32, []uint64) ([]float32, error)
- func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
- patterns := []struct {
- Pattern string
- Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
- }{
- {"model-*-of-*.safetensors", parseSafetensors},
- {"model.safetensors", parseSafetensors},
- {"pytorch_model-*-of-*.bin", parseTorch},
- {"pytorch_model.bin", parseTorch},
- {"consolidated.*.pth", parseTorch},
- }
- for _, pattern := range patterns {
- matches, err := fs.Glob(fsys, pattern.Pattern)
- if err != nil {
- return nil, err
- }
- if len(matches) > 0 {
- return pattern.Func(fsys, replacer, matches...)
- }
- }
- return nil, errors.New("unknown tensor format")
- }
|