123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570 |
- package model
- import (
- "bufio"
- "encoding/json"
- "math"
- "os"
- "path/filepath"
- "slices"
- "strconv"
- "strings"
- "testing"
- "github.com/google/go-cmp/cmp"
- )
- func llama(t testing.TB) BytePairEncoding {
- t.Helper()
- f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
- if err != nil {
- t.Fatal(err)
- }
- defer f.Close()
- vocab := make(map[string]int32)
- if err := json.NewDecoder(f).Decode(&vocab); err != nil {
- t.Fatal(err)
- }
- types := make([]uint32, len(vocab))
- tokens := make([]string, len(vocab))
- for token, id := range vocab {
- tokens[id] = token
- types[id] = 1
- }
- for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
- if _, ok := vocab[token]; !ok {
- tokens = append(tokens, token) //nolint:makezero
- types = append(types, 3) //nolint:makezero
- vocab[token] = int32(len(vocab))
- }
- }
- f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
- if err != nil {
- t.Fatal(err)
- }
- defer f.Close()
- merges := make([]string, 0, 50000)
- scanner := bufio.NewScanner(f)
- for scanner.Scan() {
- if !strings.HasPrefix(scanner.Text(), "#") {
- merges = append(merges, scanner.Text())
- }
- }
- return NewBytePairEncoding(
- `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
- &Vocabulary{
- Values: tokens,
- Types: types,
- Merges: merges,
- },
- )
- }
- func TestLlama(t *testing.T) {
- tokenizer := llama(t)
- t.Run("simple", func(t *testing.T) {
- t.Parallel()
- ids, err := tokenizer.Encode("hello world", true)
- if err != nil {
- t.Error(err)
- }
- if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" {
- t.Errorf("no match (-theirs +ours):\n%s", diff)
- }
- s, err := tokenizer.Decode([]int32{15339, 1917})
- if err != nil {
- t.Fatal(err)
- }
- if s != "hello world" {
- t.Errorf("got %q, want hello world", s)
- }
- ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
- if err != nil {
- t.Error(err)
- }
- if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" {
- t.Errorf("no match (-theirs +ours):\n%s", diff)
- }
- })
- t.Run("simple repeated", func(t *testing.T) {
- t.Parallel()
- cases := map[string][]int32{
- strings.Repeat("0", 1): {15},
- strings.Repeat("0", 2): {410},
- strings.Repeat("0", 3): {931},
- strings.Repeat("0", 4): {931, 15},
- strings.Repeat("0", 5): {931, 410},
- strings.Repeat("0", 6): {931, 931},
- strings.Repeat("0", 7): {931, 931, 15},
- strings.Repeat("0", 8): {931, 931, 410},
- strings.Repeat("0", 9): {931, 931, 931},
- strings.Repeat("0", 10): {931, 931, 931, 15},
- strings.Repeat("0", 11): {931, 931, 931, 410},
- strings.Repeat("0", 12): {931, 931, 931, 931},
- strings.Repeat("0", 13): {931, 931, 931, 931, 15},
- strings.Repeat("0", 14): {931, 931, 931, 931, 410},
- strings.Repeat("0", 15): {931, 931, 931, 931, 931},
- strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15},
- strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410},
- }
- for s, want := range cases {
- ids, err := tokenizer.Encode(s, true)
- if err != nil {
- t.Error(err)
- }
- if diff := cmp.Diff(want, ids); diff != "" {
- t.Errorf("%q no match (-theirs +ours):\n%s", s, diff)
- }
- }
- })
- t.Run("basic roundtrip", func(t *testing.T) {
- t.Parallel()
- cases := []string{
- "hello",
- "hello ",
- "hello ",
- " hello",
- " hello ",
- " hello ",
- "hello world",
- "请考试我的软件!12345",
- }
- for _, want := range cases {
- ids, err := tokenizer.Encode(want, true)
- if err != nil {
- t.Error(err)
- }
- if got, err := tokenizer.Decode(ids); err != nil {
- t.Fatal(err)
- } else if got != want {
- t.Errorf("got %q, want %q", got, want)
- }
- }
- })
- t.Run("special", func(t *testing.T) {
- t.Parallel()
- cases := map[string][]int32{
- "<|begin_of_text|>A B!": {128000, 32, 426, 0},
- "<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
- "<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
- "<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
- }
- for s, want := range cases {
- ids, err := tokenizer.Encode(s, true)
- if err != nil {
- t.Fatal(err)
- }
- if diff := cmp.Diff(want, ids); diff != "" {
- t.Errorf("no match (-theirs +ours):\n%s", diff)
- }
- }
- })
- t.Run("split", func(t *testing.T) {
- t.Parallel()
- cases := map[string][]string{
- "Hello World!": {"Hello", " World", "!"},
- "I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"},
- "In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"},
- "Hello!! ...world": {"Hello", "!!", " ...", "world"},
- "Hello World": {"Hello", " ", " World"},
- "Hello\nWorld": {"Hello", "\n", "World"},
- "Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
- }
- for s, want := range cases {
- got := slices.Collect(tokenizer.split(s))
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("no match (-theirs +ours):\n%s", diff)
- }
- }
- })
- }
- // tekken loads the Tekken tokenizer for testing
- func tekken(t testing.TB) TextProcessor {
- t.Helper()
- // Load tokenizer config from mistral-small
- tokenizerConfigPath := filepath.Join("testdata", "mistral-small", "tokenizer_config.json")
- configFile, err := os.Open(tokenizerConfigPath)
- if err != nil {
- t.Fatal(err)
- }
- defer configFile.Close()
- var config struct {
- AddBosToken bool `json:"add_bos_token"`
- AddEosToken bool `json:"add_eos_token"`
- BosToken string `json:"bos_token"`
- EosToken string `json:"eos_token"`
- }
- if err := json.NewDecoder(configFile).Decode(&config); err != nil {
- t.Fatal(err)
- }
- // Load tokenizer.json which contains the vocabulary and other settings
- tokenizerJsonPath := filepath.Join("testdata", "mistral-small", "tokenizer.json")
- tokenizerFile, err := os.Open(tokenizerJsonPath)
- if err != nil {
- t.Fatal(err)
- }
- defer tokenizerFile.Close()
- var tokenizerData struct {
- Model struct {
- Type string `json:"type"`
- Vocab map[string]int32 `json:"vocab"`
- Merges []string `json:"merges"`
- } `json:"model"`
- AddedTokens []struct {
- Id int32 `json:"id"`
- Content string `json:"content"`
- Special bool `json:"special"`
- } `json:"added_tokens"`
- PreTokenizer struct {
- Type string `json:"type"`
- Pretokenizers []struct {
- Type string `json:"type"`
- Pattern struct {
- String string `json:"String"`
- } `json:"pattern"`
- Behavior string `json:"behavior"`
- } `json:"pretokenizers"`
- } `json:"pre_tokenizer"`
- }
- if err := json.NewDecoder(tokenizerFile).Decode(&tokenizerData); err != nil {
- t.Fatal(err)
- }
- // Extract the pattern from pre_tokenizer if available
- var pattern string
- if tokenizerData.PreTokenizer.Type == "Sequence" && len(tokenizerData.PreTokenizer.Pretokenizers) > 0 {
- pattern = tokenizerData.PreTokenizer.Pretokenizers[0].Pattern.String
- }
- // Combine regular vocab and added tokens
- vocab := tokenizerData.Model.Vocab
- // Add special tokens from added_tokens
- for _, token := range tokenizerData.AddedTokens {
- vocab[token.Content] = token.Id
- }
- // Create vocabulary arrays
- maxId := int32(-1)
- for _, id := range vocab {
- if id > maxId {
- maxId = id
- }
- }
- vocabSize := int(maxId + 1)
- types := make([]uint32, vocabSize)
- tokens := make([]string, vocabSize)
- scores := make([]float32, vocabSize)
- for token, id := range vocab {
- tokens[id] = token
- types[id] = TOKEN_TYPE_NORMAL
- // Assign appropriate token types for special tokens
- if token == "<s>" {
- types[id] = TOKEN_TYPE_CONTROL
- } else if token == "</s>" {
- types[id] = TOKEN_TYPE_CONTROL
- } else if token == "[INST]" || token == "[/INST]" {
- types[id] = TOKEN_TYPE_CONTROL
- }
- }
- // In Tekken, we don't need to load merges separately as they're part of the model
- var merges []string
- // Create vocabulary object
- vocabObj := &Vocabulary{
- Values: tokens,
- Types: types,
- Scores: scores,
- Merges: merges,
- BOS: vocab[config.BosToken],
- EOS: vocab[config.EosToken],
- AddBOS: config.AddBosToken,
- AddEOS: config.AddEosToken,
- }
- // Use pattern from tokenizer.json if available
- if pattern != "" {
- // Ensure pattern has proper escaping for Go regexp
- pattern = strings.ReplaceAll(pattern, "p{", "\\p{")
- return NewBytePairEncoding(pattern, vocabObj)
- }
- // Fallback pattern if not found
- return NewBytePairEncoding(
- `\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+|\s+`,
- vocabObj,
- )
- }
- func TestTekken(t *testing.T) {
- // Skip if the test data isn't available
- if _, err := os.Stat(filepath.Join("testdata", "mistral-small")); os.IsNotExist(err) {
- t.Skip("Mistral-small test data not available")
- }
- tokenizer := tekken(t)
- t.Run("whitespace_handling", func(t *testing.T) {
- t.Parallel()
- // The key difference from SentencePiece is that Tekken doesn't prepend whitespace
- cases := []struct {
- input string
- expected string
- }{
- {" hello", " hello"},
- {"hello ", "hello "},
- {"hello world", "hello world"},
- {" hello world ", " hello world "},
- }
- for _, tc := range cases {
- ids, err := tokenizer.Encode(tc.input, false)
- if err != nil {
- t.Errorf("Failed to encode %q: %v", tc.input, err)
- continue
- }
- decoded, err := tokenizer.Decode(ids)
- if err != nil {
- t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
- continue
- }
- if decoded != tc.expected {
- t.Errorf("Whitespace handling: got %q, want %q", decoded, tc.expected)
- }
- }
- })
- t.Run("chat_templates", func(t *testing.T) {
- t.Parallel()
- // Test the Tekken chat template format which doesn't have spaces after special tokens
- templates := []struct {
- input string
- expectSpace bool // whether we expect a space after special tokens
- }{
- {"<s>[INST]user message[/INST]", false},
- {"<s>[INST] user message[/INST]", true},
- {"<s>[INST]user message [/INST]", true},
- }
- for _, tc := range templates {
- ids, err := tokenizer.Encode(tc.input, false)
- if err != nil {
- t.Errorf("Failed to encode %q: %v", tc.input, err)
- continue
- }
- decoded, err := tokenizer.Decode(ids)
- if err != nil {
- t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
- continue
- }
- // Check if there's a space after special tokens
- hasSpaceAfterINST := strings.Contains(decoded, "[INST] ")
- if hasSpaceAfterINST != tc.expectSpace {
- t.Errorf("Chat template space handling: got space=%v, want space=%v for %q",
- hasSpaceAfterINST, tc.expectSpace, tc.input)
- }
- }
- })
- t.Run("special_tokens", func(t *testing.T) {
- t.Parallel()
- // Test how Tekken handles special tokens
- cases := []struct {
- input string
- expected []string // We'll check if these tokens are in the decoded output
- }{
- {"<s>[INST]hello[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]"}},
- {"[INST]hello[/INST]</s>", []string{"[INST]", "hello", "[/INST]", "</s>"}},
- {"<s>[INST]hello[/INST]</s>[INST]again[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]", "</s>", "[INST]", "again", "[/INST]"}},
- }
- for _, tc := range cases {
- ids, err := tokenizer.Encode(tc.input, false)
- if err != nil {
- t.Errorf("Failed to encode %q: %v", tc.input, err)
- continue
- }
- decoded, err := tokenizer.Decode(ids)
- if err != nil {
- t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
- continue
- }
- for _, expected := range tc.expected {
- if !strings.Contains(decoded, expected) {
- t.Errorf("Special token handling: %q missing in decoded output %q", expected, decoded)
- }
- }
- }
- })
- t.Run("vocabulary_coverage", func(t *testing.T) {
- t.Parallel()
- // Tekken has a larger vocabulary, so test coverage of various token types
- samples := []string{
- "Hello world!",
- "This is a test of the Tekken tokenizer.",
- "It has a considerably larger vocabulary size.",
- "Special characters: !@#$%^&*()",
- "Numbers: 1234567890",
- "Multiple languages: こんにちは 你好 안녕하세요",
- "Code snippets: def function(): return True",
- }
- for _, sample := range samples {
- ids, err := tokenizer.Encode(sample, false)
- if err != nil {
- t.Errorf("Failed to encode %q: %v", sample, err)
- continue
- }
- decoded, err := tokenizer.Decode(ids)
- if err != nil {
- t.Errorf("Failed to decode tokens for %q: %v", sample, err)
- continue
- }
- if decoded != sample {
- t.Errorf("Vocabulary coverage: got %q, want %q", decoded, sample)
- }
- }
- })
- t.Run("splitting_behavior", func(t *testing.T) {
- t.Parallel()
- // Test the splitting behavior which might differ from SentencePiece
- cases := map[string][]string{
- "Hello World!": {"Hello", " World", "!"},
- "user message": {"user", " message"},
- "[INST]hello": {"[INST]", "hello"},
- "hello[/INST]": {"hello", "[/INST]"},
- }
- for s, want := range cases {
- got := slices.Collect(tokenizer.(*BytePairEncoding).split(s))
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("Splitting behavior no match (-want +got):\n%s", diff)
- }
- }
- })
- t.Run("full_chat_sequence", func(t *testing.T) {
- t.Parallel()
- // Test a complete chat sequence with Tekken's format
- chatSequence := "<s>[INST]user message[/INST]assistant message</s>[INST]new user message[/INST]"
- ids, err := tokenizer.Encode(chatSequence, false)
- if err != nil {
- t.Fatalf("Failed to encode chat sequence: %v", err)
- }
- decoded, err := tokenizer.Decode(ids)
- if err != nil {
- t.Fatalf("Failed to decode chat sequence tokens: %v", err)
- }
- // In Tekken, the whitespace shouldn't be added after special tokens
- if strings.Contains(decoded, "[INST] ") {
- t.Errorf("Tekken chat sequence has unexpected space after [INST]: %q", decoded)
- }
- if strings.Contains(decoded, "[/INST] ") {
- t.Errorf("Tekken chat sequence has unexpected space after [/INST]: %q", decoded)
- }
- })
- }
- func BenchmarkBytePairEncoding(b *testing.B) {
- tokenizer := llama(b)
- bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
- if err != nil {
- b.Fatal(err)
- }
- for i := range 8 {
- n := min(int(math.Pow10(i)), len(bts))
- bts := bts[:n]
- b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
- b.ResetTimer()
- for range b.N {
- _, err := tokenizer.Encode(string(bts), true)
- if err != nil {
- b.Fatal(err)
- }
- }
- })
- b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
- ids, err := tokenizer.Encode(string(bts), true)
- if err != nil {
- b.Fatal(err)
- }
- b.ResetTimer()
- for range b.N {
- _, err := tokenizer.Decode(ids)
- if err != nil {
- b.Fatal(err)
- }
- }
- })
- b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
- b.ResetTimer()
- for range b.N {
- slices.Collect(tokenizer.split(string(bts)))
- }
- })
- }
- }
|