tokenizer_spm.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. package convert
  2. import (
  3. "cmp"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io/fs"
  8. "os"
  9. "slices"
  10. "google.golang.org/protobuf/proto"
  11. "github.com/ollama/ollama/convert/sentencepiece"
  12. )
  13. func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
  14. ast, err := parseAdditionalSpecialTokens(fsys)
  15. if err != nil {
  16. return nil, err
  17. }
  18. bts, err := fs.ReadFile(fsys, "tokenizer.model")
  19. if err != nil {
  20. return nil, err
  21. }
  22. var spm sentencepiece.ModelProto
  23. if err := proto.Unmarshal(bts, &spm); err != nil {
  24. return nil, err
  25. }
  26. v := Vocabulary{Model: "llama"}
  27. for _, piece := range spm.GetPieces() {
  28. v.Tokens = append(v.Tokens, piece.GetPiece())
  29. v.Scores = append(v.Scores, piece.GetScore())
  30. switch t := piece.GetType(); t {
  31. case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
  32. sentencepiece.ModelProto_SentencePiece_CONTROL,
  33. sentencepiece.ModelProto_SentencePiece_UNUSED,
  34. sentencepiece.ModelProto_SentencePiece_BYTE:
  35. v.Types = append(v.Types, int32(t))
  36. default:
  37. tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
  38. if slices.Contains(ast, piece.GetPiece()) {
  39. tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
  40. }
  41. v.Types = append(v.Types, tt)
  42. }
  43. }
  44. f, err := fsys.Open("added_tokens.json")
  45. if errors.Is(err, os.ErrNotExist) {
  46. return &v, nil
  47. } else if err != nil {
  48. return nil, err
  49. }
  50. defer f.Close()
  51. var atm map[string]int
  52. if err := json.NewDecoder(f).Decode(&atm); err != nil {
  53. return nil, err
  54. }
  55. type t struct {
  56. id int
  57. content string
  58. }
  59. var ts []t
  60. for content, id := range atm {
  61. ts = append(ts, t{id, content})
  62. }
  63. slices.SortFunc(ts, func(i, j t) int {
  64. return cmp.Compare(i.id, j.id)
  65. })
  66. n := len(v.Tokens)
  67. for i, t := range ts {
  68. if t.id != i+n {
  69. return nil, fmt.Errorf("invalid token id: %d", t.id)
  70. }
  71. v.Tokens = append(v.Tokens, t.content)
  72. v.Scores = append(v.Scores, -1000.0)
  73. v.Types = append(v.Types, tokenTypeUserDefined)
  74. }
  75. return &v, nil
  76. }
  77. func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
  78. f, err := fsys.Open("special_tokens_map.json")
  79. if errors.Is(err, os.ErrNotExist) {
  80. return nil, nil
  81. } else if err != nil {
  82. return nil, err
  83. }
  84. defer f.Close()
  85. var m struct {
  86. AdditionalSpecialTokens []string `json:"additional_special_tokens"`
  87. }
  88. if err := json.NewDecoder(f).Decode(&m); err != nil {
  89. return nil, err
  90. }
  91. return m.AdditionalSpecialTokens, nil
  92. }