gguf_test.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. package llm
  2. import (
  3. "crypto/sha256"
  4. "fmt"
  5. "io"
  6. "math"
  7. "os"
  8. "path/filepath"
  9. "strings"
  10. "testing"
  11. "github.com/google/go-cmp/cmp"
  12. )
  13. func TestGGUFRewrite(t *testing.T) {
  14. tests := []string{
  15. "phi3.gguf",
  16. }
  17. for i := range tests {
  18. tt := tests[i]
  19. t.Run(tt, func(t *testing.T) {
  20. t.Parallel()
  21. p := filepath.Join("testdata", tt)
  22. if _, err := os.Stat(p); err != nil {
  23. t.Fatalf("%s not found", p)
  24. }
  25. f, err := os.Open(p)
  26. if err != nil {
  27. t.Fatal(err)
  28. }
  29. defer f.Close()
  30. ggml, _, err := decodeGGML(t, f)
  31. if err != nil {
  32. t.Fatal(err)
  33. }
  34. temp, err := os.CreateTemp("testdata", "2"+tt)
  35. if err != nil {
  36. t.Fatal(err)
  37. }
  38. defer temp.Close()
  39. n, ggml2, err := rewriteGGML(t, ggml, temp, f)
  40. /* if n != m {
  41. t.Fatalf("n: %d, m: %d", n, m)
  42. } */
  43. if err != nil {
  44. t.Fatal(err)
  45. }
  46. if diff, diff2, ok := compareGGML(n, ggml2, ggml, temp, f); !ok {
  47. if cmp.Diff(diff, diff2) != "" {
  48. t.Fatalf("\n%s,\n%s\n%s\n%s\ndiff: %s", diff["token_embd.weight"], diff2["token_embd.weight"], diff["token_embd.weight size"], diff["token_embd.weight offset"], cmp.Diff(diff, diff2))
  49. }
  50. }
  51. })
  52. }
  53. }
  54. func formatDiff(diff map[string]string) string {
  55. var builder strings.Builder
  56. for k, v := range diff {
  57. builder.WriteString(fmt.Sprintf("%s: %s\n", k, v))
  58. }
  59. return builder.String()
  60. }
  61. func compareGGML(n int64, ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string, bool) {
  62. diff := make(map[string]string)
  63. diff2 := make(map[string]string)
  64. kv1 := ggml1.KV()
  65. kv2 := ggml2.KV()
  66. if len(kv1) != len(kv2) {
  67. diff["lenKV"] = fmt.Sprintf("kv1: %d, kv2: %d", len(kv1), len(kv2))
  68. fmt.Println("lenKV", diff["lenKV"])
  69. }
  70. for k, v := range kv1 {
  71. // if v2, ok := kv2[k]; !ok {
  72. // diff[k] = fmt.Sprintf("missing key %s", k)
  73. // } else if v != v2 {
  74. // diff[fmt.Sprintf("%s type diff", k)] = fmt.Sprintf("kv1 type: %T, kv2 type: %T", v.(*array).values, v2.(*array).values)
  75. // diff[k] = fmt.Sprintf("kv1: %d, kv2: %d", len(v.(*array).values), len(v2.(*array).values))
  76. // diff[fmt.Sprintf("%s values first 10", k)] = fmt.Sprintf("\nkv1: %#v, \nkv2: %#v", v.(*array).values[0:10], v2.(*array).values[0:10])
  77. // diff[fmt.Sprintf("%s values last 10", k)] = fmt.Sprintf("\nkv1: %#v, \nkv2: %#v", v.(*array).values[len(v.(*array).values)-10:], v2.(*array).values[len(v2.(*array).values)-10:])
  78. // diff[fmt.Sprintf("%s diff", k)] = cmp.Diff(v.(*array).values, v2.(*array).values)
  79. switch t := v.(type) {
  80. case *array:
  81. if diffy := cmp.Diff(t.values, kv2[k].(*array).values); diffy != "" {
  82. diff[k] = diffy
  83. }
  84. default:
  85. if v != kv2[k] {
  86. diff[k] = fmt.Sprintf("kv1: %v, kv2: %v", v, kv2[k])
  87. }
  88. }
  89. // }
  90. }
  91. t1 := ggml1.Tensors()
  92. t2 := ggml2.Tensors()
  93. if len(t1) != len(t2) {
  94. diff["lenTensors"] = fmt.Sprintf("t1: %d, t2: %d", len(t1), len(t2))
  95. }
  96. for _, tensor := range t1 {
  97. sha256sum := sha256.New()
  98. sr := io.NewSectionReader(f, n+int64(tensor.Offset), int64(tensor.Size()))
  99. var s int64
  100. s, err := io.Copy(sha256sum, sr)
  101. if err != nil {
  102. fmt.Println(err)
  103. }
  104. diff[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
  105. diff[tensor.Name+" size"] = fmt.Sprintf("%d", s)
  106. diff[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset)
  107. }
  108. /* sha256Sum2 := sha256.New()
  109. sr1 := io.NewSectionReader(f2, 0, n)
  110. s1, err := io.Copy(sha256Sum2, sr1)
  111. if err != nil {
  112. return nil, nil, true
  113. }
  114. sha256Sum3 := sha256.New()
  115. sr2 := io.NewSectionReader(f, 0, n)
  116. s2, err := io.Copy(sha256Sum3, sr2)
  117. if err != nil {
  118. return nil, nil, true
  119. }
  120. diff["sha"] = fmt.Sprintf("%d", s1)
  121. diff2["sha"] = fmt.Sprintf("%d", s2) */
  122. for _, tensor := range t2 {
  123. sha256sum := sha256.New()
  124. var s int64
  125. sr := io.NewSectionReader(f2, n+int64(tensor.Offset), int64(tensor.Size()))
  126. s, err := io.Copy(sha256sum, sr)
  127. if err != nil {
  128. fmt.Println(err)
  129. }
  130. diff2[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
  131. diff2[tensor.Name+" size"] = fmt.Sprintf("%d", s)
  132. diff2[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset)
  133. }
  134. return diff, diff2, len(diff) == 0
  135. }
  136. func decodeGGML(t *testing.T, f *os.File) (*GGML, int64, error) {
  137. ggml, n, err := DecodeGGML(f, math.MaxInt)
  138. if err != nil {
  139. t.Fatal(err)
  140. }
  141. return ggml, n, nil
  142. }
  143. func rewriteGGML(t *testing.T, ggml *GGML, temp *os.File, f *os.File) (int64, *GGML, error) {
  144. var tensors Tensors
  145. fmt.Println("11111111111111111111111111111111111111111")
  146. for _, tensor := range ggml.Tensors() {
  147. shape := make([]uint64, len(tensor.Shape))
  148. for i := range len(tensor.Shape) {
  149. shape[i] = tensor.Shape[len(tensor.Shape)-i-1]
  150. }
  151. fmt.Println("tensors", tensor.Name, shape, tensor.Kind, 737414+int64(tensor.Offset))
  152. tensors = append(tensors, &Tensor{
  153. Name: tensor.Name,
  154. Kind: tensor.Kind,
  155. Shape: shape,
  156. WriterTo: TensorWriter{
  157. Reader: io.NewSectionReader(f, 737414+int64(tensor.Offset), int64(tensor.Size())),
  158. },
  159. })
  160. }
  161. reader := &GGUFWriter{
  162. KV: ggml.KV(),
  163. // Update .Tensors
  164. Tensors: tensors,
  165. }
  166. n, err := io.Copy(temp, reader)
  167. if err != nil {
  168. t.Fatal(err)
  169. }
  170. fmt.Println(n, " is my offset")
  171. file, err := os.Open(temp.Name())
  172. if err != nil {
  173. t.Fatal(err)
  174. }
  175. ggml2, _, err := DecodeGGML(file, math.MaxInt)
  176. if err != nil {
  177. t.Fatal(err)
  178. }
  179. return n, ggml2, nil
  180. }