model_test.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. package model
  2. import (
  3. "reflect"
  4. "slices"
  5. "strings"
  6. "testing"
  7. "github.com/google/go-cmp/cmp"
  8. fs "github.com/ollama/ollama/fs/ggml"
  9. "github.com/ollama/ollama/ml"
  10. "github.com/ollama/ollama/ml/backend/ggml"
  11. "github.com/ollama/ollama/ml/nn"
  12. )
  13. func TestParseTags(t *testing.T) {
  14. cases := []struct {
  15. value string
  16. want Tag
  17. }{
  18. {
  19. value: "output",
  20. want: Tag{
  21. Name: "output",
  22. },
  23. },
  24. {
  25. value: "output,alt:token_embd",
  26. want: Tag{
  27. Name: "output",
  28. Alternate: []string{
  29. "token_embd",
  30. },
  31. },
  32. },
  33. }
  34. for _, tt := range cases {
  35. t.Run(tt.value, func(t *testing.T) {
  36. got := ParseTags(tt.value)
  37. if diff := cmp.Diff(tt.want, got); diff != "" {
  38. t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff)
  39. }
  40. })
  41. }
  42. }
  43. type fakeBackend struct {
  44. *ggml.Backend
  45. names []string
  46. }
  47. type fakeTensor struct {
  48. *ggml.Tensor
  49. Name string
  50. }
  51. func (m *fakeBackend) Get(name string) ml.Tensor {
  52. if slices.Contains(m.names, name) {
  53. return &fakeTensor{Name: name}
  54. }
  55. return nil
  56. }
  57. func TestPopulateFields(t *testing.T) {
  58. type fakeLayer struct {
  59. Query *nn.Linear `gguf:"attn_q"`
  60. Key *nn.Linear `gguf:"attn_k"`
  61. Value *nn.Linear `gguf:"attn_v"`
  62. Output *nn.Linear `gguf:"attn_o"`
  63. }
  64. type fakeModel struct {
  65. Input *nn.Embedding `gguf:"input"`
  66. OutputNorm *nn.RMSNorm `gguf:"output_norm"`
  67. Output *nn.Linear `gguf:"output"`
  68. Layers [2]fakeLayer `gguf:"blk"`
  69. }
  70. var m fakeModel
  71. v := reflect.ValueOf(&m)
  72. v.Elem().Set(populateFields(Base{b: &fakeBackend{
  73. names: []string{
  74. "input.weight",
  75. "blk.0.attn_q.weight",
  76. "blk.0.attn_k.weight",
  77. "blk.0.attn_v.weight",
  78. "blk.1.attn_q.weight",
  79. "blk.1.attn_k.weight",
  80. "blk.1.attn_v.weight",
  81. "output_norm.weight",
  82. "output.weight",
  83. },
  84. }}, v.Elem()))
  85. if diff := cmp.Diff(fakeModel{
  86. Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
  87. OutputNorm: &nn.RMSNorm{Weight: &fakeTensor{Name: "output_norm.weight"}},
  88. Output: &nn.Linear{Weight: &fakeTensor{Name: "output.weight"}},
  89. Layers: [2]fakeLayer{
  90. {
  91. Query: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.attn_q.weight"}},
  92. Key: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.attn_k.weight"}},
  93. Value: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.attn_v.weight"}},
  94. },
  95. {
  96. Query: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.attn_q.weight"}},
  97. Key: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.attn_k.weight"}},
  98. Value: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.attn_v.weight"}},
  99. },
  100. },
  101. }, m); diff != "" {
  102. t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
  103. }
  104. }
  105. func TestPopulateFieldsAlternateName(t *testing.T) {
  106. type fakeModel struct {
  107. Input *nn.Embedding `gguf:"input"`
  108. Output *nn.Linear `gguf:"output,alt:input"`
  109. }
  110. m := fakeModel{}
  111. v := reflect.ValueOf(&m)
  112. v.Elem().Set(populateFields(Base{b: &fakeBackend{
  113. names: []string{
  114. "input.weight",
  115. },
  116. }}, v.Elem()))
  117. if diff := cmp.Diff(fakeModel{
  118. Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
  119. Output: &nn.Linear{Weight: &fakeTensor{Name: "input.weight"}},
  120. }, m); diff != "" {
  121. t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
  122. }
  123. }
  124. func TestGetTextProcessor(t *testing.T) {
  125. tp, err := getTextProcessor(fs.KV{})
  126. if err == nil {
  127. t.Error("expected error")
  128. } else if !strings.Contains(err.Error(), "unsupported model architecture") {
  129. t.Errorf("unexpected error: %v", err)
  130. } else if tp != nil {
  131. t.Error("expected nil tp")
  132. }
  133. models["dummy"] = func(ml.Config) (Model, error) {
  134. return notTextProcessorModel{}, nil
  135. }
  136. tp, err = getTextProcessor(fs.KV{"general.architecture": "dummy"})
  137. if err == nil {
  138. t.Error("expected error")
  139. } else if !strings.Contains(err.Error(), "not a TextProcessor") {
  140. t.Errorf("unexpected error: %v", err)
  141. } else if tp != nil {
  142. t.Error("expected nil tp")
  143. }
  144. }
  145. type notTextProcessorModel struct{}
  146. func (notTextProcessorModel) Forward(ml.Context, Options) (ml.Tensor, error) {
  147. panic("unimplemented")
  148. }
  149. func (notTextProcessorModel) Backend() ml.Backend {
  150. panic("unimplemented")
  151. }
  152. func (notTextProcessorModel) Config() config {
  153. panic("unimplemented")
  154. }