tokenizer_spm.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. package convert
  2. import (
  3. "cmp"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io/fs"
  8. "log/slog"
  9. "os"
  10. "reflect"
  11. "slices"
  12. "google.golang.org/protobuf/proto"
  13. "github.com/ollama/ollama/convert/sentencepiece"
  14. )
  15. func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
  16. slog.Debug("using spm vocabulary")
  17. ast, err := parseAdditionalSpecialTokens(fsys)
  18. if err != nil {
  19. return nil, err
  20. }
  21. bts, err := fs.ReadFile(fsys, "tokenizer.model")
  22. if err != nil {
  23. return nil, err
  24. }
  25. var spm sentencepiece.ModelProto
  26. if err := proto.Unmarshal(bts, &spm); err != nil {
  27. return nil, err
  28. }
  29. v := Vocabulary{Model: "llama"}
  30. for _, piece := range spm.GetPieces() {
  31. v.Tokens = append(v.Tokens, piece.GetPiece())
  32. v.Scores = append(v.Scores, piece.GetScore())
  33. switch t := piece.GetType(); t {
  34. case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
  35. sentencepiece.ModelProto_SentencePiece_CONTROL,
  36. sentencepiece.ModelProto_SentencePiece_UNUSED,
  37. sentencepiece.ModelProto_SentencePiece_BYTE:
  38. v.Types = append(v.Types, int32(t))
  39. default:
  40. tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
  41. // temporary fix to handle gemma3 broken configs
  42. if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
  43. tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
  44. }
  45. for _, t := range ast {
  46. if t.Content == piece.GetPiece() {
  47. tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
  48. break
  49. }
  50. }
  51. v.Types = append(v.Types, tt)
  52. }
  53. }
  54. f, err := fsys.Open("added_tokens.json")
  55. if errors.Is(err, os.ErrNotExist) {
  56. return &v, nil
  57. } else if err != nil {
  58. return nil, err
  59. }
  60. defer f.Close()
  61. var atm map[string]int
  62. if err := json.NewDecoder(f).Decode(&atm); err != nil {
  63. return nil, err
  64. }
  65. type t struct {
  66. id int
  67. content string
  68. }
  69. var ts []t
  70. for content, id := range atm {
  71. ts = append(ts, t{id, content})
  72. }
  73. slices.SortFunc(ts, func(i, j t) int {
  74. return cmp.Compare(i.id, j.id)
  75. })
  76. for _, t := range ts {
  77. if t.id < len(v.Tokens) {
  78. if v.Tokens[t.id] == t.content {
  79. slog.Warn("tokenizer", "duplicate token", t.content, "id", t.id)
  80. continue
  81. }
  82. return nil, fmt.Errorf("token mismatch: %s != %s at pos [%d]", t.content, v.Tokens[t.id], t.id)
  83. }
  84. if t.id != len(v.Tokens) {
  85. return nil, fmt.Errorf("invalid token id: [%d] as pos [%d]", t.id, len(v.Tokens))
  86. }
  87. v.Tokens = append(v.Tokens, t.content)
  88. v.Scores = append(v.Scores, -1000.0)
  89. v.Types = append(v.Types, tokenTypeUserDefined)
  90. }
  91. return &v, nil
  92. }
  93. type specialToken struct {
  94. Content string `json:"content"`
  95. Lstrip bool `json:"lstrip"`
  96. Normalized bool `json:"normalized"`
  97. Rstrip bool `json:"rstrip"`
  98. SingleWord bool `json:"single_word"`
  99. }
  100. func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) {
  101. f, err := fsys.Open("special_tokens_map.json")
  102. if errors.Is(err, os.ErrNotExist) {
  103. return nil, nil
  104. } else if err != nil {
  105. return nil, err
  106. }
  107. defer f.Close()
  108. var m struct {
  109. AdditionalSpecialTokens any `json:"additional_special_tokens"`
  110. }
  111. if err := json.NewDecoder(f).Decode(&m); err != nil {
  112. return nil, err
  113. }
  114. var ast []specialToken
  115. switch st := m.AdditionalSpecialTokens.(type) {
  116. case []string:
  117. for _, s := range st {
  118. ast = append(ast, specialToken{Content: s})
  119. }
  120. case []any:
  121. for _, s := range st {
  122. // marshal and unmarshal the object to get the special token
  123. tMap := s.(map[string]any)
  124. data, err := json.Marshal(tMap)
  125. if err != nil {
  126. return nil, err
  127. }
  128. var token specialToken
  129. err = json.Unmarshal(data, &token)
  130. if err != nil {
  131. return nil, err
  132. }
  133. ast = append(ast, token)
  134. }
  135. default:
  136. slog.Warn("special token", "unknown token", reflect.TypeOf(st))
  137. }
  138. slog.Debug("spm tokenizer", "additional tokens", ast)
  139. return ast, nil
  140. }