process_text_spm_test.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. package model
  2. import (
  3. "log/slog"
  4. "os"
  5. "path/filepath"
  6. "slices"
  7. "testing"
  8. "google.golang.org/protobuf/proto"
  9. "github.com/ollama/ollama/convert/sentencepiece"
  10. )
  11. func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
  12. t.Helper()
  13. bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
  14. if err != nil {
  15. t.Fatal(err)
  16. }
  17. var spm sentencepiece.ModelProto
  18. if err := proto.Unmarshal(bts, &spm); err != nil {
  19. t.Fatal(err)
  20. }
  21. preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
  22. var v Vocabulary
  23. for _, piece := range spm.GetPieces() {
  24. v.Values = append(v.Values, piece.GetPiece())
  25. v.Scores = append(v.Scores, piece.GetScore())
  26. switch t := piece.GetType(); t {
  27. case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
  28. sentencepiece.ModelProto_SentencePiece_CONTROL,
  29. sentencepiece.ModelProto_SentencePiece_UNUSED,
  30. sentencepiece.ModelProto_SentencePiece_BYTE:
  31. v.Types = append(v.Types, uint32(t))
  32. default:
  33. tt := uint32(sentencepiece.ModelProto_SentencePiece_NORMAL)
  34. // todo parse the special tokens file
  35. // - this will roundtrip correctly but the <start_of_turn> and
  36. // <end_of_turn> tokens aren't processed
  37. v.Types = append(v.Types, tt)
  38. }
  39. }
  40. return NewSentencePieceModel(preTokenizer, &v)
  41. }
  42. func TestSentencePieceEncode(t *testing.T) {
  43. logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
  44. slog.SetDefault(logger)
  45. tokenizer := loadSentencePieceVocab(t)
  46. t.Run("basic roundtrip", func(t *testing.T) {
  47. t.Parallel()
  48. cases := []string{
  49. "hello",
  50. "hello ",
  51. "hello ",
  52. " hello",
  53. " hello ",
  54. " hello ",
  55. "hello world",
  56. "请考试我的软件!12345",
  57. "你好",
  58. "Hello 你好 world!",
  59. }
  60. for _, want := range cases {
  61. ids, err := tokenizer.Encode(want)
  62. if err != nil {
  63. t.Fatal(err)
  64. }
  65. if got, err := tokenizer.Decode(ids); err != nil {
  66. t.Fatal(err)
  67. } else if got != want {
  68. t.Errorf("got %q, want %q [%#v]", got, want, ids)
  69. }
  70. }
  71. })
  72. t.Run("special tokens", func(t *testing.T) {
  73. type candidate struct {
  74. token string
  75. ids []int32
  76. }
  77. cases := []candidate{
  78. {"<bos>", []int32{2}},
  79. {"<eos>", []int32{1}},
  80. }
  81. for _, want := range cases {
  82. ids, err := tokenizer.Encode(want.token)
  83. if err != nil {
  84. t.Fatal(err)
  85. }
  86. if !slices.Equal(ids, want.ids) {
  87. t.Errorf("got %#v, want %#v", ids, want.ids)
  88. }
  89. }
  90. })
  91. }