|
@@ -1,15 +1,30 @@
|
|
|
package convert
|
|
|
|
|
|
import (
|
|
|
+ "cmp"
|
|
|
+ "crypto/sha256"
|
|
|
"encoding/json"
|
|
|
- "io/ioutil"
|
|
|
+ "fmt"
|
|
|
+ "log/slog"
|
|
|
"os"
|
|
|
+ "slices"
|
|
|
+
|
|
|
+ "golang.org/x/exp/maps"
|
|
|
)
|
|
|
|
|
|
type Tokenizer struct {
|
|
|
Version string `json:"version"`
|
|
|
AddedTokens []Token `json:"added_tokens"`
|
|
|
Model TokenizerModel `json:"model"`
|
|
|
+
|
|
|
+ PreTokenizer struct {
|
|
|
+ PreTokenziers []struct {
|
|
|
+ Type string `json:"type"`
|
|
|
+ Pattern struct {
|
|
|
+ Regex string `json:"Regex"`
|
|
|
+ } `json:"pattern"`
|
|
|
+ } `json:"pretokenizers"`
|
|
|
+ } `json:"pre_tokenizer"`
|
|
|
}
|
|
|
|
|
|
type TokenizerModel struct {
|
|
@@ -26,47 +41,69 @@ type Token struct {
|
|
|
UserDefined bool
|
|
|
}
|
|
|
|
|
|
-func (t *Tokenizer) getMaxID() int {
|
|
|
- var maxID int
|
|
|
- for _, v := range t.Model.Vocab {
|
|
|
- maxID = max(maxID, v)
|
|
|
+func (t *Token) Type() int32 {
|
|
|
+ switch {
|
|
|
+ case t.Special:
|
|
|
+ return 3
|
|
|
+ case t.UserDefined:
|
|
|
+ return 4
|
|
|
+ default:
|
|
|
+ return 1
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- for _, v := range t.AddedTokens {
|
|
|
- maxID = max(maxID, v.ID)
|
|
|
- }
|
|
|
- return maxID
|
|
|
+func (t *Tokenizer) maxID() int {
|
|
|
+ return max(
|
|
|
+ slices.Max(maps.Values(t.Model.Vocab)),
|
|
|
+ slices.MaxFunc(t.AddedTokens, func(a, b Token) int {
|
|
|
+ return cmp.Compare(a.ID, b.ID)
|
|
|
+ }).ID,
|
|
|
+ )
|
|
|
}
|
|
|
|
|
|
-func newTokenizer(dirpath string) (*Tokenizer, error) {
|
|
|
+func parseTokens(dirpath string) (pre string, tokens []Token, merges []string, err error) {
|
|
|
f, err := os.Open(dirpath)
|
|
|
if err != nil {
|
|
|
panic(err)
|
|
|
}
|
|
|
defer f.Close()
|
|
|
|
|
|
- data, err := ioutil.ReadAll(f)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
+ var t Tokenizer
|
|
|
+ if err := json.NewDecoder(f).Decode(&t); err != nil {
|
|
|
+ return "", nil, nil, err
|
|
|
}
|
|
|
|
|
|
- var tdata Tokenizer
|
|
|
-
|
|
|
- if err := json.Unmarshal(data, &tdata); err != nil {
|
|
|
- return nil, err
|
|
|
+ tokens = make([]Token, t.maxID()+1)
|
|
|
+ for k, v := range t.Model.Vocab {
|
|
|
+ tokens[v] = Token{ID: v, Content: k, Special: false, UserDefined: false}
|
|
|
}
|
|
|
|
|
|
- maxID := tdata.getMaxID()
|
|
|
- tdata.Model.Tokens = make([]Token, maxID+1)
|
|
|
+ for _, v := range t.AddedTokens {
|
|
|
+ v.UserDefined = true
|
|
|
+ tokens[v.ID] = v
|
|
|
+ }
|
|
|
|
|
|
- for k, v := range tdata.Model.Vocab {
|
|
|
- tdata.Model.Tokens[v] = Token{ID: v, Content: k, Special: false, UserDefined: false}
|
|
|
+ sha256sum := sha256.New()
|
|
|
+ for _, pt := range t.PreTokenizer.PreTokenziers {
|
|
|
+ switch pt.Type {
|
|
|
+ case "Split":
|
|
|
+ if pt.Pattern.Regex != "" {
|
|
|
+ sha256sum.Write([]byte(pt.Pattern.Regex))
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- for _, v := range tdata.AddedTokens {
|
|
|
- v.UserDefined = true
|
|
|
- tdata.Model.Tokens[v.ID] = v
|
|
|
+ switch digest := fmt.Sprintf("%x", sha256sum.Sum(nil)); digest {
|
|
|
+ case "d98f9631be1e9607a9848c26c1f9eac1aa9fc21ac6ba82a2fc0741af9780a48f":
|
|
|
+ pre = "llama-bpe"
|
|
|
+ case "03df5c5863ad70781dcfdef491ead25140f895fe8010964be0daefe27be32b02":
|
|
|
+ pre = "deepseek-llm"
|
|
|
+ case "21cde974d587f0d54dc8d56b183cc1e6239600172035c68fbd6d4b9f8da0576e":
|
|
|
+ pre = "deepseek-coder"
|
|
|
+ default:
|
|
|
+ slog.Warn("unknown pretokenizer, using default", "digest", digest)
|
|
|
+ pre = "default"
|
|
|
}
|
|
|
|
|
|
- return &tdata, nil
|
|
|
+ return pre, tokens, t.Model.Merges, nil
|
|
|
}
|