model_test.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. package modeltest
  2. import (
  3. "encoding/json"
  4. "os"
  5. "path/filepath"
  6. "reflect"
  7. "testing"
  8. "github.com/ollama/ollama/cache"
  9. "github.com/ollama/ollama/convert"
  10. "github.com/ollama/ollama/ml"
  11. "github.com/ollama/ollama/model"
  12. _ "github.com/ollama/ollama/model/qwen2"
  13. )
  14. func TestForward(t *testing.T) {
  15. cases := []string{
  16. "qwen2",
  17. // Add more model architectures here...
  18. }
  19. for _, tt := range cases {
  20. t.Run(tt, func(t *testing.T) {
  21. t.Parallel()
  22. p := filepath.Join("testdata", tt)
  23. if testing.Short() {
  24. t.Skip("skipping in short mode")
  25. } else if _, err := os.Stat(p); err != nil {
  26. t.Skipf("%s not found", p)
  27. }
  28. f, err := os.CreateTemp(t.TempDir(), "f16")
  29. if err != nil {
  30. t.Fatal(err)
  31. }
  32. defer func() {
  33. f.Close()
  34. os.Remove(f.Name())
  35. }()
  36. if err := convert.ConvertModel(os.DirFS(p), f); err != nil {
  37. t.Fatal(err)
  38. }
  39. m, err := model.New(f.Name())
  40. if err != nil {
  41. t.Fatal(err)
  42. }
  43. b := m.Backend()
  44. ctx := b.NewContext()
  45. ctx.SetDebug(true)
  46. // Run forward pass
  47. _, err = model.Forward(ctx, m, model.WithCache(cache.NewCausalCache(m.Backend(), 2048, ml.DTypeF32)))
  48. if err != nil {
  49. t.Fatal(err)
  50. }
  51. // Validate the graph layers
  52. data, err := os.ReadFile(filepath.Join("testdata", tt+".json"))
  53. if err != nil {
  54. t.Fatal(err)
  55. }
  56. var expected ml.Graph
  57. if err := json.Unmarshal(data, &expected); err != nil {
  58. t.Fatal(err)
  59. }
  60. result := ctx.GetTrace()
  61. if len(result.Graph) != len(expected.Graph) {
  62. t.Errorf("expected %d layers, got %d", len(expected.Graph), len(result.Graph))
  63. }
  64. for i, layer := range expected.Graph {
  65. if i >= len(result.Graph) {
  66. break
  67. }
  68. actual := result.Graph[i]
  69. if layer.Name != actual.Name {
  70. t.Errorf("layer %d: expected name %s, got %s", i, layer.Name, actual.Name)
  71. }
  72. if !reflect.DeepEqual(layer.Shape, actual.Shape) {
  73. t.Errorf("layer %d: expected shape %v, got %v", i, layer.Shape, actual.Shape)
  74. }
  75. }
  76. })
  77. }
  78. }