model_external_test.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. // Package model_test provides external tests for the model package.
  2. // This test file specifically tests the forward pass functionality on models.
  3. // It is in a separate package (model_test) to avoid import cycles while still
  4. // being able to test the public API of the model package.
  5. package model_test
  6. import (
  7. "encoding/json"
  8. "fmt"
  9. "os"
  10. "path/filepath"
  11. "strings"
  12. "testing"
  13. "github.com/ollama/ollama/ml"
  14. "github.com/ollama/ollama/model"
  15. "github.com/ollama/ollama/sample"
  16. _ "github.com/ollama/ollama/model/models"
  17. )
  18. type modelTest struct {
  19. Prompt string `json:"prompt"`
  20. OutputContainsOne []string `json:"output_contains_one"`
  21. }
  22. func TestForwardSimple(t *testing.T) {
  23. if testing.Short() {
  24. t.Skip("skipping in short mode")
  25. }
  26. // Read all JSON files from testdata/models
  27. files, err := os.ReadDir("testdata/models")
  28. if err != nil {
  29. t.Fatal(err)
  30. }
  31. for _, file := range files {
  32. if !strings.HasSuffix(file.Name(), ".json") {
  33. continue
  34. }
  35. jsonPath := filepath.Join("testdata/models", file.Name())
  36. ggufPath := filepath.Join("testdata/models", strings.TrimSuffix(file.Name(), ".json")+".gguf")
  37. // Skip if no corresponding .gguf file exists
  38. if _, err := os.Stat(ggufPath); err != nil {
  39. t.Logf("skipping %s: no corresponding GGUF file found", file.Name())
  40. continue
  41. }
  42. data, err := os.ReadFile(jsonPath)
  43. if err != nil {
  44. t.Fatal(err)
  45. }
  46. var test modelTest
  47. if err := json.Unmarshal(data, &test); err != nil {
  48. t.Fatal(err)
  49. }
  50. t.Run(strings.TrimSuffix(file.Name(), ".json"), func(t *testing.T) {
  51. m, err := model.New(ggufPath)
  52. if err != nil {
  53. t.Fatal(err)
  54. }
  55. m.Config().Cache.Init(m.Backend(), ml.DTypeF32, 2048)
  56. inputs, err := m.(model.TextProcessor).Encode(test.Prompt)
  57. if err != nil {
  58. t.Fatal(err)
  59. }
  60. var result []string
  61. for len(result) < 100 { // Limit to 100 tokens max
  62. options := model.Options{
  63. Inputs: inputs,
  64. Positions: make([]int32, len(inputs)),
  65. Sequences: make([]int, len(inputs)),
  66. Outputs: []int32{int32(len(inputs) - 1)},
  67. }
  68. for i := range options.Positions {
  69. options.Positions[i] = int32(i)
  70. options.Sequences[i] = 0
  71. }
  72. ctx := m.Backend().NewContext()
  73. modelOutput, err := model.Forward(ctx, m, options)
  74. if err != nil {
  75. ctx.Close()
  76. t.Fatal(fmt.Errorf("forward pass failed: %v", err))
  77. }
  78. f32s := modelOutput.Floats()
  79. logits := make([]float64, len(f32s))
  80. for i, f32 := range f32s {
  81. logits[i] = float64(f32)
  82. }
  83. token, err := sample.Sample(logits, sample.Greedy())
  84. if err != nil {
  85. ctx.Close()
  86. t.Fatal(fmt.Errorf("sampling failed: %v", err))
  87. }
  88. ctx.Close()
  89. // Greedy sampling: take the token with the highest logit
  90. nextToken := int32(token[0])
  91. if m.(model.TextProcessor).Is(nextToken, model.SpecialEOS) {
  92. break
  93. }
  94. piece, err := m.(model.TextProcessor).Decode([]int32{nextToken})
  95. if err != nil {
  96. t.Fatal(err)
  97. }
  98. result = append(result, piece)
  99. output := strings.Join(result, "")
  100. for _, expectedOutput := range test.OutputContainsOne {
  101. if strings.Contains(output, expectedOutput) {
  102. t.Logf("Test passed with output: %q (matched expected: %q)", output, expectedOutput)
  103. return
  104. }
  105. }
  106. // Maintain full context by appending new token
  107. inputs = append(inputs, nextToken)
  108. }
  109. t.Fatalf("Expected output containing one of %q but got: %q", test.OutputContainsOne, strings.Join(result, ""))
  110. })
  111. }
  112. }