Josh Yan 9 months ago
parent
commit
703ecccc6b
1 changed files with 8 additions and 33 deletions
  1. 8 33
      llm/gguf_test.go

+ 8 - 33
llm/gguf_test.go

@@ -7,17 +7,15 @@ import (
 	"math"
 	"os"
 	"path/filepath"
-	"strings"
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
 )
 
 func TestGGUFRewrite(t *testing.T) {
-	tests := []string{
-		"phi3.gguf",
-		"nutiny.gguf",
-	}
+	// to test this GGUF Rewrite, add gguf files to /llm/testdata
+	// add the name of the file to the tests slice
+	tests := []string{}
 
 	for i := range tests {
 		tt := tests[i]
@@ -35,6 +33,7 @@ func TestGGUFRewrite(t *testing.T) {
 			}
 			defer f.Close()
 
+			// decode original gguf
 			ggml, _, err := decodeGGML(t, f)
 			if err != nil {
 				t.Fatal(err)
@@ -46,35 +45,22 @@ func TestGGUFRewrite(t *testing.T) {
 			}
 			defer temp.Close()
 
-			n, ggml2, err := rewriteGGML(t, ggml, temp, f)
-
-			/* if n != m {
-				t.Fatalf("n: %d, m: %d", n, m)
-			} */
+			_, ggml2, err := rewriteGGML(t, ggml, temp, f)
 
 			if err != nil {
 				t.Fatal(err)
 			}
-			//t.Fatal("FULL SIZE JFAKFJJEFJAJFLAEJJAFAJKLFJ", n)
 
-			if diff, diff2, ok := compareGGML(n, ggml2, ggml, temp, f); !ok {
+			if diff, diff2, ok := compareGGML(ggml2, ggml, temp, f); !ok {
 				if cmp.Diff(diff, diff2) != "" {
-					t.Fatalf("\n%s,\n%s\n%s\n%s\ndiff: %s", diff["token_embd.weight"], diff2["token_embd.weight"], diff["token_embd.weight size"], diff["token_embd.weight offset"], cmp.Diff(diff, diff2))
+					t.Fatalf("diff: \n%s", cmp.Diff(diff, diff2))
 				}
 			}
 		})
 	}
 }
 
-func formatDiff(diff map[string]string) string {
-	var builder strings.Builder
-	for k, v := range diff {
-		builder.WriteString(fmt.Sprintf("%s: %s\n", k, v))
-	}
-	return builder.String()
-}
-
-func compareGGML(n int64, ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string, bool) {
+func compareGGML(ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string, bool) {
 	diff := make(map[string]string)
 	diff2 := make(map[string]string)
 
@@ -87,15 +73,6 @@ func compareGGML(n int64, ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[stri
 	}
 
 	for k, v := range kv1 {
-		// if v2, ok := kv2[k]; !ok {
-		// diff[k] = fmt.Sprintf("missing key %s", k)
-		// } else if v != v2 {
-		// diff[fmt.Sprintf("%s type diff", k)] = fmt.Sprintf("kv1 type: %T, kv2 type: %T", v.(*array).values, v2.(*array).values)
-		// diff[k] = fmt.Sprintf("kv1: %d, kv2: %d", len(v.(*array).values), len(v2.(*array).values))
-		// diff[fmt.Sprintf("%s values first 10", k)] = fmt.Sprintf("\nkv1: %#v, \nkv2: %#v", v.(*array).values[0:10], v2.(*array).values[0:10])
-		// diff[fmt.Sprintf("%s values last 10", k)] = fmt.Sprintf("\nkv1: %#v, \nkv2: %#v", v.(*array).values[len(v.(*array).values)-10:], v2.(*array).values[len(v2.(*array).values)-10:])
-		// diff[fmt.Sprintf("%s diff", k)] = cmp.Diff(v.(*array).values, v2.(*array).values)
-
 		switch t := v.(type) {
 		case *array:
 			if diffy := cmp.Diff(t.values, kv2[k].(*array).values); diffy != "" {
@@ -106,8 +83,6 @@ func compareGGML(n int64, ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[stri
 				diff[k] = fmt.Sprintf("kv1: %v, kv2: %v", v, kv2[k])
 			}
 		}
-
-		// }
 	}
 
 	t1 := ggml1.Tensors()