|
@@ -10,6 +10,7 @@ import (
|
|
"log/slog"
|
|
"log/slog"
|
|
"os"
|
|
"os"
|
|
"slices"
|
|
"slices"
|
|
|
|
+ "strings"
|
|
|
|
|
|
"golang.org/x/exp/maps"
|
|
"golang.org/x/exp/maps"
|
|
)
|
|
)
|
|
@@ -60,7 +61,25 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
|
addedTokens[t.Content] = t
|
|
addedTokens[t.Content] = t
|
|
}
|
|
}
|
|
|
|
|
|
- t.Merges = tt.Model.Merges
|
|
|
|
|
|
+ if len(tt.Model.Merges) == 0 {
|
|
|
|
+ // noop; merges is empty
|
|
|
|
+ } else if err := json.Unmarshal(tt.Model.Merges, &t.Merges); err == nil {
|
|
|
|
+ // noop; merges is []string
|
|
|
|
+ } else if merges, err := func() ([][]string, error) {
|
|
|
|
+ var merges [][]string
|
|
|
|
+ if err := json.Unmarshal(tt.Model.Merges, &merges); err != nil {
|
|
|
|
+ return nil, err
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return merges, nil
|
|
|
|
+ }(); err == nil {
|
|
|
|
+ t.Merges = make([]string, len(merges))
|
|
|
|
+ for i := range merges {
|
|
|
|
+ t.Merges[i] = strings.Join(merges[i], " ")
|
|
|
|
+ }
|
|
|
|
+ } else {
|
|
|
|
+ return nil, fmt.Errorf("could not parse tokenizer merges. expected []string or [][]string: %w", err)
|
|
|
|
+ }
|
|
|
|
|
|
sha256sum := sha256.New()
|
|
sha256sum := sha256.New()
|
|
for _, pt := range tt.PreTokenizer.PreTokenizers {
|
|
for _, pt := range tt.PreTokenizer.PreTokenizers {
|
|
@@ -156,9 +175,9 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
|
type tokenizer struct {
|
|
type tokenizer struct {
|
|
AddedTokens []token `json:"added_tokens"`
|
|
AddedTokens []token `json:"added_tokens"`
|
|
Model struct {
|
|
Model struct {
|
|
- Type string `json:"type"`
|
|
|
|
- Vocab map[string]int `json:"vocab"`
|
|
|
|
- Merges []string `json:"merges"`
|
|
|
|
|
|
+ Type string `json:"type"`
|
|
|
|
+ Vocab map[string]int `json:"vocab"`
|
|
|
|
+ Merges json.RawMessage `json:"merges"`
|
|
} `json:"model"`
|
|
} `json:"model"`
|
|
|
|
|
|
PreTokenizer struct {
|
|
PreTokenizer struct {
|