model_test.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. package model
  2. import (
  3. "reflect"
  4. "slices"
  5. "strings"
  6. "testing"
  7. "github.com/google/go-cmp/cmp"
  8. fs "github.com/ollama/ollama/fs/ggml"
  9. "github.com/ollama/ollama/ml"
  10. "github.com/ollama/ollama/ml/backend/ggml"
  11. "github.com/ollama/ollama/ml/nn"
  12. "github.com/ollama/ollama/model/input"
  13. )
  14. func TestParseTags(t *testing.T) {
  15. cases := []struct {
  16. value string
  17. want Tag
  18. }{
  19. {
  20. value: "output",
  21. want: Tag{
  22. Name: "output",
  23. },
  24. },
  25. {
  26. value: "output,alt:token_embd",
  27. want: Tag{
  28. Name: "output",
  29. Alternate: []string{
  30. "token_embd",
  31. },
  32. },
  33. },
  34. }
  35. for _, tt := range cases {
  36. t.Run(tt.value, func(t *testing.T) {
  37. got := ParseTags(tt.value)
  38. if diff := cmp.Diff(tt.want, got); diff != "" {
  39. t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff)
  40. }
  41. })
  42. }
  43. }
  44. type fakeBackend struct {
  45. *ggml.Backend
  46. names []string
  47. }
  48. type fakeTensor struct {
  49. *ggml.Tensor
  50. Name string
  51. }
  52. func (m *fakeBackend) Get(name string) ml.Tensor {
  53. if slices.Contains(m.names, name) {
  54. return &fakeTensor{Name: name}
  55. }
  56. return nil
  57. }
  58. func TestPopulateFields(t *testing.T) {
  59. type fakeLayer struct {
  60. Query *nn.Linear `gguf:"attn_q"`
  61. Key *nn.Linear `gguf:"attn_k"`
  62. Value *nn.Linear `gguf:"attn_v"`
  63. Output *nn.Linear `gguf:"attn_o"`
  64. }
  65. type fakeModel struct {
  66. Input *nn.Embedding `gguf:"input"`
  67. OutputNorm *nn.RMSNorm `gguf:"output_norm"`
  68. Output *nn.Linear `gguf:"output"`
  69. Layers [2]fakeLayer `gguf:"blk"`
  70. }
  71. var m fakeModel
  72. v := reflect.ValueOf(&m)
  73. v.Elem().Set(populateFields(Base{b: &fakeBackend{
  74. names: []string{
  75. "input.weight",
  76. "blk.0.attn_q.weight",
  77. "blk.0.attn_k.weight",
  78. "blk.0.attn_v.weight",
  79. "blk.1.attn_q.weight",
  80. "blk.1.attn_k.weight",
  81. "blk.1.attn_v.weight",
  82. "output_norm.weight",
  83. "output.weight",
  84. },
  85. }}, v.Elem()))
  86. if diff := cmp.Diff(fakeModel{
  87. Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
  88. OutputNorm: &nn.RMSNorm{Weight: &fakeTensor{Name: "output_norm.weight"}},
  89. Output: &nn.Linear{Weight: &fakeTensor{Name: "output.weight"}},
  90. Layers: [2]fakeLayer{
  91. {
  92. Query: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.attn_q.weight"}},
  93. Key: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.attn_k.weight"}},
  94. Value: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.attn_v.weight"}},
  95. },
  96. {
  97. Query: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.attn_q.weight"}},
  98. Key: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.attn_k.weight"}},
  99. Value: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.attn_v.weight"}},
  100. },
  101. },
  102. }, m); diff != "" {
  103. t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
  104. }
  105. }
  106. func TestPopulateFieldsAlternateName(t *testing.T) {
  107. type fakeModel struct {
  108. Input *nn.Embedding `gguf:"input"`
  109. Output *nn.Linear `gguf:"output,alt:input"`
  110. }
  111. m := fakeModel{}
  112. v := reflect.ValueOf(&m)
  113. v.Elem().Set(populateFields(Base{b: &fakeBackend{
  114. names: []string{
  115. "input.weight",
  116. },
  117. }}, v.Elem()))
  118. if diff := cmp.Diff(fakeModel{
  119. Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
  120. Output: &nn.Linear{Weight: &fakeTensor{Name: "input.weight"}},
  121. }, m); diff != "" {
  122. t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
  123. }
  124. }
  125. func TestGetTextProcessor(t *testing.T) {
  126. tp, err := getTextProcessor(fs.KV{})
  127. if err == nil {
  128. t.Error("expected error")
  129. } else if !strings.Contains(err.Error(), "unsupported model architecture") {
  130. t.Errorf("unexpected error: %v", err)
  131. } else if tp != nil {
  132. t.Error("expected nil tp")
  133. }
  134. models["dummy"] = func(ml.Config) (Model, error) {
  135. return notTextProcessorModel{}, nil
  136. }
  137. tp, err = getTextProcessor(fs.KV{"general.architecture": "dummy"})
  138. if err == nil {
  139. t.Error("expected error")
  140. } else if !strings.Contains(err.Error(), "not a TextProcessor") {
  141. t.Errorf("unexpected error: %v", err)
  142. } else if tp != nil {
  143. t.Error("expected nil tp")
  144. }
  145. }
  146. type notTextProcessorModel struct{}
  147. func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
  148. panic("unimplemented")
  149. }
  150. func (notTextProcessorModel) Backend() ml.Backend {
  151. panic("unimplemented")
  152. }
  153. func (notTextProcessorModel) Config() config {
  154. panic("unimplemented")
  155. }