tokenizer.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. package convert
  2. import (
  3. "crypto/sha256"
  4. "encoding/hex"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io/fs"
  9. "log/slog"
  10. "os"
  11. "slices"
  12. "golang.org/x/exp/maps"
  13. )
  14. const (
  15. _ int32 = iota
  16. tokenTypeNormal
  17. tokenTypeUnknown
  18. tokenTypeControl
  19. tokenTypeUserDefined
  20. tokenTypeUnused
  21. tokenTypeByte
  22. )
  23. type Tokenizer struct {
  24. *Vocabulary
  25. SpecialVocabulary []*SpecialVocabulary
  26. Merges []string
  27. Pre string
  28. Template string
  29. }
  30. func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) {
  31. v, err := parseVocabulary(fsys)
  32. if err != nil {
  33. return nil, err
  34. }
  35. t := &Tokenizer{
  36. Vocabulary: v,
  37. Pre: "default",
  38. }
  39. addedTokens := make(map[string]token)
  40. if f, err := fsys.Open("tokenizer.json"); errors.Is(err, os.ErrNotExist) {
  41. } else if err != nil {
  42. return nil, err
  43. } else {
  44. defer f.Close()
  45. var tt tokenizer
  46. if err := json.NewDecoder(f).Decode(&tt); err != nil {
  47. return nil, err
  48. }
  49. for _, t := range tt.AddedTokens {
  50. addedTokens[t.Content] = t
  51. }
  52. t.Merges = tt.Model.Merges
  53. sha256sum := sha256.New()
  54. for _, pt := range tt.PreTokenizer.PreTokenizers {
  55. switch pt.Type {
  56. case "Split":
  57. if pt.Pattern.Regex != "" {
  58. // create a checksum of all Split pretokenizers which should be sufficient
  59. // to identify the pretokenizer
  60. sha256sum.Write([]byte(pt.Pattern.Regex))
  61. }
  62. }
  63. }
  64. switch digest := hex.EncodeToString(sha256sum.Sum(nil)); digest {
  65. case "d98f9631be1e9607a9848c26c1f9eac1aa9fc21ac6ba82a2fc0741af9780a48f":
  66. t.Pre = "llama-bpe"
  67. case "03df5c5863ad70781dcfdef491ead25140f895fe8010964be0daefe27be32b02":
  68. t.Pre = "deepseek-llm"
  69. case "21cde974d587f0d54dc8d56b183cc1e6239600172035c68fbd6d4b9f8da0576e":
  70. t.Pre = "deepseek-coder"
  71. case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
  72. // noop, empty pretokenizer
  73. default:
  74. slog.Warn("unknown pretokenizer, using default", "digest", digest)
  75. }
  76. }
  77. if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) {
  78. } else if err != nil {
  79. return nil, err
  80. } else {
  81. defer f.Close()
  82. var p map[string]json.RawMessage
  83. if err := json.NewDecoder(f).Decode(&p); err != nil {
  84. return nil, err
  85. }
  86. if template, ok := p["chat_template"]; ok {
  87. var s []struct {
  88. Name string `json:"name"`
  89. Template string `json:"template"`
  90. }
  91. if err := json.Unmarshal(template, &t.Template); err == nil {
  92. // noop
  93. } else if err := json.Unmarshal(template, &s); err == nil {
  94. for _, e := range s {
  95. if e.Name == "default" {
  96. t.Template = e.Template
  97. break
  98. }
  99. }
  100. } else {
  101. return nil, fmt.Errorf("invalid chat_template: %w", err)
  102. }
  103. }
  104. for _, st := range specialTokenTypes {
  105. sv := SpecialVocabulary{Type: st}
  106. if bts, ok := p[fmt.Sprintf("add_%s_token", st)]; ok {
  107. if err := json.Unmarshal(bts, &sv.AddToken); err != nil {
  108. return nil, err
  109. }
  110. }
  111. if bts, ok := p[fmt.Sprintf("%s_token", st)]; ok {
  112. var content string
  113. if err := json.Unmarshal(bts, &content); err != nil {
  114. var mm map[string]any
  115. if err := json.Unmarshal(bts, &mm); err != nil {
  116. continue
  117. }
  118. content, ok = mm["content"].(string)
  119. if !ok {
  120. continue
  121. }
  122. }
  123. sv.Content = content
  124. }
  125. if id, ok := addedTokens[sv.Content]; ok {
  126. sv.ID = id.ID
  127. t.SpecialVocabulary = append(t.SpecialVocabulary, &sv)
  128. }
  129. }
  130. }
  131. return t, nil
  132. }
  133. type tokenizer struct {
  134. Version string `json:"version"`
  135. AddedTokens []token `json:"added_tokens"`
  136. Model struct {
  137. Type string `json:"type"`
  138. Vocab map[string]int `json:"vocab"`
  139. Merges []string `json:"merges"`
  140. } `json:"model"`
  141. PreTokenizer struct {
  142. PreTokenizers []struct {
  143. Type string `json:"type"`
  144. Pattern struct {
  145. Regex string `json:"Regex"`
  146. } `json:"pattern"`
  147. } `json:"pretokenizers"`
  148. } `json:"pre_tokenizer"`
  149. }
  150. type token struct {
  151. ID int `json:"id"`
  152. Content string `json:"content"`
  153. Special bool `json:"special"`
  154. UserDefined bool
  155. }
  156. type Vocabulary struct {
  157. Model string
  158. Tokens []string
  159. Scores []float32
  160. Types []int32
  161. }
  162. func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
  163. f, err := fsys.Open("tokenizer.json")
  164. if err != nil {
  165. return nil, err
  166. }
  167. defer f.Close()
  168. var t tokenizer
  169. if err := json.NewDecoder(f).Decode(&t); err != nil {
  170. return nil, err
  171. }
  172. tokens := make(map[int]token, len(t.Model.Vocab))
  173. for k, v := range t.Model.Vocab {
  174. tokens[v] = token{
  175. ID: v,
  176. Content: k,
  177. }
  178. }
  179. for _, token := range t.AddedTokens {
  180. token.UserDefined = true
  181. tokens[token.ID] = token
  182. }
  183. keys := maps.Keys(tokens)
  184. slices.Sort(keys)
  185. v := Vocabulary{Model: "gpt2"}
  186. for _, k := range keys {
  187. token := tokens[k]
  188. v.Tokens = append(v.Tokens, token.Content)
  189. v.Scores = append(v.Scores, float32(token.ID))
  190. switch {
  191. case token.Special:
  192. v.Types = append(v.Types, tokenTypeControl)
  193. case token.UserDefined:
  194. v.Types = append(v.Types, tokenTypeUserDefined)
  195. default:
  196. v.Types = append(v.Types, tokenTypeNormal)
  197. }
  198. }
  199. return &v, nil
  200. }
  201. func parseVocabulary(fsys fs.FS) (*Vocabulary, error) {
  202. patterns := []struct {
  203. Pattern string
  204. Func func(fs.FS) (*Vocabulary, error)
  205. }{
  206. {"tokenizer.model", parseSentencePiece},
  207. {"tokenizer.json", parseVocabularyFromTokenizer},
  208. }
  209. for _, pattern := range patterns {
  210. if _, err := fs.Stat(fsys, pattern.Pattern); errors.Is(err, os.ErrNotExist) {
  211. continue
  212. } else if err != nil {
  213. return nil, err
  214. }
  215. return pattern.Func(fsys)
  216. }
  217. return nil, errors.New("unknown tensor format")
  218. }
  219. type SpecialVocabulary struct {
  220. Type string
  221. ID int
  222. Content string
  223. AddToken bool
  224. }
  225. func (sv SpecialVocabulary) Key() string {
  226. switch t := sv.Type; t {
  227. case "bos", "eos", "cls", "mask":
  228. return t
  229. case "unk":
  230. return "unknown"
  231. case "sep":
  232. //nolint:misspell // this is an upstream typo
  233. return "seperator"
  234. case "pad":
  235. return "padding"
  236. }
  237. panic("unknown special vocabulary type")
  238. }