ggml_test.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. package ggml
  2. import (
  3. "maps"
  4. "slices"
  5. "strconv"
  6. "strings"
  7. "testing"
  8. "github.com/google/go-cmp/cmp"
  9. )
  10. func TestTensorLayers(t *testing.T) {
  11. tensors := make(map[string]*Tensor)
  12. for _, name := range []string{
  13. "token_embd.weight",
  14. "blk.0.attn_k.weight",
  15. "blk.0.attn_output.weight",
  16. "blk.0.attn_q.weight",
  17. "blk.0.attn_v.weight",
  18. "blk.0.attn_norm.weight",
  19. "blk.0.ffn_down.weight",
  20. "blk.0.ffn_gate.weight",
  21. "blk.0.ffn_up.weight",
  22. "blk.0.ffn_norm.weight",
  23. "output_norm.weight",
  24. "mm.0.bias",
  25. "mm.0.weight",
  26. "v.blk.0.attn_k.weight",
  27. "v.blk.0.attn_output.weight",
  28. "v.blk.0.attn_q.weight",
  29. "v.blk.0.attn_v.weight",
  30. "v.blk.0.attn_norm.weight",
  31. "v.blk.0.ffn_down.weight",
  32. "v.blk.0.ffn_gate.weight",
  33. "v.blk.0.ffn_up.weight",
  34. "v.blk.0.ffn_norm.weight",
  35. "v.patch_embd.weight",
  36. "v.position_embd.gate",
  37. "v.position_embd.weight",
  38. } {
  39. tensors[name] = &Tensor{Name: name}
  40. }
  41. cases := []struct {
  42. name string
  43. items []*Tensor
  44. want map[string]Layer
  45. }{
  46. {
  47. name: "text",
  48. items: slices.Collect(func(yield func(*Tensor) bool) {
  49. for k, v := range tensors {
  50. if !strings.HasPrefix(k, "mm.") && !strings.HasPrefix(k, "v.") {
  51. if !yield(v) {
  52. return
  53. }
  54. }
  55. }
  56. }),
  57. want: map[string]Layer{
  58. "blk.0": {
  59. "attn_k.weight": tensors["blk.0.attn_k.weight"],
  60. "attn_q.weight": tensors["blk.0.attn_q.weight"],
  61. "attn_v.weight": tensors["blk.0.attn_v.weight"],
  62. "attn_output.weight": tensors["blk.0.attn_output.weight"],
  63. "attn_norm.weight": tensors["blk.0.attn_norm.weight"],
  64. "ffn_down.weight": tensors["blk.0.ffn_down.weight"],
  65. "ffn_gate.weight": tensors["blk.0.ffn_gate.weight"],
  66. "ffn_up.weight": tensors["blk.0.ffn_up.weight"],
  67. "ffn_norm.weight": tensors["blk.0.ffn_norm.weight"],
  68. },
  69. "token_embd": {"weight": tensors["token_embd.weight"]},
  70. "output_norm": {"weight": tensors["output_norm.weight"]},
  71. },
  72. },
  73. {
  74. name: "vision",
  75. items: slices.Collect(func(yield func(*Tensor) bool) {
  76. for k, v := range tensors {
  77. if strings.HasPrefix(k, "mm.") || strings.HasPrefix(k, "v.") {
  78. if !yield(v) {
  79. return
  80. }
  81. }
  82. }
  83. }),
  84. want: map[string]Layer{
  85. "mm.0": {
  86. "bias": tensors["mm.0.bias"],
  87. "weight": tensors["mm.0.weight"],
  88. },
  89. "v.blk.0": {
  90. "attn_k.weight": tensors["v.blk.0.attn_k.weight"],
  91. "attn_q.weight": tensors["v.blk.0.attn_q.weight"],
  92. "attn_v.weight": tensors["v.blk.0.attn_v.weight"],
  93. "attn_output.weight": tensors["v.blk.0.attn_output.weight"],
  94. "attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
  95. "ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
  96. "ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
  97. "ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
  98. "ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
  99. },
  100. "v": {
  101. "patch_embd.weight": tensors["v.patch_embd.weight"],
  102. "position_embd.gate": tensors["v.position_embd.gate"],
  103. "position_embd.weight": tensors["v.position_embd.weight"],
  104. },
  105. },
  106. },
  107. {
  108. name: "vision and text",
  109. items: slices.Collect(maps.Values(tensors)),
  110. want: map[string]Layer{
  111. "blk.0": {
  112. "attn_k.weight": tensors["blk.0.attn_k.weight"],
  113. "attn_q.weight": tensors["blk.0.attn_q.weight"],
  114. "attn_v.weight": tensors["blk.0.attn_v.weight"],
  115. "attn_output.weight": tensors["blk.0.attn_output.weight"],
  116. "attn_norm.weight": tensors["blk.0.attn_norm.weight"],
  117. "ffn_down.weight": tensors["blk.0.ffn_down.weight"],
  118. "ffn_gate.weight": tensors["blk.0.ffn_gate.weight"],
  119. "ffn_up.weight": tensors["blk.0.ffn_up.weight"],
  120. "ffn_norm.weight": tensors["blk.0.ffn_norm.weight"],
  121. },
  122. "token_embd": {"weight": tensors["token_embd.weight"]},
  123. "output_norm": {"weight": tensors["output_norm.weight"]},
  124. "mm.0": {
  125. "bias": tensors["mm.0.bias"],
  126. "weight": tensors["mm.0.weight"],
  127. },
  128. "v.blk.0": {
  129. "attn_k.weight": tensors["v.blk.0.attn_k.weight"],
  130. "attn_q.weight": tensors["v.blk.0.attn_q.weight"],
  131. "attn_v.weight": tensors["v.blk.0.attn_v.weight"],
  132. "attn_output.weight": tensors["v.blk.0.attn_output.weight"],
  133. "attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
  134. "ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
  135. "ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
  136. "ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
  137. "ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
  138. },
  139. "v": {
  140. "patch_embd.weight": tensors["v.patch_embd.weight"],
  141. "position_embd.gate": tensors["v.position_embd.gate"],
  142. "position_embd.weight": tensors["v.position_embd.weight"],
  143. },
  144. },
  145. },
  146. }
  147. for _, tt := range cases {
  148. t.Run(tt.name, func(t *testing.T) {
  149. got := Tensors{items: tt.items}.GroupLayers()
  150. if diff := cmp.Diff(got, tt.want); diff != "" {
  151. t.Errorf("unexpected layers (-got +want):\n%s", diff)
  152. }
  153. })
  154. }
  155. }
  156. // ref: https://github.com/ggml-org/llama.cpp/blob/a82c9e7c23ef6db48cebfa194dc9cebbc4ac3552/ggml/src/ggml.c#L572
  157. func TestTensorTypes(t *testing.T) {
  158. cases := []struct {
  159. kind uint32
  160. blockSize uint64
  161. typeSize uint64
  162. }{
  163. {0, 1, 4},
  164. {1, 1, 2},
  165. {2, 32, 18},
  166. {3, 32, 20},
  167. {6, 32, 22},
  168. {7, 32, 24},
  169. {8, 32, 34},
  170. {9, 32, 36},
  171. {10, 256, 84},
  172. {11, 256, 110},
  173. {12, 256, 144},
  174. {13, 256, 176},
  175. {14, 256, 210},
  176. {15, 256, 292},
  177. {16, 256, 66},
  178. {17, 256, 74},
  179. {18, 256, 98},
  180. {19, 256, 50},
  181. {20, 32, 18},
  182. {21, 256, 110},
  183. {22, 256, 82},
  184. {23, 256, 136},
  185. {24, 1, 1},
  186. {25, 1, 2},
  187. {26, 1, 4},
  188. {27, 1, 8},
  189. {28, 1, 8},
  190. {29, 256, 56},
  191. {30, 1, 2},
  192. }
  193. for _, tt := range cases {
  194. t.Run(strconv.Itoa(int(tt.kind)), func(t *testing.T) {
  195. tensor := Tensor{Kind: tt.kind}
  196. if tensor.blockSize() != tt.blockSize {
  197. t.Errorf("unexpected block size: got=%d want=%d", tensor.blockSize(), tt.blockSize)
  198. }
  199. if tensor.typeSize() != tt.typeSize {
  200. t.Errorf("unexpected type size: got=%d want=%d", tensor.typeSize(), tt.typeSize)
  201. }
  202. })
  203. }
  204. }