tokenizer.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. package convert
  2. import (
  3. "crypto/sha256"
  4. "encoding/hex"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io/fs"
  9. "log/slog"
  10. "os"
  11. "slices"
  12. "golang.org/x/exp/maps"
  13. )
  14. const (
  15. _ int32 = iota
  16. tokenTypeNormal
  17. tokenTypeUnknown
  18. tokenTypeControl
  19. tokenTypeUserDefined
  20. tokenTypeUnused
  21. tokenTypeByte
  22. )
  23. type Tokenizer struct {
  24. *Vocabulary
  25. SpecialVocabulary []*SpecialVocabulary
  26. Merges []string
  27. Pre string
  28. Template string
  29. }
  30. func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) {
  31. v, err := parseVocabulary(fsys)
  32. if err != nil {
  33. return nil, err
  34. }
  35. t := &Tokenizer{
  36. Vocabulary: v,
  37. Pre: "default",
  38. }
  39. addedTokens := make(map[string]token)
  40. if f, err := fsys.Open("tokenizer.json"); errors.Is(err, os.ErrNotExist) {
  41. } else if err != nil {
  42. return nil, err
  43. } else {
  44. defer f.Close()
  45. var tt tokenizer
  46. if err := json.NewDecoder(f).Decode(&tt); err != nil {
  47. return nil, err
  48. }
  49. for _, t := range tt.AddedTokens {
  50. addedTokens[t.Content] = t
  51. }
  52. t.Merges = tt.Model.Merges
  53. sha256sum := sha256.New()
  54. for _, pt := range tt.PreTokenizer.PreTokenizers {
  55. switch pt.Type {
  56. case "Split":
  57. if pt.Pattern.Regex != "" {
  58. // create a checksum of all Split pretokenizers which should be sufficient
  59. // to identify the pretokenizer
  60. sha256sum.Write([]byte(pt.Pattern.Regex))
  61. }
  62. }
  63. }
  64. switch digest := hex.EncodeToString(sha256sum.Sum(nil)); digest {
  65. case "d98f9631be1e9607a9848c26c1f9eac1aa9fc21ac6ba82a2fc0741af9780a48f":
  66. t.Pre = "llama-bpe"
  67. case "03df5c5863ad70781dcfdef491ead25140f895fe8010964be0daefe27be32b02":
  68. t.Pre = "deepseek-llm"
  69. case "21cde974d587f0d54dc8d56b183cc1e6239600172035c68fbd6d4b9f8da0576e":
  70. t.Pre = "deepseek-coder"
  71. case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
  72. // noop, empty pretokenizer
  73. default:
  74. slog.Warn("unknown pretokenizer, using default", "digest", digest)
  75. }
  76. }
  77. if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) {
  78. } else if err != nil {
  79. return nil, err
  80. } else {
  81. defer f.Close()
  82. var p map[string]json.RawMessage
  83. if err := json.NewDecoder(f).Decode(&p); err != nil {
  84. return nil, err
  85. }
  86. if template, ok := p["chat_template"]; ok {
  87. if err := json.Unmarshal(template, &t.Template); err != nil {
  88. return nil, err
  89. }
  90. }
  91. for _, st := range specialTokenTypes {
  92. sv := SpecialVocabulary{Type: st}
  93. if bts, ok := p[fmt.Sprintf("add_%s_token", st)]; ok {
  94. if err := json.Unmarshal(bts, &sv.AddToken); err != nil {
  95. return nil, err
  96. }
  97. }
  98. if bts, ok := p[fmt.Sprintf("%s_token", st)]; ok {
  99. var content string
  100. if err := json.Unmarshal(bts, &content); err != nil {
  101. var mm map[string]any
  102. if err := json.Unmarshal(bts, &mm); err != nil {
  103. continue
  104. }
  105. content, ok = mm["content"].(string)
  106. if !ok {
  107. continue
  108. }
  109. }
  110. sv.Content = content
  111. }
  112. if id, ok := addedTokens[sv.Content]; ok {
  113. sv.ID = id.ID
  114. t.SpecialVocabulary = append(t.SpecialVocabulary, &sv)
  115. }
  116. }
  117. }
  118. return t, nil
  119. }
  120. type tokenizer struct {
  121. Version string `json:"version"`
  122. AddedTokens []token `json:"added_tokens"`
  123. Model struct {
  124. Type string `json:"type"`
  125. Vocab map[string]int `json:"vocab"`
  126. Merges []string `json:"merges"`
  127. } `json:"model"`
  128. PreTokenizer struct {
  129. PreTokenizers []struct {
  130. Type string `json:"type"`
  131. Pattern struct {
  132. Regex string `json:"Regex"`
  133. } `json:"pattern"`
  134. } `json:"pretokenizers"`
  135. } `json:"pre_tokenizer"`
  136. }
  137. type token struct {
  138. ID int `json:"id"`
  139. Content string `json:"content"`
  140. Special bool `json:"special"`
  141. UserDefined bool
  142. }
  143. type Vocabulary struct {
  144. Model string
  145. Tokens []string
  146. Scores []float32
  147. Types []int32
  148. }
  149. func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
  150. f, err := fsys.Open("tokenizer.json")
  151. if err != nil {
  152. return nil, err
  153. }
  154. defer f.Close()
  155. var t tokenizer
  156. if err := json.NewDecoder(f).Decode(&t); err != nil {
  157. return nil, err
  158. }
  159. tokens := make(map[int]token, len(t.Model.Vocab))
  160. for k, v := range t.Model.Vocab {
  161. tokens[v] = token{
  162. ID: v,
  163. Content: k,
  164. }
  165. }
  166. for _, token := range t.AddedTokens {
  167. token.UserDefined = true
  168. tokens[token.ID] = token
  169. }
  170. keys := maps.Keys(tokens)
  171. slices.Sort(keys)
  172. v := Vocabulary{Model: "gpt2"}
  173. for _, k := range keys {
  174. token := tokens[k]
  175. v.Tokens = append(v.Tokens, token.Content)
  176. v.Scores = append(v.Scores, float32(token.ID))
  177. switch {
  178. case token.Special:
  179. v.Types = append(v.Types, tokenTypeControl)
  180. case token.UserDefined:
  181. v.Types = append(v.Types, tokenTypeUserDefined)
  182. default:
  183. v.Types = append(v.Types, tokenTypeNormal)
  184. }
  185. }
  186. return &v, nil
  187. }
  188. func parseVocabulary(fsys fs.FS) (*Vocabulary, error) {
  189. patterns := []struct {
  190. Pattern string
  191. Func func(fs.FS) (*Vocabulary, error)
  192. }{
  193. {"tokenizer.model", parseSentencePiece},
  194. {"tokenizer.json", parseVocabularyFromTokenizer},
  195. }
  196. for _, pattern := range patterns {
  197. if _, err := fs.Stat(fsys, pattern.Pattern); errors.Is(err, os.ErrNotExist) {
  198. continue
  199. } else if err != nil {
  200. return nil, err
  201. }
  202. return pattern.Func(fsys)
  203. }
  204. return nil, errors.New("unknown tensor format")
  205. }
  206. type SpecialVocabulary struct {
  207. Type string
  208. ID int
  209. Content string
  210. AddToken bool
  211. }
  212. func (sv SpecialVocabulary) Key() string {
  213. switch t := sv.Type; t {
  214. case "bos", "eos", "cls", "mask":
  215. return t
  216. case "unk":
  217. return "unknown"
  218. case "sep":
  219. //nolint:misspell // this is an upstream typo
  220. return "seperator"
  221. case "pad":
  222. return "padding"
  223. }
  224. panic("unknown special vocabulary type")
  225. }