|
@@ -208,14 +208,18 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
- if vocabSize := int(p.VocabSize); vocabSize > len(t.Vocabulary.Tokens) {
|
|
|
- slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", p.VocabSize, "actual", len(t.Vocabulary.Tokens))
|
|
|
+ vocabSize := int(p.VocabSize)
|
|
|
+ switch {
|
|
|
+ case vocabSize > len(t.Vocabulary.Tokens):
|
|
|
+ slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
|
|
|
for i := range vocabSize - len(t.Vocabulary.Tokens) {
|
|
|
t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
|
|
|
t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
|
|
|
t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
|
|
|
}
|
|
|
- } else {
|
|
|
+ case vocabSize < len(t.Vocabulary.Tokens):
|
|
|
+ return fmt.Errorf("vocabulary is larger than expected '%d' instead of '%d'", len(t.Vocabulary.Tokens), vocabSize)
|
|
|
+ default:
|
|
|
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
|
|
|
}
|
|
|
|