ggml_test.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. package ggml
  2. import (
  3. "maps"
  4. "slices"
  5. "strings"
  6. "testing"
  7. "github.com/google/go-cmp/cmp"
  8. )
  9. func TestTensorLayers(t *testing.T) {
  10. tensors := make(map[string]*Tensor)
  11. for _, name := range []string{
  12. "token_embd.weight",
  13. "blk.0.attn_k.weight",
  14. "blk.0.attn_output.weight",
  15. "blk.0.attn_q.weight",
  16. "blk.0.attn_v.weight",
  17. "blk.0.attn_norm.weight",
  18. "blk.0.ffn_down.weight",
  19. "blk.0.ffn_gate.weight",
  20. "blk.0.ffn_up.weight",
  21. "blk.0.ffn_norm.weight",
  22. "output_norm.weight",
  23. "mm.0.bias",
  24. "mm.0.weight",
  25. "v.blk.0.attn_k.weight",
  26. "v.blk.0.attn_output.weight",
  27. "v.blk.0.attn_q.weight",
  28. "v.blk.0.attn_v.weight",
  29. "v.blk.0.attn_norm.weight",
  30. "v.blk.0.ffn_down.weight",
  31. "v.blk.0.ffn_gate.weight",
  32. "v.blk.0.ffn_up.weight",
  33. "v.blk.0.ffn_norm.weight",
  34. "v.patch_embd.weight",
  35. "v.position_embd.gate",
  36. "v.position_embd.weight",
  37. } {
  38. tensors[name] = &Tensor{Name: name}
  39. }
  40. cases := []struct {
  41. name string
  42. items []*Tensor
  43. want map[string]Layer
  44. }{
  45. {
  46. name: "text",
  47. items: slices.Collect(func(yield func(*Tensor) bool) {
  48. for k, v := range tensors {
  49. if !strings.HasPrefix(k, "mm.") && !strings.HasPrefix(k, "v.") {
  50. if !yield(v) {
  51. return
  52. }
  53. }
  54. }
  55. }),
  56. want: map[string]Layer{
  57. "blk.0": {
  58. "attn_k.weight": tensors["blk.0.attn_k.weight"],
  59. "attn_q.weight": tensors["blk.0.attn_q.weight"],
  60. "attn_v.weight": tensors["blk.0.attn_v.weight"],
  61. "attn_output.weight": tensors["blk.0.attn_output.weight"],
  62. "attn_norm.weight": tensors["blk.0.attn_norm.weight"],
  63. "ffn_down.weight": tensors["blk.0.ffn_down.weight"],
  64. "ffn_gate.weight": tensors["blk.0.ffn_gate.weight"],
  65. "ffn_up.weight": tensors["blk.0.ffn_up.weight"],
  66. "ffn_norm.weight": tensors["blk.0.ffn_norm.weight"],
  67. },
  68. "token_embd": {"weight": tensors["token_embd.weight"]},
  69. "output_norm": {"weight": tensors["output_norm.weight"]},
  70. },
  71. },
  72. {
  73. name: "vision",
  74. items: slices.Collect(func(yield func(*Tensor) bool) {
  75. for k, v := range tensors {
  76. if strings.HasPrefix(k, "mm.") || strings.HasPrefix(k, "v.") {
  77. if !yield(v) {
  78. return
  79. }
  80. }
  81. }
  82. }),
  83. want: map[string]Layer{
  84. "mm.0": {
  85. "bias": tensors["mm.0.bias"],
  86. "weight": tensors["mm.0.weight"],
  87. },
  88. "v.blk.0": {
  89. "attn_k.weight": tensors["v.blk.0.attn_k.weight"],
  90. "attn_q.weight": tensors["v.blk.0.attn_q.weight"],
  91. "attn_v.weight": tensors["v.blk.0.attn_v.weight"],
  92. "attn_output.weight": tensors["v.blk.0.attn_output.weight"],
  93. "attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
  94. "ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
  95. "ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
  96. "ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
  97. "ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
  98. },
  99. "v": {
  100. "patch_embd.weight": tensors["v.patch_embd.weight"],
  101. "position_embd.gate": tensors["v.position_embd.gate"],
  102. "position_embd.weight": tensors["v.position_embd.weight"],
  103. },
  104. },
  105. },
  106. {
  107. name: "vision and text",
  108. items: slices.Collect(maps.Values(tensors)),
  109. want: map[string]Layer{
  110. "blk.0": {
  111. "attn_k.weight": tensors["blk.0.attn_k.weight"],
  112. "attn_q.weight": tensors["blk.0.attn_q.weight"],
  113. "attn_v.weight": tensors["blk.0.attn_v.weight"],
  114. "attn_output.weight": tensors["blk.0.attn_output.weight"],
  115. "attn_norm.weight": tensors["blk.0.attn_norm.weight"],
  116. "ffn_down.weight": tensors["blk.0.ffn_down.weight"],
  117. "ffn_gate.weight": tensors["blk.0.ffn_gate.weight"],
  118. "ffn_up.weight": tensors["blk.0.ffn_up.weight"],
  119. "ffn_norm.weight": tensors["blk.0.ffn_norm.weight"],
  120. },
  121. "token_embd": {"weight": tensors["token_embd.weight"]},
  122. "output_norm": {"weight": tensors["output_norm.weight"]},
  123. "mm.0": {
  124. "bias": tensors["mm.0.bias"],
  125. "weight": tensors["mm.0.weight"],
  126. },
  127. "v.blk.0": {
  128. "attn_k.weight": tensors["v.blk.0.attn_k.weight"],
  129. "attn_q.weight": tensors["v.blk.0.attn_q.weight"],
  130. "attn_v.weight": tensors["v.blk.0.attn_v.weight"],
  131. "attn_output.weight": tensors["v.blk.0.attn_output.weight"],
  132. "attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
  133. "ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
  134. "ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
  135. "ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
  136. "ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
  137. },
  138. "v": {
  139. "patch_embd.weight": tensors["v.patch_embd.weight"],
  140. "position_embd.gate": tensors["v.position_embd.gate"],
  141. "position_embd.weight": tensors["v.position_embd.weight"],
  142. },
  143. },
  144. },
  145. }
  146. for _, tt := range cases {
  147. t.Run(tt.name, func(t *testing.T) {
  148. got := Tensors{items: tt.items}.GroupLayers()
  149. if diff := cmp.Diff(got, tt.want); diff != "" {
  150. t.Errorf("unexpected layers (-got +want):\n%s", diff)
  151. }
  152. })
  153. }
  154. }