瀏覽代碼

model: Don't unconditionally add special tokens

We sometimes tokenize partial strings. For example, with
multimodal inputs, we split the input string around the images
and then tokenize each piece. In these cases, we should only add
the special tokens on the first piece.
Jesse Gross 1 月之前
父節點
當前提交
b70fc4d51e
共有 4 個文件被更改,包括 12 次插入12 次删除
  1. 1 1
      llm/server.go
  2. 3 3
      model/process_text.go
  3. 7 7
      model/process_text_test.go
  4. 1 1
      runner/ollamarunner/runner.go

+ 1 - 1
llm/server.go

@@ -973,7 +973,7 @@ func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error)
 		return s.llamaModel.Tokenize(content, false, true)
 	}
 	if s.textProcessor != nil {
-		tokens, err := s.textProcessor.Encode(content)
+		tokens, err := s.textProcessor.Encode(content, false)
 		if err != nil {
 			return nil, err
 		}

+ 3 - 3
model/process_text.go

@@ -19,7 +19,7 @@ const (
 )
 
 type TextProcessor interface {
-	Encode(string) ([]int32, error)
+	Encode(s string, addSpecial bool) ([]int32, error)
 	Decode([]int32) (string, error)
 	Is(int32, Special) bool
 }
@@ -144,7 +144,7 @@ type merge struct {
 	runes []rune
 }
 
-func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
+func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
 	fragments := []fragment{{value: s}}
 	for _, special := range bpe.vocab.SpecialVocabulary() {
 		// TODO: process special tokens concurrently
@@ -282,7 +282,7 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
 		}
 	}
 
-	if len(ids) > 0 {
+	if addSpecial && len(ids) > 0 {
 		if bpe.vocab.AddBOS {
 			if ids[0] == bpe.vocab.BOS {
 				slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)

+ 7 - 7
model/process_text_test.go

@@ -74,7 +74,7 @@ func TestLlama(t *testing.T) {
 	t.Run("simple", func(t *testing.T) {
 		t.Parallel()
 
-		ids, err := tokenizer.Encode("hello world")
+		ids, err := tokenizer.Encode("hello world", true)
 		if err != nil {
 			t.Error(err)
 		}
@@ -92,7 +92,7 @@ func TestLlama(t *testing.T) {
 			t.Errorf("got %q, want hello world", s)
 		}
 
-		ids, err = tokenizer.Encode("hello <|end_of_text|>")
+		ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
 		if err != nil {
 			t.Error(err)
 		}
@@ -126,7 +126,7 @@ func TestLlama(t *testing.T) {
 		}
 
 		for s, want := range cases {
-			ids, err := tokenizer.Encode(s)
+			ids, err := tokenizer.Encode(s, true)
 			if err != nil {
 				t.Error(err)
 			}
@@ -152,7 +152,7 @@ func TestLlama(t *testing.T) {
 		}
 
 		for _, want := range cases {
-			ids, err := tokenizer.Encode(want)
+			ids, err := tokenizer.Encode(want, true)
 			if err != nil {
 				t.Error(err)
 			}
@@ -176,7 +176,7 @@ func TestLlama(t *testing.T) {
 		}
 
 		for s, want := range cases {
-			ids, err := tokenizer.Encode(s)
+			ids, err := tokenizer.Encode(s, true)
 			if err != nil {
 				t.Fatal(err)
 			}
@@ -222,7 +222,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
 		b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
 			b.ResetTimer()
 			for range b.N {
-				_, err := tokenizer.Encode(string(bts))
+				_, err := tokenizer.Encode(string(bts), true)
 				if err != nil {
 					b.Fatal(err)
 				}
@@ -230,7 +230,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
 		})
 
 		b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
-			ids, err := tokenizer.Encode(string(bts))
+			ids, err := tokenizer.Encode(string(bts), true)
 			if err != nil {
 				b.Fatal(err)
 			}

+ 1 - 1
runner/ollamarunner/runner.go

@@ -161,7 +161,7 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
 
 	for i, part := range parts {
 		// text - tokenize
-		tokens, err := s.model.(model.TextProcessor).Encode(part)
+		tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
 		if err != nil {
 			return nil, err
 		}