123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- // Package model_test provides external tests for the model package.
- // This test file specifically tests the forward pass functionality on models.
- // It is in a separate package (model_test) to avoid import cycles while still
- // being able to test the public API of the model package.
- package model_test
- import (
- "encoding/json"
- "fmt"
- "os"
- "path/filepath"
- "strings"
- "testing"
- "github.com/ollama/ollama/ml"
- "github.com/ollama/ollama/model"
- "github.com/ollama/ollama/sample"
- _ "github.com/ollama/ollama/model/models"
- )
- type modelTest struct {
- Prompt string `json:"prompt"`
- OutputContainsOne []string `json:"output_contains_one"`
- }
- func TestForwardSimple(t *testing.T) {
- if testing.Short() {
- t.Skip("skipping in short mode")
- }
- // Read all JSON files from testdata/models
- files, err := os.ReadDir("testdata/models")
- if err != nil {
- t.Fatal(err)
- }
- for _, file := range files {
- if !strings.HasSuffix(file.Name(), ".json") {
- continue
- }
- jsonPath := filepath.Join("testdata/models", file.Name())
- ggufPath := filepath.Join("testdata/models", strings.TrimSuffix(file.Name(), ".json")+".gguf")
- // Skip if no corresponding .gguf file exists
- if _, err := os.Stat(ggufPath); err != nil {
- t.Logf("skipping %s: no corresponding GGUF file found", file.Name())
- continue
- }
- data, err := os.ReadFile(jsonPath)
- if err != nil {
- t.Fatal(err)
- }
- var test modelTest
- if err := json.Unmarshal(data, &test); err != nil {
- t.Fatal(err)
- }
- t.Run(strings.TrimSuffix(file.Name(), ".json"), func(t *testing.T) {
- m, err := model.New(ggufPath)
- if err != nil {
- t.Fatal(err)
- }
- m.Config().Cache.Init(m.Backend(), ml.DTypeF32, 2048)
- inputs, err := m.(model.TextProcessor).Encode(test.Prompt)
- if err != nil {
- t.Fatal(err)
- }
- var result []string
- for len(result) < 100 { // Limit to 100 tokens max
- options := model.Options{
- Inputs: inputs,
- Positions: make([]int32, len(inputs)),
- Sequences: make([]int, len(inputs)),
- Outputs: []int32{int32(len(inputs) - 1)},
- }
- for i := range options.Positions {
- options.Positions[i] = int32(i)
- options.Sequences[i] = 0
- }
- ctx := m.Backend().NewContext()
- modelOutput, err := model.Forward(ctx, m, options)
- if err != nil {
- ctx.Close()
- t.Fatal(fmt.Errorf("forward pass failed: %v", err))
- }
- f32s := modelOutput.Floats()
- logits := make([]float64, len(f32s))
- for i, f32 := range f32s {
- logits[i] = float64(f32)
- }
- token, err := sample.Sample(logits, sample.Greedy())
- if err != nil {
- ctx.Close()
- t.Fatal(fmt.Errorf("sampling failed: %v", err))
- }
- ctx.Close()
- // Greedy sampling: take the token with the highest logit
- nextToken := int32(token[0])
- if m.(model.TextProcessor).Is(nextToken, model.SpecialEOS) {
- break
- }
- piece, err := m.(model.TextProcessor).Decode([]int32{nextToken})
- if err != nil {
- t.Fatal(err)
- }
- result = append(result, piece)
- output := strings.Join(result, "")
- for _, expectedOutput := range test.OutputContainsOne {
- if strings.Contains(output, expectedOutput) {
- t.Logf("Test passed with output: %q (matched expected: %q)", output, expectedOutput)
- return
- }
- }
- // Maintain full context by appending new token
- inputs = append(inputs, nextToken)
- }
- t.Fatalf("Expected output containing one of %q but got: %q", test.OutputContainsOne, strings.Join(result, ""))
- })
- }
- }
|