Jelajahi Sumber

model: support for mistral-small in the ollama runner

Mistral is a popular research lab making open source models. This updates
the forward pass of llama architecture models to support both llama models
and mistral models by accounting for additional metadata present in mistral
models, and finding the correct dimensions for the output projection.
Bruce MacDonald 1 bulan lalu
induk
melakukan
191b1b1eb3
2 mengubah file dengan 340 tambahan dan 5 penghapusan
  1. 20 5
      model/models/llama/model.go
  2. 320 0
      model/process_text_test.go

+ 20 - 5
model/models/llama/model.go

@@ -13,9 +13,9 @@ import (
 )
 
 type Options struct {
-	hiddenSize, numHeads, numKVHeads int
-	eps, ropeBase, ropeScale         float32
-	ropeDim                          uint32
+	hiddenSize, numHeads, numKVHeads, headDim int
+	eps, ropeBase, ropeScale                  float32
+	ropeDim                                   uint32
 }
 
 type Model struct {
@@ -37,6 +37,8 @@ func New(c ml.Config) (model.Model, error) {
 
 	m := Model{
 		BytePairEncoding: model.NewBytePairEncoding(
+			// TODO: need to set this in the conversion for mistral:
+			// tokenizer.ggml.pretokenizer = [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+
 			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
 			&model.Vocabulary{
 				Values: c.Strings("tokenizer.ggml.tokens"),
@@ -53,6 +55,7 @@ func New(c ml.Config) (model.Model, error) {
 			hiddenSize: int(c.Uint("embedding_length")),
 			numHeads:   int(c.Uint("attention.head_count")),
 			numKVHeads: int(c.Uint("attention.head_count_kv")),
+			headDim:    int(c.Uint("attention.key_length")),
 			eps:        c.Float("attention.layer_norm_rms_epsilon"),
 			ropeBase:   c.Float("rope.freq_base"),
 			ropeScale:  c.Float("rope.freq_scale", 1),
@@ -75,24 +78,36 @@ type SelfAttention struct {
 
 func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
 	batchSize := hiddenState.Dim(1)
-	headDim := opts.hiddenSize / opts.numHeads
 	ropeType := uint32(0)
+	// Get head dimension - use explicit value if available, otherwise calculate
+	headDim := opts.headDim
+	if headDim == 0 {
+		headDim = opts.hiddenSize / opts.numHeads
+	}
 
+	// Query projection and reshape
 	q := sa.Query.Forward(ctx, hiddenState)
 	q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
 	q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
 
+	// Key projection and reshape
 	k := sa.Key.Forward(ctx, hiddenState)
 	k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
 	k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
 
+	// Value projection and reshape
 	v := sa.Value.Forward(ctx, hiddenState)
 	v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
 
+	// Attention computation
 	scaleFactor := 1.0 / math.Sqrt(float64(headDim))
 	kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
-	kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
 
+	// Reshape attention output for final projection
+	outputDim := headDim * opts.numHeads
+	kqv = kqv.Reshape(ctx, outputDim, batchSize)
+
+	// Apply output projection
 	return sa.Output.Forward(ctx, kqv)
 }
 

+ 320 - 0
model/process_text_test.go

@@ -209,6 +209,326 @@ func TestLlama(t *testing.T) {
 	})
 }
 
+// tekken loads the Tekken tokenizer for testing
+func tekken(t testing.TB) TextProcessor {
+	t.Helper()
+
+	// Load tokenizer config from mistral-small
+	tokenizerConfigPath := filepath.Join("testdata", "mistral-small", "tokenizer_config.json")
+	configFile, err := os.Open(tokenizerConfigPath)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer configFile.Close()
+
+	var config struct {
+		AddBosToken bool `json:"add_bos_token"`
+		AddEosToken bool `json:"add_eos_token"`
+		BosToken    struct {
+			Content string `json:"content"`
+		} `json:"bos_token"`
+		EosToken struct {
+			Content string `json:"content"`
+		} `json:"eos_token"`
+	}
+	if err := json.NewDecoder(configFile).Decode(&config); err != nil {
+		t.Fatal(err)
+	}
+
+	// Load tokenizer.json which contains the vocabulary and other settings
+	tokenizerJsonPath := filepath.Join("testdata", "mistral-small", "tokenizer.json")
+	tokenizerFile, err := os.Open(tokenizerJsonPath)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer tokenizerFile.Close()
+
+	var tokenizerData struct {
+		Model struct {
+			Type   string           `json:"type"`
+			Vocab  map[string]int32 `json:"vocab"`
+			Merges []string         `json:"merges"`
+		} `json:"model"`
+		AddedTokens []struct {
+			Id      int32  `json:"id"`
+			Content string `json:"content"`
+			Special bool   `json:"special"`
+		} `json:"added_tokens"`
+		PreTokenizer struct {
+			Type          string `json:"type"`
+			Pretokenizers []struct {
+				Type    string `json:"type"`
+				Pattern struct {
+					String string `json:"String"`
+				} `json:"pattern"`
+				Behavior string `json:"behavior"`
+			} `json:"pretokenizers"`
+		} `json:"pre_tokenizer"`
+	}
+	if err := json.NewDecoder(tokenizerFile).Decode(&tokenizerData); err != nil {
+		t.Fatal(err)
+	}
+
+	// Extract the pattern from pre_tokenizer if available
+	var pattern string
+	if tokenizerData.PreTokenizer.Type == "Sequence" && len(tokenizerData.PreTokenizer.Pretokenizers) > 0 {
+		pattern = tokenizerData.PreTokenizer.Pretokenizers[0].Pattern.String
+	}
+
+	// Combine regular vocab and added tokens
+	vocab := tokenizerData.Model.Vocab
+
+	// Add special tokens from added_tokens
+	for _, token := range tokenizerData.AddedTokens {
+		vocab[token.Content] = token.Id
+	}
+
+	// Create vocabulary arrays
+	maxId := int32(-1)
+	for _, id := range vocab {
+		if id > maxId {
+			maxId = id
+		}
+	}
+
+	vocabSize := int(maxId + 1)
+	types := make([]uint32, vocabSize)
+	tokens := make([]string, vocabSize)
+	scores := make([]float32, vocabSize)
+
+	for token, id := range vocab {
+		tokens[id] = token
+		types[id] = TOKEN_TYPE_NORMAL
+
+		// Assign appropriate token types for special tokens
+		if token == "<s>" {
+			types[id] = TOKEN_TYPE_CONTROL
+		} else if token == "</s>" {
+			types[id] = TOKEN_TYPE_CONTROL
+		} else if token == "[INST]" || token == "[/INST]" {
+			types[id] = TOKEN_TYPE_CONTROL
+		}
+	}
+
+	// In Tekken, we don't need to load merges separately as they're part of the model
+	var merges []string
+
+	// Create vocabulary object
+	vocabObj := &Vocabulary{
+		Values: tokens,
+		Types:  types,
+		Scores: scores,
+		Merges: merges,
+		BOS:    vocab[config.BosToken.Content],
+		EOS:    vocab[config.EosToken.Content],
+		AddBOS: config.AddBosToken,
+		AddEOS: config.AddEosToken,
+	}
+
+	// Use pattern from tokenizer.json if available
+	if pattern != "" {
+		// Ensure pattern has proper escaping for Go regexp
+		pattern = strings.ReplaceAll(pattern, "p{", "\\p{")
+		return NewBytePairEncoding(pattern, vocabObj)
+	}
+
+	// Fallback pattern if not found
+	return NewBytePairEncoding(
+		`\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+|\s+`,
+		vocabObj,
+	)
+}
+
+func TestTekken(t *testing.T) {
+	// Skip if the test data isn't available
+	if _, err := os.Stat(filepath.Join("testdata", "mistral-small")); os.IsNotExist(err) {
+		t.Skip("Mistral-small test data not available")
+	}
+
+	tokenizer := tekken(t)
+
+	t.Run("whitespace_handling", func(t *testing.T) {
+		t.Parallel()
+
+		// The key difference from SentencePiece is that Tekken doesn't prepend whitespace
+		cases := []struct {
+			input    string
+			expected string
+		}{
+			{" hello", " hello"},
+			{"hello ", "hello "},
+			{"hello world", "hello world"},
+			{" hello world ", " hello world "},
+		}
+
+		for _, tc := range cases {
+			ids, err := tokenizer.Encode(tc.input, false)
+			if err != nil {
+				t.Errorf("Failed to encode %q: %v", tc.input, err)
+				continue
+			}
+
+			decoded, err := tokenizer.Decode(ids)
+			if err != nil {
+				t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
+				continue
+			}
+
+			if decoded != tc.expected {
+				t.Errorf("Whitespace handling: got %q, want %q", decoded, tc.expected)
+			}
+		}
+	})
+
+	t.Run("chat_templates", func(t *testing.T) {
+		t.Parallel()
+
+		// Test the Tekken chat template format which doesn't have spaces after special tokens
+		templates := []struct {
+			input       string
+			expectSpace bool // whether we expect a space after special tokens
+		}{
+			{"<s>[INST]user message[/INST]", false},
+			{"<s>[INST] user message[/INST]", true},
+			{"<s>[INST]user message [/INST]", true},
+		}
+
+		for _, tc := range templates {
+			ids, err := tokenizer.Encode(tc.input, false)
+			if err != nil {
+				t.Errorf("Failed to encode %q: %v", tc.input, err)
+				continue
+			}
+
+			decoded, err := tokenizer.Decode(ids)
+			if err != nil {
+				t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
+				continue
+			}
+
+			// Check if there's a space after special tokens
+			hasSpaceAfterINST := strings.Contains(decoded, "[INST] ")
+
+			if hasSpaceAfterINST != tc.expectSpace {
+				t.Errorf("Chat template space handling: got space=%v, want space=%v for %q",
+					hasSpaceAfterINST, tc.expectSpace, tc.input)
+			}
+		}
+	})
+
+	t.Run("special_tokens", func(t *testing.T) {
+		t.Parallel()
+
+		// Test how Tekken handles special tokens
+		cases := []struct {
+			input    string
+			expected []string // We'll check if these tokens are in the decoded output
+		}{
+			{"<s>[INST]hello[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]"}},
+			{"[INST]hello[/INST]</s>", []string{"[INST]", "hello", "[/INST]", "</s>"}},
+			{"<s>[INST]hello[/INST]</s>[INST]again[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]", "</s>", "[INST]", "again", "[/INST]"}},
+		}
+
+		for _, tc := range cases {
+			ids, err := tokenizer.Encode(tc.input, false)
+			if err != nil {
+				t.Errorf("Failed to encode %q: %v", tc.input, err)
+				continue
+			}
+
+			decoded, err := tokenizer.Decode(ids)
+			if err != nil {
+				t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
+				continue
+			}
+
+			for _, expected := range tc.expected {
+				if !strings.Contains(decoded, expected) {
+					t.Errorf("Special token handling: %q missing in decoded output %q", expected, decoded)
+				}
+			}
+		}
+	})
+
+	t.Run("vocabulary_coverage", func(t *testing.T) {
+		t.Parallel()
+
+		// Tekken has a larger vocabulary, so test coverage of various token types
+		samples := []string{
+			"Hello world!",
+			"This is a test of the Tekken tokenizer.",
+			"It has a considerably larger vocabulary size.",
+			"Special characters: !@#$%^&*()",
+			"Numbers: 1234567890",
+			"Multiple languages: こんにちは 你好 안녕하세요",
+			"Code snippets: def function(): return True",
+		}
+
+		for _, sample := range samples {
+			ids, err := tokenizer.Encode(sample, false)
+			if err != nil {
+				t.Errorf("Failed to encode %q: %v", sample, err)
+				continue
+			}
+
+			decoded, err := tokenizer.Decode(ids)
+			if err != nil {
+				t.Errorf("Failed to decode tokens for %q: %v", sample, err)
+				continue
+			}
+
+			if decoded != sample {
+				t.Errorf("Vocabulary coverage: got %q, want %q", decoded, sample)
+			}
+		}
+	})
+
+	t.Run("splitting_behavior", func(t *testing.T) {
+		t.Parallel()
+
+		// Test the splitting behavior which might differ from SentencePiece
+		cases := map[string][]string{
+			"Hello World!": {"Hello", " World", "!"},
+			"user message": {"user", " message"},
+			"[INST]hello":  {"[INST]", "hello"},
+			"hello[/INST]": {"hello", "[/INST]"},
+		}
+
+		for s, want := range cases {
+			got := slices.Collect(tokenizer.(*BytePairEncoding).split(s))
+			if diff := cmp.Diff(want, got); diff != "" {
+				t.Errorf("Splitting behavior no match (-want +got):\n%s", diff)
+			}
+		}
+	})
+
+	t.Run("full_chat_sequence", func(t *testing.T) {
+		t.Parallel()
+
+		// Test a complete chat sequence with Tekken's format
+		chatSequence := "<s>[INST]user message[/INST]assistant message</s>[INST]new user message[/INST]"
+
+		ids, err := tokenizer.Encode(chatSequence, false)
+		if err != nil {
+			t.Fatalf("Failed to encode chat sequence: %v", err)
+		}
+
+		decoded, err := tokenizer.Decode(ids)
+		if err != nil {
+			t.Fatalf("Failed to decode chat sequence tokens: %v", err)
+		}
+
+		// In Tekken, the whitespace shouldn't be added after special tokens
+		if strings.Contains(decoded, "[INST] ") {
+			t.Errorf("Tekken chat sequence has unexpected space after [INST]: %q", decoded)
+		}
+
+		if strings.Contains(decoded, "[/INST] ") {
+			t.Errorf("Tekken chat sequence has unexpected space after [/INST]: %q", decoded)
+		}
+	})
+}
+
 func BenchmarkBytePairEncoding(b *testing.B) {
 	tokenizer := llama(b)
 	bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))