tokenizer.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. package convert
  2. import (
  3. "crypto/sha256"
  4. "encoding/hex"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io/fs"
  9. "log/slog"
  10. "os"
  11. "slices"
  12. "golang.org/x/exp/maps"
  13. )
  14. const (
  15. _ int32 = iota
  16. tokenTypeNormal
  17. tokenTypeUnknown
  18. tokenTypeControl
  19. tokenTypeUserDefined
  20. tokenTypeUnused
  21. tokenTypeByte
  22. )
  23. type Tokenizer struct {
  24. *Vocabulary
  25. SpecialVocabulary []*SpecialVocabulary
  26. Merges []string
  27. Pre string
  28. Template string
  29. }
  30. func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) {
  31. v, err := parseVocabulary(fsys)
  32. if err != nil {
  33. return nil, err
  34. }
  35. t := &Tokenizer{
  36. Vocabulary: v,
  37. Pre: "default",
  38. }
  39. addedTokens := make(map[string]token)
  40. if f, err := fsys.Open("tokenizer.json"); errors.Is(err, os.ErrNotExist) {
  41. } else if err != nil {
  42. return nil, err
  43. } else {
  44. defer f.Close()
  45. var tt tokenizer
  46. if err := json.NewDecoder(f).Decode(&tt); err != nil {
  47. return nil, err
  48. }
  49. for _, t := range tt.AddedTokens {
  50. addedTokens[t.Content] = t
  51. }
  52. t.Merges = tt.Model.Merges
  53. sha256sum := sha256.New()
  54. for _, pt := range tt.PreTokenizer.PreTokenizers {
  55. switch pt.Type {
  56. case "Split":
  57. if pt.Pattern.Regex != "" {
  58. // create a checksum of all Split pretokenizers which should be sufficient
  59. // to identify the pretokenizer
  60. sha256sum.Write([]byte(pt.Pattern.Regex))
  61. }
  62. }
  63. }
  64. switch digest := hex.EncodeToString(sha256sum.Sum(nil)); digest {
  65. case "d98f9631be1e9607a9848c26c1f9eac1aa9fc21ac6ba82a2fc0741af9780a48f":
  66. t.Pre = "llama-bpe"
  67. case "03df5c5863ad70781dcfdef491ead25140f895fe8010964be0daefe27be32b02":
  68. t.Pre = "deepseek-llm"
  69. case "21cde974d587f0d54dc8d56b183cc1e6239600172035c68fbd6d4b9f8da0576e":
  70. t.Pre = "deepseek-coder"
  71. case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
  72. // noop, empty pretokenizer
  73. default:
  74. slog.Warn("unknown pretokenizer, using default", "digest", digest)
  75. }
  76. }
  77. if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) {
  78. } else if err != nil {
  79. return nil, err
  80. } else {
  81. defer f.Close()
  82. var p map[string]json.RawMessage
  83. if err := json.NewDecoder(f).Decode(&p); err != nil {
  84. return nil, err
  85. }
  86. if template, ok := p["chat_template"]; ok {
  87. var s []struct {
  88. Name string `json:"name"`
  89. Template string `json:"template"`
  90. }
  91. if err := json.Unmarshal(template, &t.Template); err == nil {
  92. // noop
  93. } else if err := json.Unmarshal(template, &s); err == nil {
  94. for _, e := range s {
  95. if e.Name == "default" {
  96. t.Template = e.Template
  97. break
  98. }
  99. }
  100. } else {
  101. return nil, fmt.Errorf("invalid chat_template: %w", err)
  102. }
  103. }
  104. for _, st := range specialTokenTypes {
  105. sv := SpecialVocabulary{Type: st}
  106. if bts, ok := p[fmt.Sprintf("add_%s_token", st)]; ok {
  107. if err := json.Unmarshal(bts, &sv.AddToken); err != nil {
  108. return nil, err
  109. }
  110. }
  111. if bts, ok := p[fmt.Sprintf("%s_token", st)]; ok {
  112. var content string
  113. if err := json.Unmarshal(bts, &content); err != nil {
  114. var mm map[string]any
  115. if err := json.Unmarshal(bts, &mm); err != nil {
  116. continue
  117. }
  118. content, ok = mm["content"].(string)
  119. if !ok {
  120. continue
  121. }
  122. }
  123. sv.Content = content
  124. }
  125. if id, ok := addedTokens[sv.Content]; ok {
  126. sv.ID = id.ID
  127. t.SpecialVocabulary = append(t.SpecialVocabulary, &sv)
  128. }
  129. }
  130. }
  131. return t, nil
  132. }
  133. type tokenizer struct {
  134. AddedTokens []token `json:"added_tokens"`
  135. Model struct {
  136. Type string `json:"type"`
  137. Vocab map[string]int `json:"vocab"`
  138. Merges []string `json:"merges"`
  139. } `json:"model"`
  140. PreTokenizer struct {
  141. PreTokenizers []struct {
  142. Type string `json:"type"`
  143. Pattern struct {
  144. Regex string `json:"Regex"`
  145. } `json:"pattern"`
  146. } `json:"pretokenizers"`
  147. } `json:"pre_tokenizer"`
  148. }
  149. type token struct {
  150. ID int `json:"id"`
  151. Content string `json:"content"`
  152. Special bool `json:"special"`
  153. UserDefined bool
  154. }
  155. type Vocabulary struct {
  156. Model string
  157. Tokens []string
  158. Scores []float32
  159. Types []int32
  160. }
  161. func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
  162. f, err := fsys.Open("tokenizer.json")
  163. if err != nil {
  164. return nil, err
  165. }
  166. defer f.Close()
  167. var t tokenizer
  168. if err := json.NewDecoder(f).Decode(&t); err != nil {
  169. return nil, err
  170. }
  171. tokens := make(map[int]token, len(t.Model.Vocab))
  172. for k, v := range t.Model.Vocab {
  173. tokens[v] = token{
  174. ID: v,
  175. Content: k,
  176. }
  177. }
  178. for _, token := range t.AddedTokens {
  179. token.UserDefined = true
  180. tokens[token.ID] = token
  181. }
  182. keys := maps.Keys(tokens)
  183. slices.Sort(keys)
  184. v := Vocabulary{Model: "gpt2"}
  185. for _, k := range keys {
  186. token := tokens[k]
  187. v.Tokens = append(v.Tokens, token.Content)
  188. v.Scores = append(v.Scores, float32(token.ID))
  189. switch {
  190. case token.Special:
  191. v.Types = append(v.Types, tokenTypeControl)
  192. case token.UserDefined:
  193. v.Types = append(v.Types, tokenTypeUserDefined)
  194. default:
  195. v.Types = append(v.Types, tokenTypeNormal)
  196. }
  197. }
  198. return &v, nil
  199. }
  200. func parseVocabulary(fsys fs.FS) (*Vocabulary, error) {
  201. patterns := []struct {
  202. Pattern string
  203. Func func(fs.FS) (*Vocabulary, error)
  204. }{
  205. {"tokenizer.model", parseSentencePiece},
  206. {"tokenizer.json", parseVocabularyFromTokenizer},
  207. }
  208. for _, pattern := range patterns {
  209. if _, err := fs.Stat(fsys, pattern.Pattern); errors.Is(err, os.ErrNotExist) {
  210. continue
  211. } else if err != nil {
  212. return nil, err
  213. }
  214. return pattern.Func(fsys)
  215. }
  216. return nil, errors.New("unknown tokenizer format")
  217. }
  218. type SpecialVocabulary struct {
  219. Type string
  220. ID int
  221. Content string
  222. AddToken bool
  223. }
  224. func (sv SpecialVocabulary) Key() string {
  225. switch t := sv.Type; t {
  226. case "bos", "eos", "cls", "mask":
  227. return t
  228. case "unk":
  229. return "unknown"
  230. case "sep":
  231. //nolint:misspell // this is an upstream typo
  232. return "seperator"
  233. case "pad":
  234. return "padding"
  235. }
  236. panic("unknown special vocabulary type")
  237. }