|
@@ -74,7 +74,7 @@ func TestLlama(t *testing.T) {
|
|
|
t.Run("simple", func(t *testing.T) {
|
|
|
t.Parallel()
|
|
|
|
|
|
- ids, err := tokenizer.Encode("hello world")
|
|
|
+ ids, err := tokenizer.Encode("hello world", true)
|
|
|
if err != nil {
|
|
|
t.Error(err)
|
|
|
}
|
|
@@ -92,7 +92,7 @@ func TestLlama(t *testing.T) {
|
|
|
t.Errorf("got %q, want hello world", s)
|
|
|
}
|
|
|
|
|
|
- ids, err = tokenizer.Encode("hello <|end_of_text|>")
|
|
|
+ ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
|
|
|
if err != nil {
|
|
|
t.Error(err)
|
|
|
}
|
|
@@ -126,7 +126,7 @@ func TestLlama(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
for s, want := range cases {
|
|
|
- ids, err := tokenizer.Encode(s)
|
|
|
+ ids, err := tokenizer.Encode(s, true)
|
|
|
if err != nil {
|
|
|
t.Error(err)
|
|
|
}
|
|
@@ -152,7 +152,7 @@ func TestLlama(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
for _, want := range cases {
|
|
|
- ids, err := tokenizer.Encode(want)
|
|
|
+ ids, err := tokenizer.Encode(want, true)
|
|
|
if err != nil {
|
|
|
t.Error(err)
|
|
|
}
|
|
@@ -176,7 +176,7 @@ func TestLlama(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
for s, want := range cases {
|
|
|
- ids, err := tokenizer.Encode(s)
|
|
|
+ ids, err := tokenizer.Encode(s, true)
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
@@ -222,7 +222,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
|
|
|
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
|
|
|
b.ResetTimer()
|
|
|
for range b.N {
|
|
|
- _, err := tokenizer.Encode(string(bts))
|
|
|
+ _, err := tokenizer.Encode(string(bts), true)
|
|
|
if err != nil {
|
|
|
b.Fatal(err)
|
|
|
}
|
|
@@ -230,7 +230,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
|
|
|
})
|
|
|
|
|
|
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
|
|
|
- ids, err := tokenizer.Encode(string(bts))
|
|
|
+ ids, err := tokenizer.Encode(string(bts), true)
|
|
|
if err != nil {
|
|
|
b.Fatal(err)
|
|
|
}
|