Przeglądaj źródła

convert: fix parse functions

Michael Yang 9 miesięcy temu
rodzic
commit
d8e2664c33
2 zmienionych plików z 21 dodań i 15 usunięć
  1. 12 9
      convert/reader.go
  2. 9 6
      convert/tokenizer.go

+ 12 - 9
convert/reader.go

@@ -56,22 +56,25 @@ func (t *tensorBase) SetRepacker(fn repacker) {
 type repacker func(string, []float32, []uint64) ([]float32, error)
 type repacker func(string, []float32, []uint64) ([]float32, error)
 
 
 func parseTensors(fsys fs.FS) ([]Tensor, error) {
 func parseTensors(fsys fs.FS) ([]Tensor, error) {
-	patterns := map[string]func(fs.FS, ...string) ([]Tensor, error){
-		"model-*-of-*.safetensors": parseSafetensors,
-		"model.safetensors":        parseSafetensors,
-		"pytorch_model-*-of-*.bin": parseTorch,
-		"pytorch_model.bin":        parseTorch,
-		"consolidated.*.pth":       parseTorch,
+	patterns := []struct {
+		Pattern string
+		Func    func(fs.FS, ...string) ([]Tensor, error)
+	}{
+		{"model-*-of-*.safetensors", parseSafetensors},
+		{"model.safetensors", parseSafetensors},
+		{"pytorch_model-*-of-*.bin", parseTorch},
+		{"pytorch_model.bin", parseTorch},
+		{"consolidated.*.pth", parseTorch},
 	}
 	}
 
 
-	for pattern, parseFn := range patterns {
-		matches, err := fs.Glob(fsys, pattern)
+	for _, pattern := range patterns {
+		matches, err := fs.Glob(fsys, pattern.Pattern)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 
 
 		if len(matches) > 0 {
 		if len(matches) > 0 {
-			return parseFn(fsys, matches...)
+			return pattern.Func(fsys, matches...)
 		}
 		}
 	}
 	}
 
 

+ 9 - 6
convert/tokenizer.go

@@ -220,19 +220,22 @@ func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
 }
 }
 
 
 func parseVocabulary(fsys fs.FS) (*Vocabulary, error) {
 func parseVocabulary(fsys fs.FS) (*Vocabulary, error) {
-	patterns := map[string]func(fs.FS) (*Vocabulary, error){
-		"tokenizer.model": parseSentencePiece,
-		"tokenizer.json":  parseVocabularyFromTokenizer,
+	patterns := []struct {
+		Pattern string
+		Func    func(fs.FS) (*Vocabulary, error)
+	}{
+		{"tokenizer.model", parseSentencePiece},
+		{"tokenizer.json", parseVocabularyFromTokenizer},
 	}
 	}
 
 
-	for pattern, parseFn := range patterns {
-		if _, err := fs.Stat(fsys, pattern); errors.Is(err, os.ErrNotExist) {
+	for _, pattern := range patterns {
+		if _, err := fs.Stat(fsys, pattern.Pattern); errors.Is(err, os.ErrNotExist) {
 			continue
 			continue
 		} else if err != nil {
 		} else if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 
 
-		return parseFn(fsys)
+		return pattern.Func(fsys)
 	}
 	}
 
 
 	return nil, errors.New("unknown tensor format")
 	return nil, errors.New("unknown tensor format")