model_test.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. package model
  2. import (
  3. "reflect"
  4. "slices"
  5. "testing"
  6. "github.com/google/go-cmp/cmp"
  7. "github.com/ollama/ollama/ml"
  8. "github.com/ollama/ollama/ml/backend/ggml"
  9. "github.com/ollama/ollama/ml/nn"
  10. )
  11. func TestParseTags(t *testing.T) {
  12. cases := []struct {
  13. value string
  14. want Tag
  15. }{
  16. {
  17. value: "output",
  18. want: Tag{
  19. Name: "output",
  20. },
  21. },
  22. {
  23. value: "output,alt:token_embd",
  24. want: Tag{
  25. Name: "output",
  26. Alternate: []string{
  27. "token_embd",
  28. },
  29. },
  30. },
  31. }
  32. for _, tt := range cases {
  33. t.Run(tt.value, func(t *testing.T) {
  34. got := ParseTags(tt.value)
  35. if diff := cmp.Diff(tt.want, got); diff != "" {
  36. t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff)
  37. }
  38. })
  39. }
  40. }
  41. type fakeBackend struct {
  42. *ggml.Backend
  43. names []string
  44. }
  45. type fakeTensor struct {
  46. *ggml.Tensor
  47. Name string
  48. }
  49. func (m *fakeBackend) Get(name string) ml.Tensor {
  50. if slices.Contains(m.names, name) {
  51. return &fakeTensor{Name: name}
  52. }
  53. return nil
  54. }
  55. func TestPopulateFields(t *testing.T) {
  56. type fakeLayer struct {
  57. Query *nn.Linear `gguf:"attn_q"`
  58. Key *nn.Linear `gguf:"attn_k"`
  59. Value *nn.Linear `gguf:"attn_v"`
  60. Output *nn.Linear `gguf:"attn_o"`
  61. }
  62. type fakeModel struct {
  63. Input *nn.Embedding `gguf:"input"`
  64. OutputNorm *nn.RMSNorm `gguf:"output_norm"`
  65. Output *nn.Linear `gguf:"output"`
  66. Layers [2]fakeLayer `gguf:"blk"`
  67. }
  68. var m fakeModel
  69. v := reflect.ValueOf(&m)
  70. v.Elem().Set(populateFields(Base{b: &fakeBackend{
  71. names: []string{
  72. "input.weight",
  73. "blk.0.attn_q.weight",
  74. "blk.0.attn_k.weight",
  75. "blk.0.attn_v.weight",
  76. "blk.1.attn_q.weight",
  77. "blk.1.attn_k.weight",
  78. "blk.1.attn_v.weight",
  79. "output_norm.weight",
  80. "output.weight",
  81. },
  82. }}, v.Elem()))
  83. if diff := cmp.Diff(fakeModel{
  84. Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
  85. OutputNorm: &nn.RMSNorm{Weight: &fakeTensor{Name: "output_norm.weight"}},
  86. Output: &nn.Linear{Weight: &fakeTensor{Name: "output.weight"}},
  87. Layers: [2]fakeLayer{
  88. {
  89. Query: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.attn_q.weight"}},
  90. Key: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.attn_k.weight"}},
  91. Value: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.attn_v.weight"}},
  92. },
  93. {
  94. Query: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.attn_q.weight"}},
  95. Key: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.attn_k.weight"}},
  96. Value: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.attn_v.weight"}},
  97. },
  98. },
  99. }, m); diff != "" {
  100. t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
  101. }
  102. }
  103. func TestPopulateFieldsAlternateName(t *testing.T) {
  104. type fakeModel struct {
  105. Input *nn.Embedding `gguf:"input"`
  106. Output *nn.Linear `gguf:"output,alt:input"`
  107. }
  108. m := fakeModel{}
  109. v := reflect.ValueOf(&m)
  110. v.Elem().Set(populateFields(Base{b: &fakeBackend{
  111. names: []string{
  112. "input.weight",
  113. },
  114. }}, v.Elem()))
  115. if diff := cmp.Diff(fakeModel{
  116. Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
  117. Output: &nn.Linear{Weight: &fakeTensor{Name: "input.weight"}},
  118. }, m); diff != "" {
  119. t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
  120. }
  121. }