瀏覽代碼

convert: fix parse functions

Michael Yang 9 月之前
父節點
當前提交
d8e2664c33
共有 2 個文件被更改,包括 21 次插入15 次删除
  1. 12 9
      convert/reader.go
  2. 9 6
      convert/tokenizer.go

+ 12 - 9
convert/reader.go

@@ -56,22 +56,25 @@ func (t *tensorBase) SetRepacker(fn repacker) {
 type repacker func(string, []float32, []uint64) ([]float32, error)
 
 func parseTensors(fsys fs.FS) ([]Tensor, error) {
-	patterns := map[string]func(fs.FS, ...string) ([]Tensor, error){
-		"model-*-of-*.safetensors": parseSafetensors,
-		"model.safetensors":        parseSafetensors,
-		"pytorch_model-*-of-*.bin": parseTorch,
-		"pytorch_model.bin":        parseTorch,
-		"consolidated.*.pth":       parseTorch,
+	patterns := []struct {
+		Pattern string
+		Func    func(fs.FS, ...string) ([]Tensor, error)
+	}{
+		{"model-*-of-*.safetensors", parseSafetensors},
+		{"model.safetensors", parseSafetensors},
+		{"pytorch_model-*-of-*.bin", parseTorch},
+		{"pytorch_model.bin", parseTorch},
+		{"consolidated.*.pth", parseTorch},
 	}
 
-	for pattern, parseFn := range patterns {
-		matches, err := fs.Glob(fsys, pattern)
+	for _, pattern := range patterns {
+		matches, err := fs.Glob(fsys, pattern.Pattern)
 		if err != nil {
 			return nil, err
 		}
 
 		if len(matches) > 0 {
-			return parseFn(fsys, matches...)
+			return pattern.Func(fsys, matches...)
 		}
 	}
 

+ 9 - 6
convert/tokenizer.go

@@ -220,19 +220,22 @@ func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
 }
 
 func parseVocabulary(fsys fs.FS) (*Vocabulary, error) {
-	patterns := map[string]func(fs.FS) (*Vocabulary, error){
-		"tokenizer.model": parseSentencePiece,
-		"tokenizer.json":  parseVocabularyFromTokenizer,
+	patterns := []struct {
+		Pattern string
+		Func    func(fs.FS) (*Vocabulary, error)
+	}{
+		{"tokenizer.model", parseSentencePiece},
+		{"tokenizer.json", parseVocabularyFromTokenizer},
 	}
 
-	for pattern, parseFn := range patterns {
-		if _, err := fs.Stat(fsys, pattern); errors.Is(err, os.ErrNotExist) {
+	for _, pattern := range patterns {
+		if _, err := fs.Stat(fsys, pattern.Pattern); errors.Is(err, os.ErrNotExist) {
 			continue
 		} else if err != nil {
 			return nil, err
 		}
 
-		return parseFn(fsys)
+		return pattern.Func(fsys)
 	}
 
 	return nil, errors.New("unknown tensor format")