gguf_test.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. package llm
  2. import (
  3. "crypto/sha256"
  4. "fmt"
  5. "io"
  6. "math"
  7. "os"
  8. "path/filepath"
  9. "testing"
  10. "github.com/google/go-cmp/cmp"
  11. )
  12. // TestGGUFDecode tests the decoding and rewriting of (unsorted) GGUF files
  13. // To run, add GGUF files to /llm/testdata and add the name of the file to the tests slice
  14. // This creates a temporary file in /llm/testdata that will deleted only if the test passes
  15. // Note: map[Tensor.Name + " offset"] is commented since sorting will reorder the tensors
  16. // Comment out sort.Sort(gguf.Tensors) in gguf.go to test offsets
  17. func TestGGUFRewrite(t *testing.T) {
  18. tests := []string{
  19. "phi3.gguf",
  20. }
  21. for i := range tests {
  22. tt := tests[i]
  23. t.Run(tt, func(t *testing.T) {
  24. t.Parallel()
  25. p := filepath.Join("testdata", tt)
  26. if _, err := os.Stat(p); err != nil {
  27. t.Skip("file not found", p)
  28. }
  29. wantFile, err := os.Open(p)
  30. if err != nil {
  31. t.Fatal(err)
  32. }
  33. defer wantFile.Close()
  34. // decode original gguf
  35. _, wantGGML, err := decodeGGML(t, wantFile)
  36. if err != nil {
  37. t.Fatal(err)
  38. }
  39. gotFile, err := os.CreateTemp("testdata", tt)
  40. if err != nil {
  41. t.Fatal(err)
  42. }
  43. defer func() {
  44. gotFile.Close()
  45. if !t.Failed() {
  46. os.Remove(gotFile.Name())
  47. }
  48. }()
  49. _, gotGGML, err := rewriteGGML(t, wantGGML, gotFile, wantFile)
  50. if err != nil {
  51. t.Fatal(err)
  52. }
  53. diff, diff2 := compareGGML(t, gotGGML, wantGGML, gotFile, wantFile)
  54. if cmp.Diff(diff, diff2) != "" {
  55. t.Fatalf("diff: \n%s", cmp.Diff(diff, diff2))
  56. }
  57. })
  58. }
  59. }
  60. func compareGGML(t *testing.T, gotGGML, wantGGML *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string) {
  61. got := make(map[string]string)
  62. want := make(map[string]string)
  63. gotKV := gotGGML.KV()
  64. wantKV := wantGGML.KV()
  65. if len(gotKV) != len(wantKV) {
  66. t.Fatalf("got length: %d != want length: %d", len(gotKV), len(wantKV))
  67. }
  68. for k, v := range gotKV {
  69. switch t := v.(type) {
  70. case *array:
  71. if diffy := cmp.Diff(t.values, wantKV[k].(*array).values); diffy != "" {
  72. got[k] = diffy
  73. }
  74. default:
  75. if v != wantKV[k] {
  76. got[k] = fmt.Sprintf("kv1: %v, kv2: %v", v, want[k])
  77. }
  78. }
  79. }
  80. gotTensors := gotGGML.Tensors().Items
  81. gotOffset := gotGGML.Tensors().Offset
  82. wantTensors := wantGGML.Tensors().Items
  83. wantOffset := wantGGML.Tensors().Offset
  84. if len(gotTensors) != len(wantTensors) {
  85. got["lenTensors"] = fmt.Sprintf("t1: %d, t2: %d", len(gotTensors), len(wantTensors))
  86. }
  87. for _, tensor := range gotTensors {
  88. sha256sum := sha256.New()
  89. sr := io.NewSectionReader(f, gotOffset+int64(tensor.Offset), int64(tensor.Size()))
  90. var s int64
  91. s, err := io.Copy(sha256sum, sr)
  92. if err != nil {
  93. t.Fatalf("error: %v", err)
  94. }
  95. got[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
  96. got[tensor.Name+" size"] = fmt.Sprintf("%d", s)
  97. // got[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset)
  98. }
  99. for _, tensor := range wantTensors {
  100. sha256sum := sha256.New()
  101. var s int64
  102. sr := io.NewSectionReader(f2, wantOffset +int64(tensor.Offset), int64(tensor.Size()))
  103. s, err := io.Copy(sha256sum, sr)
  104. if err != nil {
  105. t.Fatalf("error: %v", err)
  106. }
  107. want[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
  108. want[tensor.Name+" size"] = fmt.Sprintf("%d", s)
  109. // want[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset)
  110. }
  111. return got, want
  112. }
  113. func decodeGGML(t *testing.T, f *os.File) (int64, *GGML, error) {
  114. ggml, n, err := DecodeGGML(f, math.MaxInt)
  115. if err != nil {
  116. t.Fatal(err)
  117. }
  118. return n, ggml, nil
  119. }
  120. func rewriteGGML(t *testing.T, ggml *GGML, gotFile *os.File, wantFile *os.File) (int64, *GGML, error) {
  121. var tensors []*Tensor
  122. for _, tensor := range ggml.Tensors().Items {
  123. shape := make([]uint64, len(tensor.Shape))
  124. for i := range len(tensor.Shape) {
  125. shape[i] = tensor.Shape[len(tensor.Shape)-i-1]
  126. }
  127. tensors = append(tensors, &Tensor{
  128. Name: tensor.Name,
  129. Kind: tensor.Kind,
  130. Shape: shape,
  131. WriterTo: TensorWriter{
  132. Reader: io.NewSectionReader(wantFile, ggml.Tensors().Offset+int64(tensor.Offset), int64(tensor.Size())),
  133. },
  134. })
  135. }
  136. reader := &GGUFWriter{
  137. KV: ggml.KV(),
  138. Tensors: Tensors{
  139. Items: tensors,
  140. Offset: ggml.Tensors().Offset,
  141. },
  142. }
  143. n, err := io.Copy(gotFile, reader)
  144. if err != nil {
  145. t.Fatal(err)
  146. }
  147. file, err := os.Open(gotFile.Name())
  148. if err != nil {
  149. t.Fatal(err)
  150. }
  151. ggml2, _, err := DecodeGGML(file, math.MaxInt)
  152. if err != nil {
  153. t.Fatal(err)
  154. }
  155. return n, ggml2, nil
  156. }