tokenizer_spm.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. package convert
  2. import (
  3. "cmp"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io/fs"
  8. "log/slog"
  9. "os"
  10. "reflect"
  11. "slices"
  12. "google.golang.org/protobuf/proto"
  13. "github.com/ollama/ollama/convert/sentencepiece"
  14. )
  15. func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
  16. slog.Debug("using spm vocabulary")
  17. ast, err := parseAdditionalSpecialTokens(fsys)
  18. if err != nil {
  19. return nil, err
  20. }
  21. bts, err := fs.ReadFile(fsys, "tokenizer.model")
  22. if err != nil {
  23. return nil, err
  24. }
  25. var spm sentencepiece.ModelProto
  26. if err := proto.Unmarshal(bts, &spm); err != nil {
  27. return nil, err
  28. }
  29. v := Vocabulary{Model: "llama"}
  30. for _, piece := range spm.GetPieces() {
  31. v.Tokens = append(v.Tokens, piece.GetPiece())
  32. v.Scores = append(v.Scores, piece.GetScore())
  33. switch t := piece.GetType(); t {
  34. case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
  35. sentencepiece.ModelProto_SentencePiece_CONTROL,
  36. sentencepiece.ModelProto_SentencePiece_UNUSED,
  37. sentencepiece.ModelProto_SentencePiece_BYTE:
  38. v.Types = append(v.Types, int32(t))
  39. default:
  40. tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
  41. for _, t := range ast {
  42. if t.Content == piece.GetPiece() {
  43. tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
  44. break
  45. }
  46. }
  47. v.Types = append(v.Types, tt)
  48. }
  49. }
  50. f, err := fsys.Open("added_tokens.json")
  51. if errors.Is(err, os.ErrNotExist) {
  52. return &v, nil
  53. } else if err != nil {
  54. return nil, err
  55. }
  56. defer f.Close()
  57. var atm map[string]int
  58. if err := json.NewDecoder(f).Decode(&atm); err != nil {
  59. return nil, err
  60. }
  61. type t struct {
  62. id int
  63. content string
  64. }
  65. var ts []t
  66. for content, id := range atm {
  67. ts = append(ts, t{id, content})
  68. }
  69. slices.SortFunc(ts, func(i, j t) int {
  70. return cmp.Compare(i.id, j.id)
  71. })
  72. for _, t := range ts {
  73. if t.id < len(v.Tokens) {
  74. if v.Tokens[t.id] == t.content {
  75. slog.Warn("tokenizer", "duplicate token", t.content, "id", t.id)
  76. continue
  77. }
  78. return nil, fmt.Errorf("token mismatch: %s != %s at pos [%d]", t.content, v.Tokens[t.id], t.id)
  79. }
  80. if t.id != len(v.Tokens) {
  81. return nil, fmt.Errorf("invalid token id: [%d] as pos [%d]", t.id, len(v.Tokens))
  82. }
  83. v.Tokens = append(v.Tokens, t.content)
  84. v.Scores = append(v.Scores, -1000.0)
  85. v.Types = append(v.Types, tokenTypeUserDefined)
  86. }
  87. return &v, nil
  88. }
  89. type specialToken struct {
  90. Content string `json:"content"`
  91. Lstrip bool `json:"lstrip"`
  92. Normalized bool `json:"normalized"`
  93. Rstrip bool `json:"rstrip"`
  94. SingleWord bool `json:"single_word"`
  95. }
  96. func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) {
  97. f, err := fsys.Open("special_tokens_map.json")
  98. if errors.Is(err, os.ErrNotExist) {
  99. return nil, nil
  100. } else if err != nil {
  101. return nil, err
  102. }
  103. defer f.Close()
  104. var m struct {
  105. AdditionalSpecialTokens any `json:"additional_special_tokens"`
  106. }
  107. if err := json.NewDecoder(f).Decode(&m); err != nil {
  108. return nil, err
  109. }
  110. var ast []specialToken
  111. switch st := m.AdditionalSpecialTokens.(type) {
  112. case []string:
  113. for _, s := range st {
  114. ast = append(ast, specialToken{Content: s})
  115. }
  116. case []any:
  117. for _, s := range st {
  118. // marshal and unmarshal the object to get the special token
  119. tMap := s.(map[string]any)
  120. data, err := json.Marshal(tMap)
  121. if err != nil {
  122. return nil, err
  123. }
  124. var token specialToken
  125. err = json.Unmarshal(data, &token)
  126. if err != nil {
  127. return nil, err
  128. }
  129. ast = append(ast, token)
  130. }
  131. default:
  132. slog.Warn("special token", "unknown token", reflect.TypeOf(st))
  133. }
  134. slog.Debug("spm tokenizer", "additional tokens", ast)
  135. return ast, nil
  136. }