Browse Source

model: add a test for model forward pass during implementation

Adds a new test file to verify model forward pass behavior through
JSON-specified test cases. The framework loads model files (.gguf) and their
corresponding test specifications to validate expected outputs using greedy
sampling.
Bruce MacDonald 2 months ago
parent
commit
7fa9694359
4 changed files with 158 additions and 0 deletions
  1. 3 0
      .gitignore
  2. 138 0
      model/model_external_test.go
  3. 10 0
      model/testdata/models/README.md
  4. 7 0
      model/testdata/models/qwen2_5.json

+ 3 - 0
.gitignore

@@ -14,3 +14,6 @@ test_data
 __debug_bin*
 llama/build
 llama/vendor
+model/testdata/models/*
+!model/testdata/models/*.md
+!model/testdata/models/*.json

+ 138 - 0
model/model_external_test.go

@@ -0,0 +1,138 @@
+// Package model_test provides external tests for the model package.
+// This test file specifically tests the forward pass functionality on models.
+// It is in a separate package (model_test) to avoid import cycles while still
+// being able to test the public API of the model package.
+package model_test
+
+import (
+	"encoding/json"
+	"fmt"
+	"os"
+	"path/filepath"
+	"strings"
+	"testing"
+
+	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/model"
+	"github.com/ollama/ollama/sample"
+
+	_ "github.com/ollama/ollama/model/models"
+)
+
+type modelTest struct {
+	Prompt            string   `json:"prompt"`
+	OutputContainsOne []string `json:"output_contains_one"`
+}
+
+func TestForwardSimple(t *testing.T) {
+	if testing.Short() {
+		t.Skip("skipping in short mode")
+	}
+
+	// Read all JSON files from testdata/models
+	files, err := os.ReadDir("testdata/models")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	for _, file := range files {
+		if !strings.HasSuffix(file.Name(), ".json") {
+			continue
+		}
+
+		jsonPath := filepath.Join("testdata/models", file.Name())
+		ggufPath := filepath.Join("testdata/models", strings.TrimSuffix(file.Name(), ".json")+".gguf")
+
+		// Skip if no corresponding .gguf file exists
+		if _, err := os.Stat(ggufPath); err != nil {
+			t.Logf("skipping %s: no corresponding GGUF file found", file.Name())
+			continue
+		}
+
+		data, err := os.ReadFile(jsonPath)
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		var test modelTest
+		if err := json.Unmarshal(data, &test); err != nil {
+			t.Fatal(err)
+		}
+
+		t.Run(strings.TrimSuffix(file.Name(), ".json"), func(t *testing.T) {
+			m, err := model.New(ggufPath)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			m.Config().Cache.Init(m.Backend(), ml.DTypeF32, 2048)
+
+			inputs, err := m.(model.TextProcessor).Encode(test.Prompt)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			var result []string
+			for len(result) < 100 { // Limit to 100 tokens max
+				options := model.Options{
+					Inputs:    inputs,
+					Positions: make([]int32, len(inputs)),
+					Sequences: make([]int, len(inputs)),
+					Outputs:   []int32{int32(len(inputs) - 1)},
+				}
+				for i := range options.Positions {
+					options.Positions[i] = int32(i)
+					options.Sequences[i] = 0
+				}
+
+				ctx := m.Backend().NewContext()
+
+				modelOutput, err := model.Forward(ctx, m, options)
+				if err != nil {
+					ctx.Close()
+					t.Fatal(fmt.Errorf("forward pass failed: %v", err))
+				}
+
+				f32s := modelOutput.Floats()
+				logits := make([]float64, len(f32s))
+				for i, f32 := range f32s {
+					logits[i] = float64(f32)
+				}
+
+				token, err := sample.Sample(logits, sample.Greedy())
+				if err != nil {
+					ctx.Close()
+					t.Fatal(fmt.Errorf("sampling failed: %v", err))
+				}
+
+				ctx.Close()
+
+				// Greedy sampling: take the token with the highest logit
+				nextToken := int32(token[0])
+				if m.(model.TextProcessor).Is(nextToken, model.SpecialEOS) {
+					break
+				}
+
+				piece, err := m.(model.TextProcessor).Decode([]int32{nextToken})
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				result = append(result, piece)
+				output := strings.Join(result, "")
+
+				for _, expectedOutput := range test.OutputContainsOne {
+					if strings.Contains(output, expectedOutput) {
+						t.Logf("Test passed with output: %q (matched expected: %q)", output, expectedOutput)
+						return
+					}
+				}
+
+				// Maintain full context by appending new token
+				inputs = append(inputs, nextToken)
+			}
+
+			t.Fatalf("Expected output containing one of %q but got: %q", test.OutputContainsOne, strings.Join(result, ""))
+		})
+	}
+}

+ 10 - 0
model/testdata/models/README.md

@@ -0,0 +1,10 @@
+# Test Model Directory
+
+This directory is used for storing model files (like `.gguf` files) that are required to run the tests in `model_external_test.go`. 
+
+## Usage
+
+- Place any model files you need for testing in this directory
+- The test file will look for any model files here (e.g., `llama3.gguf`)
+- All non-markdown files in this directory are git-ignored to prevent large model files from being committed to the repository
+- Only `.md` files (like this README) will be tracked in git

+ 7 - 0
model/testdata/models/qwen2_5.json

@@ -0,0 +1,7 @@
+{
+  "prompt": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nhi<|im_end|>\n<|im_start|>assistant\n",
+  "output_contains_one": [
+    "Hello",
+    "Hi"
+  ]
+}