tokenizer.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. package convert
  2. import (
  3. "cmp"
  4. "crypto/sha256"
  5. "encoding/hex"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "log/slog"
  10. "os"
  11. "path/filepath"
  12. "slices"
  13. )
  14. const (
  15. _ int32 = iota
  16. tokenTypeNormal
  17. tokenTypeUnknown
  18. tokenTypeControl
  19. tokenTypeUserDefined
  20. tokenTypeUnused
  21. tokenTypeByte
  22. )
  23. type Tokenizer struct {
  24. *Vocabulary
  25. SpecialVocabulary []*SpecialVocabulary
  26. Merges []string
  27. Pre string
  28. Template string
  29. }
  30. func parseTokenizer(d string, specialTypes []string) (*Tokenizer, error) {
  31. v, err := parseVocabulary(d)
  32. if err != nil {
  33. return nil, err
  34. }
  35. t := &Tokenizer{
  36. Vocabulary: v,
  37. Pre: "default",
  38. }
  39. addedTokens := make(map[string]token)
  40. if f, err := os.Open(filepath.Join(d, "tokenizer.json")); errors.Is(err, os.ErrNotExist) {
  41. } else if err != nil {
  42. return nil, err
  43. } else {
  44. defer f.Close()
  45. var tt tokenizer
  46. if err := json.NewDecoder(f).Decode(&tt); err != nil {
  47. return nil, err
  48. }
  49. for _, t := range tt.AddedTokens {
  50. addedTokens[t.Content] = t
  51. }
  52. t.Merges = tt.Model.Merges
  53. sha256sum := sha256.New()
  54. for _, pt := range tt.PreTokenizer.PreTokenizers {
  55. switch pt.Type {
  56. case "Split":
  57. if pt.Pattern.Regex != "" {
  58. sha256sum.Write([]byte(pt.Pattern.Regex))
  59. }
  60. }
  61. }
  62. switch digest := hex.EncodeToString(sha256sum.Sum(nil)); digest {
  63. case "d98f9631be1e9607a9848c26c1f9eac1aa9fc21ac6ba82a2fc0741af9780a48f":
  64. t.Pre = "llama-bpe"
  65. case "03df5c5863ad70781dcfdef491ead25140f895fe8010964be0daefe27be32b02":
  66. t.Pre = "deepseek-llm"
  67. case "21cde974d587f0d54dc8d56b183cc1e6239600172035c68fbd6d4b9f8da0576e":
  68. t.Pre = "deepseek-coder"
  69. case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
  70. // noop, empty pretokenizer
  71. default:
  72. slog.Warn("unknown pretokenizer, using default", "digest", digest)
  73. }
  74. }
  75. if f, err := os.Open(filepath.Join(d, "tokenizer_config.json")); errors.Is(err, os.ErrNotExist) {
  76. } else if err != nil {
  77. return nil, err
  78. } else {
  79. defer f.Close()
  80. var p map[string]json.RawMessage
  81. if err := json.NewDecoder(f).Decode(&p); err != nil {
  82. return nil, err
  83. }
  84. if template, ok := p["chat_template"]; ok {
  85. if err := json.Unmarshal(template, &t.Template); err != nil {
  86. return nil, err
  87. }
  88. }
  89. for _, st := range specialTypes {
  90. sv := SpecialVocabulary{Type: st}
  91. if bts, ok := p[fmt.Sprintf("add_%s_token", st)]; ok {
  92. if err := json.Unmarshal(bts, &sv.AddToken); err != nil {
  93. return nil, err
  94. }
  95. }
  96. if bts, ok := p[fmt.Sprintf("%s_token", st)]; ok {
  97. var content string
  98. if err := json.Unmarshal(bts, &content); err != nil {
  99. var mm map[string]any
  100. if err := json.Unmarshal(bts, &mm); err != nil {
  101. continue
  102. }
  103. content, ok = mm["content"].(string)
  104. if !ok {
  105. continue
  106. }
  107. }
  108. sv.Content = content
  109. }
  110. if id, ok := addedTokens[sv.Content]; ok {
  111. sv.ID = id.ID
  112. t.SpecialVocabulary = append(t.SpecialVocabulary, &sv)
  113. }
  114. }
  115. }
  116. return t, nil
  117. }
  118. type tokenizer struct {
  119. Version string `json:"version"`
  120. AddedTokens []token `json:"added_tokens"`
  121. Model struct {
  122. Type string `json:"type"`
  123. Vocab map[string]int `json:"vocab"`
  124. Merges []string `json:"merges"`
  125. } `json:"model"`
  126. PreTokenizer struct {
  127. PreTokenizers []struct {
  128. Type string `json:"type"`
  129. Pattern struct {
  130. Regex string `json:"Regex"`
  131. } `json:"pattern"`
  132. } `json:"pretokenizers"`
  133. } `json:"pre_tokenizer"`
  134. }
  135. type token struct {
  136. ID int `json:"id"`
  137. Content string `json:"content"`
  138. Special bool `json:"special"`
  139. UserDefined bool
  140. }
  141. type Vocabulary struct {
  142. Model string
  143. Tokens []string
  144. Scores []float32
  145. Types []int32
  146. }
  147. func parseVocabularyFromTokenizer(p string) (*Vocabulary, error) {
  148. f, err := os.Open(filepath.Join(p, "tokenizer.json"))
  149. if err != nil {
  150. return nil, err
  151. }
  152. defer f.Close()
  153. var t tokenizer
  154. if err := json.NewDecoder(f).Decode(&t); err != nil {
  155. return nil, err
  156. }
  157. var tokens []token
  158. for k, v := range t.Model.Vocab {
  159. tokens = append(tokens, token{
  160. ID: v,
  161. Content: k,
  162. })
  163. }
  164. for _, t := range t.AddedTokens {
  165. t.UserDefined = true
  166. tokens = append(tokens, t)
  167. }
  168. slices.SortFunc(tokens, func(i, j token) int {
  169. return cmp.Compare(i.ID, j.ID)
  170. })
  171. v := Vocabulary{Model: "gpt2"}
  172. for _, t := range tokens {
  173. v.Tokens = append(v.Tokens, t.Content)
  174. v.Scores = append(v.Scores, float32(t.ID))
  175. switch {
  176. case t.Special:
  177. v.Types = append(v.Types, tokenTypeControl)
  178. case t.UserDefined:
  179. v.Types = append(v.Types, tokenTypeUserDefined)
  180. default:
  181. v.Types = append(v.Types, tokenTypeNormal)
  182. }
  183. }
  184. return &v, nil
  185. }
  186. func parseVocabulary(d string) (*Vocabulary, error) {
  187. patterns := map[string]func(string) (*Vocabulary, error){
  188. "tokenizer.model": parseSentencePiece,
  189. "tokenizer.json": parseVocabularyFromTokenizer,
  190. }
  191. for pattern, parseFn := range patterns {
  192. matches, err := filepath.Glob(filepath.Join(d, pattern))
  193. if err != nil {
  194. return nil, err
  195. }
  196. if len(matches) > 0 {
  197. return parseFn(d)
  198. }
  199. }
  200. return nil, errors.New("unknown tensor format")
  201. }
  202. type SpecialVocabulary struct {
  203. Type string
  204. ID int
  205. Content string
  206. AddToken bool
  207. }
  208. func (sv SpecialVocabulary) Key() string {
  209. switch t := sv.Type; t {
  210. case "bos", "eos", "cls", "mask":
  211. return t
  212. case "unk":
  213. return "unknown"
  214. case "sep":
  215. //nolint:misspell // this is an upstream typo
  216. return "seperator"
  217. case "pad":
  218. return "padding"
  219. }
  220. panic("unknown special vocabulary type")
  221. }