|
@@ -56,22 +56,25 @@ func (t *tensorBase) SetRepacker(fn repacker) {
|
|
type repacker func(string, []float32, []uint64) ([]float32, error)
|
|
type repacker func(string, []float32, []uint64) ([]float32, error)
|
|
|
|
|
|
func parseTensors(fsys fs.FS) ([]Tensor, error) {
|
|
func parseTensors(fsys fs.FS) ([]Tensor, error) {
|
|
- patterns := map[string]func(fs.FS, ...string) ([]Tensor, error){
|
|
|
|
- "model-*-of-*.safetensors": parseSafetensors,
|
|
|
|
- "model.safetensors": parseSafetensors,
|
|
|
|
- "pytorch_model-*-of-*.bin": parseTorch,
|
|
|
|
- "pytorch_model.bin": parseTorch,
|
|
|
|
- "consolidated.*.pth": parseTorch,
|
|
|
|
|
|
+ patterns := []struct {
|
|
|
|
+ Pattern string
|
|
|
|
+ Func func(fs.FS, ...string) ([]Tensor, error)
|
|
|
|
+ }{
|
|
|
|
+ {"model-*-of-*.safetensors", parseSafetensors},
|
|
|
|
+ {"model.safetensors", parseSafetensors},
|
|
|
|
+ {"pytorch_model-*-of-*.bin", parseTorch},
|
|
|
|
+ {"pytorch_model.bin", parseTorch},
|
|
|
|
+ {"consolidated.*.pth", parseTorch},
|
|
}
|
|
}
|
|
|
|
|
|
- for pattern, parseFn := range patterns {
|
|
|
|
- matches, err := fs.Glob(fsys, pattern)
|
|
|
|
|
|
+ for _, pattern := range patterns {
|
|
|
|
+ matches, err := fs.Glob(fsys, pattern.Pattern)
|
|
if err != nil {
|
|
if err != nil {
|
|
return nil, err
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
|
|
if len(matches) > 0 {
|
|
if len(matches) > 0 {
|
|
- return parseFn(fsys, matches...)
|
|
|
|
|
|
+ return pattern.Func(fsys, matches...)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|