浏览代码

catch when model vocab size is set correctly (#6714)

Patrick Devine 7 月之前
父节点
当前提交
84b84ce2db
共有 1 个文件被更改,包括 7 次插入3 次删除
  1. 7 3
      convert/convert.go

+ 7 - 3
convert/convert.go

@@ -208,14 +208,18 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
 		return err
 	}
 
-	if vocabSize := int(p.VocabSize); vocabSize > len(t.Vocabulary.Tokens) {
-		slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", p.VocabSize, "actual", len(t.Vocabulary.Tokens))
+	vocabSize := int(p.VocabSize)
+	switch {
+	case vocabSize > len(t.Vocabulary.Tokens):
+		slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
 		for i := range vocabSize - len(t.Vocabulary.Tokens) {
 			t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
 			t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
 			t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
 		}
-	} else {
+	case vocabSize < len(t.Vocabulary.Tokens):
+		return fmt.Errorf("vocabulary is larger than expected '%d' instead of '%d'", len(t.Vocabulary.Tokens), vocabSize)
+	default:
 		slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
 	}