process_text_spm_test.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. package model
  2. import (
  3. "log/slog"
  4. "os"
  5. "path/filepath"
  6. "slices"
  7. "testing"
  8. "google.golang.org/protobuf/proto"
  9. "github.com/ollama/ollama/convert/sentencepiece"
  10. )
  11. func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
  12. t.Helper()
  13. bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
  14. if err != nil {
  15. t.Fatal(err)
  16. }
  17. var spm sentencepiece.ModelProto
  18. if err := proto.Unmarshal(bts, &spm); err != nil {
  19. t.Fatal(err)
  20. }
  21. preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
  22. var v Vocabulary
  23. for _, piece := range spm.GetPieces() {
  24. v.Values = append(v.Values, piece.GetPiece())
  25. v.Scores = append(v.Scores, piece.GetScore())
  26. switch t := piece.GetType(); t {
  27. case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
  28. sentencepiece.ModelProto_SentencePiece_CONTROL,
  29. sentencepiece.ModelProto_SentencePiece_UNUSED,
  30. sentencepiece.ModelProto_SentencePiece_BYTE:
  31. v.Types = append(v.Types, uint32(t))
  32. default:
  33. tt := uint32(sentencepiece.ModelProto_SentencePiece_NORMAL)
  34. // todo parse the special tokens file
  35. // - this will roundtrip correctly but the <start_of_turn> and
  36. // <end_of_turn> tokens aren't processed
  37. v.Types = append(v.Types, tt)
  38. }
  39. }
  40. return NewSentencePieceModel(preTokenizer, &v)
  41. }
  42. func TestSentencePieceEncode(t *testing.T) {
  43. logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
  44. slog.SetDefault(logger)
  45. tokenizer := loadSentencePieceVocab(t)
  46. t.Run("basic roundtrip", func(t *testing.T) {
  47. t.Parallel()
  48. cases := []string{
  49. "hello",
  50. "hello ",
  51. "hello ",
  52. " hello",
  53. " hello ",
  54. " hello ",
  55. "hello world",
  56. "请考试我的软件!12345",
  57. "你好",
  58. "Hello 你好 world!",
  59. "Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?",
  60. "Multilingual: 你好 こんにちは Привет Hola مرحبا",
  61. "Numbers and symbols: 123456789 +- */",
  62. "Special tokens: <bos> text <eos>",
  63. "Code snippets: func main() { fmt.Println(\"Hello World\") }",
  64. "Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " +
  65. "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " +
  66. "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
  67. }
  68. for _, want := range cases {
  69. ids, err := tokenizer.Encode(want, true)
  70. if err != nil {
  71. t.Fatal(err)
  72. }
  73. if got, err := tokenizer.Decode(ids); err != nil {
  74. t.Fatal(err)
  75. } else if got != want {
  76. t.Errorf("got %q, want %q [%#v]", got, want, ids)
  77. }
  78. }
  79. })
  80. t.Run("special tokens", func(t *testing.T) {
  81. type candidate struct {
  82. token string
  83. ids []int32
  84. }
  85. cases := []candidate{
  86. {"<bos>", []int32{2}},
  87. {"<eos>", []int32{1}},
  88. }
  89. for _, want := range cases {
  90. ids, err := tokenizer.Encode(want.token, true)
  91. if err != nil {
  92. t.Fatal(err)
  93. }
  94. if !slices.Equal(ids, want.ids) {
  95. t.Errorf("got %#v, want %#v", ids, want.ids)
  96. }
  97. }
  98. })
  99. }