gguf_test.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. package llm
  2. import (
  3. "crypto/sha256"
  4. "fmt"
  5. "io"
  6. "math"
  7. "os"
  8. "path/filepath"
  9. "testing"
  10. "github.com/google/go-cmp/cmp"
  11. )
  12. func TestGGUFRewrite(t *testing.T) {
  13. // to test this GGUF Rewrite, add gguf files to /llm/testdata
  14. // add the name of the file to the tests slice
  15. tests := []string{}
  16. for i := range tests {
  17. tt := tests[i]
  18. t.Run(tt, func(t *testing.T) {
  19. t.Parallel()
  20. p := filepath.Join("testdata", tt)
  21. if _, err := os.Stat(p); err != nil {
  22. t.Fatalf("%s not found", p)
  23. }
  24. f, err := os.Open(p)
  25. if err != nil {
  26. t.Fatal(err)
  27. }
  28. defer f.Close()
  29. // decode original gguf
  30. ggml, _, err := decodeGGML(t, f)
  31. if err != nil {
  32. t.Fatal(err)
  33. }
  34. temp, err := os.CreateTemp("testdata", "2"+tt)
  35. if err != nil {
  36. t.Fatal(err)
  37. }
  38. defer temp.Close()
  39. _, ggml2, err := rewriteGGML(t, ggml, temp, f)
  40. if err != nil {
  41. t.Fatal(err)
  42. }
  43. if diff, diff2, ok := compareGGML(ggml2, ggml, temp, f); !ok {
  44. if cmp.Diff(diff, diff2) != "" {
  45. t.Fatalf("diff: \n%s", cmp.Diff(diff, diff2))
  46. }
  47. }
  48. })
  49. }
  50. }
  51. func compareGGML(ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string, bool) {
  52. diff := make(map[string]string)
  53. diff2 := make(map[string]string)
  54. kv1 := ggml1.KV()
  55. kv2 := ggml2.KV()
  56. if len(kv1) != len(kv2) {
  57. diff["lenKV"] = fmt.Sprintf("kv1: %d, kv2: %d", len(kv1), len(kv2))
  58. fmt.Println("lenKV", diff["lenKV"])
  59. }
  60. for k, v := range kv1 {
  61. switch t := v.(type) {
  62. case *array:
  63. if diffy := cmp.Diff(t.values, kv2[k].(*array).values); diffy != "" {
  64. diff[k] = diffy
  65. }
  66. default:
  67. if v != kv2[k] {
  68. diff[k] = fmt.Sprintf("kv1: %v, kv2: %v", v, kv2[k])
  69. }
  70. }
  71. }
  72. t1 := ggml1.Tensors()
  73. t2 := ggml2.Tensors()
  74. if len(t1.Items) != len(t2.Items) {
  75. diff["lenTensors"] = fmt.Sprintf("t1: %d, t2: %d", len(t1.Items), len(t2.Items))
  76. }
  77. for _, tensor := range t1.Items {
  78. sha256sum := sha256.New()
  79. sr := io.NewSectionReader(f, t1.Offset+int64(tensor.Offset), int64(tensor.Size()))
  80. var s int64
  81. s, err := io.Copy(sha256sum, sr)
  82. if err != nil {
  83. fmt.Println(err)
  84. }
  85. diff[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
  86. diff[tensor.Name+" size"] = fmt.Sprintf("%d", s)
  87. diff[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset)
  88. }
  89. /* sha256Sum2 := sha256.New()
  90. sr1 := io.NewSectionReader(f2, 0, n)
  91. s1, err := io.Copy(sha256Sum2, sr1)
  92. if err != nil {
  93. return nil, nil, true
  94. }
  95. sha256Sum3 := sha256.New()
  96. sr2 := io.NewSectionReader(f, 0, n)
  97. s2, err := io.Copy(sha256Sum3, sr2)
  98. if err != nil {
  99. return nil, nil, true
  100. }
  101. diff["sha"] = fmt.Sprintf("%d", s1)
  102. diff2["sha"] = fmt.Sprintf("%d", s2) */
  103. for _, tensor := range t2.Items {
  104. sha256sum := sha256.New()
  105. var s int64
  106. sr := io.NewSectionReader(f2, t1.Offset+int64(tensor.Offset), int64(tensor.Size()))
  107. s, err := io.Copy(sha256sum, sr)
  108. if err != nil {
  109. fmt.Println(err)
  110. }
  111. diff2[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
  112. diff2[tensor.Name+" size"] = fmt.Sprintf("%d", s)
  113. diff2[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset)
  114. }
  115. return diff, diff2, len(diff) == 0
  116. }
  117. func decodeGGML(t *testing.T, f *os.File) (*GGML, int64, error) {
  118. ggml, n, err := DecodeGGML(f, math.MaxInt)
  119. if err != nil {
  120. t.Fatal(err)
  121. }
  122. return ggml, n, nil
  123. }
  124. func rewriteGGML(t *testing.T, ggml *GGML, temp *os.File, f *os.File) (int64, *GGML, error) {
  125. var tensors []*Tensor
  126. fmt.Println("11111111111111111111111111111111111111111")
  127. for _, tensor := range ggml.Tensors().Items {
  128. shape := make([]uint64, len(tensor.Shape))
  129. for i := range len(tensor.Shape) {
  130. shape[i] = tensor.Shape[len(tensor.Shape)-i-1]
  131. }
  132. fmt.Println("tensors", tensor.Name, shape, tensor.Kind, tensor.Offset)
  133. fmt.Println(ggml.Tensors().Offset)
  134. tensors = append(tensors, &Tensor{
  135. Name: tensor.Name,
  136. Kind: tensor.Kind,
  137. Shape: shape,
  138. WriterTo: TensorWriter{
  139. Reader: io.NewSectionReader(f, ggml.Tensors().Offset+int64(tensor.Offset), int64(tensor.Size())),
  140. },
  141. })
  142. }
  143. reader := &GGUFWriter{
  144. KV: ggml.KV(),
  145. // Update .Tensors
  146. Tensors: Tensors{
  147. Items: tensors,
  148. Offset: ggml.Tensors().Offset,
  149. },
  150. }
  151. n, err := io.Copy(temp, reader)
  152. if err != nil {
  153. t.Fatal(err)
  154. }
  155. fmt.Println(n, " is my offset")
  156. file, err := os.Open(temp.Name())
  157. if err != nil {
  158. t.Fatal(err)
  159. }
  160. ggml2, _, err := DecodeGGML(file, math.MaxInt)
  161. if err != nil {
  162. t.Fatal(err)
  163. }
  164. return n, ggml2, nil
  165. }