ggml_test.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. package ggml
  2. import (
  3. "maps"
  4. "slices"
  5. "strings"
  6. "testing"
  7. "github.com/google/go-cmp/cmp"
  8. )
  9. type ggufModel struct {
  10. kv KV
  11. tensors Tensors
  12. }
  13. func (m *ggufModel) KV() KV { return m.kv }
  14. func (m *ggufModel) Tensors() Tensors { return m.tensors }
  15. func TestGraphNoVocab(t *testing.T) {
  16. g := &GGML{
  17. container: &containerGGUF{},
  18. model: &ggufModel{
  19. kv: KV{
  20. "general.architecture": "llama",
  21. "block_count": uint32(1),
  22. },
  23. tensors: Tensors{
  24. items: []*Tensor{},
  25. },
  26. },
  27. }
  28. // This should not panic
  29. _, _, _ = g.GraphSize(1, 1, "f16")
  30. }
  31. func TestTensorLayers(t *testing.T) {
  32. tensors := make(map[string]*Tensor)
  33. for _, name := range []string{
  34. "token_embd.weight",
  35. "blk.0.attn_k.weight",
  36. "blk.0.attn_output.weight",
  37. "blk.0.attn_q.weight",
  38. "blk.0.attn_v.weight",
  39. "blk.0.attn_norm.weight",
  40. "blk.0.ffn_down.weight",
  41. "blk.0.ffn_gate.weight",
  42. "blk.0.ffn_up.weight",
  43. "blk.0.ffn_norm.weight",
  44. "output_norm.weight",
  45. "mm.0.bias",
  46. "mm.0.weight",
  47. "v.blk.0.attn_k.weight",
  48. "v.blk.0.attn_output.weight",
  49. "v.blk.0.attn_q.weight",
  50. "v.blk.0.attn_v.weight",
  51. "v.blk.0.attn_norm.weight",
  52. "v.blk.0.ffn_down.weight",
  53. "v.blk.0.ffn_gate.weight",
  54. "v.blk.0.ffn_up.weight",
  55. "v.blk.0.ffn_norm.weight",
  56. "v.patch_embd.weight",
  57. "v.position_embd.gate",
  58. "v.position_embd.weight",
  59. } {
  60. tensors[name] = &Tensor{Name: name}
  61. }
  62. cases := []struct {
  63. name string
  64. items []*Tensor
  65. want map[string]Layer
  66. }{
  67. {
  68. name: "text",
  69. items: slices.Collect(func(yield func(*Tensor) bool) {
  70. for k, v := range tensors {
  71. if !strings.HasPrefix(k, "mm.") && !strings.HasPrefix(k, "v.") {
  72. if !yield(v) {
  73. return
  74. }
  75. }
  76. }
  77. }),
  78. want: map[string]Layer{
  79. "blk.0": {
  80. "attn_k.weight": tensors["blk.0.attn_k.weight"],
  81. "attn_q.weight": tensors["blk.0.attn_q.weight"],
  82. "attn_v.weight": tensors["blk.0.attn_v.weight"],
  83. "attn_output.weight": tensors["blk.0.attn_output.weight"],
  84. "attn_norm.weight": tensors["blk.0.attn_norm.weight"],
  85. "ffn_down.weight": tensors["blk.0.ffn_down.weight"],
  86. "ffn_gate.weight": tensors["blk.0.ffn_gate.weight"],
  87. "ffn_up.weight": tensors["blk.0.ffn_up.weight"],
  88. "ffn_norm.weight": tensors["blk.0.ffn_norm.weight"],
  89. },
  90. "token_embd": {"weight": tensors["token_embd.weight"]},
  91. "output_norm": {"weight": tensors["output_norm.weight"]},
  92. },
  93. },
  94. {
  95. name: "vision",
  96. items: slices.Collect(func(yield func(*Tensor) bool) {
  97. for k, v := range tensors {
  98. if strings.HasPrefix(k, "mm.") || strings.HasPrefix(k, "v.") {
  99. if !yield(v) {
  100. return
  101. }
  102. }
  103. }
  104. }),
  105. want: map[string]Layer{
  106. "mm.0": {
  107. "bias": tensors["mm.0.bias"],
  108. "weight": tensors["mm.0.weight"],
  109. },
  110. "v.blk.0": {
  111. "attn_k.weight": tensors["v.blk.0.attn_k.weight"],
  112. "attn_q.weight": tensors["v.blk.0.attn_q.weight"],
  113. "attn_v.weight": tensors["v.blk.0.attn_v.weight"],
  114. "attn_output.weight": tensors["v.blk.0.attn_output.weight"],
  115. "attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
  116. "ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
  117. "ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
  118. "ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
  119. "ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
  120. },
  121. "v": {
  122. "patch_embd.weight": tensors["v.patch_embd.weight"],
  123. "position_embd.gate": tensors["v.position_embd.gate"],
  124. "position_embd.weight": tensors["v.position_embd.weight"],
  125. },
  126. },
  127. },
  128. {
  129. name: "vision and text",
  130. items: slices.Collect(maps.Values(tensors)),
  131. want: map[string]Layer{
  132. "blk.0": {
  133. "attn_k.weight": tensors["blk.0.attn_k.weight"],
  134. "attn_q.weight": tensors["blk.0.attn_q.weight"],
  135. "attn_v.weight": tensors["blk.0.attn_v.weight"],
  136. "attn_output.weight": tensors["blk.0.attn_output.weight"],
  137. "attn_norm.weight": tensors["blk.0.attn_norm.weight"],
  138. "ffn_down.weight": tensors["blk.0.ffn_down.weight"],
  139. "ffn_gate.weight": tensors["blk.0.ffn_gate.weight"],
  140. "ffn_up.weight": tensors["blk.0.ffn_up.weight"],
  141. "ffn_norm.weight": tensors["blk.0.ffn_norm.weight"],
  142. },
  143. "token_embd": {"weight": tensors["token_embd.weight"]},
  144. "output_norm": {"weight": tensors["output_norm.weight"]},
  145. "mm.0": {
  146. "bias": tensors["mm.0.bias"],
  147. "weight": tensors["mm.0.weight"],
  148. },
  149. "v.blk.0": {
  150. "attn_k.weight": tensors["v.blk.0.attn_k.weight"],
  151. "attn_q.weight": tensors["v.blk.0.attn_q.weight"],
  152. "attn_v.weight": tensors["v.blk.0.attn_v.weight"],
  153. "attn_output.weight": tensors["v.blk.0.attn_output.weight"],
  154. "attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
  155. "ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
  156. "ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
  157. "ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
  158. "ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
  159. },
  160. "v": {
  161. "patch_embd.weight": tensors["v.patch_embd.weight"],
  162. "position_embd.gate": tensors["v.position_embd.gate"],
  163. "position_embd.weight": tensors["v.position_embd.weight"],
  164. },
  165. },
  166. },
  167. }
  168. for _, tt := range cases {
  169. t.Run(tt.name, func(t *testing.T) {
  170. got := Tensors{items: tt.items}.GroupLayers()
  171. if diff := cmp.Diff(got, tt.want); diff != "" {
  172. t.Errorf("unexpected layers (-got +want):\n%s", diff)
  173. }
  174. })
  175. }
  176. }