Josh Yan 9 months ago
parent
commit
6ee22d5080
2 changed files with 61 additions and 74 deletions
  1. 1 1
      llm/gguf.go
  2. 60 73
      llm/gguf_test.go

+ 1 - 1
llm/gguf.go

@@ -911,7 +911,6 @@ func ggufWriteKV(ws io.Writer, k string, v any) error {
 		}
 
 	default:
-		fmt.Println("type is", v)
 		return fmt.Errorf("improper type for '%s'", k)
 	}
 
@@ -919,5 +918,6 @@ func ggufWriteKV(ws io.Writer, k string, v any) error {
 }
 
 func ggufPadding(offset, align int64) int64 {
+	// we mod twice in the case offset%align = 0
 	return (align - offset%align) % align
 }

+ 60 - 73
llm/gguf_test.go

@@ -12,10 +12,14 @@ import (
 	"github.com/google/go-cmp/cmp"
 )
 
+// TestGGUFDecode tests the decoding and rewriting of (unsorted) GGUF files
+// To run, add GGUF files to /llm/testdata and add the name of the file to the tests slice
+// Should comment //sort.Sort(tensors) in gguf.go
+// This creates a temporary file in /llm/testdata that will deleted only if the test passes
 func TestGGUFRewrite(t *testing.T) {
-	// to test this GGUF Rewrite, add gguf files to /llm/testdata
-	// add the name of the file to the tests slice
-	tests := []string{}
+	tests := []string{
+		"nutiny.gguf",
+	}
 
 	for i := range tests {
 		tt := tests[i]
@@ -24,169 +28,152 @@ func TestGGUFRewrite(t *testing.T) {
 			p := filepath.Join("testdata", tt)
 
 			if _, err := os.Stat(p); err != nil {
-				t.Fatalf("%s not found", p)
+				t.Skip("file not found", p)
 			}
 
-			f, err := os.Open(p)
+			wantFile, err := os.Open(p)
 			if err != nil {
 				t.Fatal(err)
 			}
-			defer f.Close()
+			defer wantFile.Close()
 
 			// decode original gguf
-			ggml, _, err := decodeGGML(t, f)
+			_, wantGGML, err := decodeGGML(t, wantFile)
 			if err != nil {
 				t.Fatal(err)
 			}
 
-			temp, err := os.CreateTemp("testdata", "2"+tt)
+			gotFile, err := os.CreateTemp("testdata", tt)
 			if err != nil {
 				t.Fatal(err)
 			}
-			defer temp.Close()
+			defer func() {
+				gotFile.Close()
+				if !t.Failed() {
+					os.Remove(gotFile.Name())
+				}
+			}()
 
-			_, ggml2, err := rewriteGGML(t, ggml, temp, f)
+			_, gotGGML, err := rewriteGGML(t, wantGGML, gotFile, wantFile)
 
 			if err != nil {
 				t.Fatal(err)
 			}
 
-			if diff, diff2, ok := compareGGML(ggml2, ggml, temp, f); !ok {
-				if cmp.Diff(diff, diff2) != "" {
-					t.Fatalf("diff: \n%s", cmp.Diff(diff, diff2))
-				}
+			diff, diff2 := compareGGML(t, gotGGML, wantGGML, gotFile, wantFile) 
+			if cmp.Diff(diff, diff2) != "" {
+				t.Fatalf("diff: \n%s", cmp.Diff(diff, diff2))
 			}
 		})
 	}
 }
 
-func compareGGML(ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string, bool) {
-	diff := make(map[string]string)
-	diff2 := make(map[string]string)
+func compareGGML(t *testing.T, gotGGML, wantGGML *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string) {
+	got := make(map[string]string)
+	want := make(map[string]string)
 
-	kv1 := ggml1.KV()
-	kv2 := ggml2.KV()
+	gotKV := gotGGML.KV()
+	wantKV := wantGGML.KV()
 
-	if len(kv1) != len(kv2) {
-		diff["lenKV"] = fmt.Sprintf("kv1: %d, kv2: %d", len(kv1), len(kv2))
-		fmt.Println("lenKV", diff["lenKV"])
+	if len(gotKV) != len(wantKV) {
+		t.Fatalf("got length: %d != want length: %d", len(gotKV), len(wantKV))
 	}
 
-	for k, v := range kv1 {
+	for k, v := range gotKV {
 		switch t := v.(type) {
 		case *array:
-			if diffy := cmp.Diff(t.values, kv2[k].(*array).values); diffy != "" {
-				diff[k] = diffy
+			if diffy := cmp.Diff(t.values, wantKV[k].(*array).values); diffy != "" {
+				got[k] = diffy
 			}
 		default:
-			if v != kv2[k] {
-				diff[k] = fmt.Sprintf("kv1: %v, kv2: %v", v, kv2[k])
+			if v != wantKV[k] {
+				got[k] = fmt.Sprintf("kv1: %v, kv2: %v", v, want[k])
 			}
 		}
 	}
 
-	t1 := ggml1.Tensors()
-	t2 := ggml2.Tensors()
+	gotTensors := gotGGML.Tensors().Items
+	gotOffset := gotGGML.Tensors().Offset
+	wantTensors := wantGGML.Tensors().Items
+	wantOffset := wantGGML.Tensors().Offset
 
-	if len(t1.Items) != len(t2.Items) {
-		diff["lenTensors"] = fmt.Sprintf("t1: %d, t2: %d", len(t1.Items), len(t2.Items))
+	if len(gotTensors) != len(wantTensors) {
+		got["lenTensors"] = fmt.Sprintf("t1: %d, t2: %d", len(gotTensors), len(wantTensors))
 	}
 
-	for _, tensor := range t1.Items {
+	for _, tensor := range gotTensors {
 		sha256sum := sha256.New()
-		sr := io.NewSectionReader(f, t1.Offset+int64(tensor.Offset), int64(tensor.Size()))
+		sr := io.NewSectionReader(f, gotOffset+int64(tensor.Offset), int64(tensor.Size()))
 		var s int64
 		s, err := io.Copy(sha256sum, sr)
 		if err != nil {
-			fmt.Println(err)
+			t.Fatalf("error: %v", err)
 		}
 
-		diff[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
-		diff[tensor.Name+" size"] = fmt.Sprintf("%d", s)
-		diff[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset)
-	}
-
-	/* sha256Sum2 := sha256.New()
-	sr1 := io.NewSectionReader(f2, 0, n)
-	s1, err := io.Copy(sha256Sum2, sr1)
-	if err != nil {
-		return nil, nil, true
-	}
-
-	sha256Sum3 := sha256.New()
-	sr2 := io.NewSectionReader(f, 0, n)
-	s2, err := io.Copy(sha256Sum3, sr2)
-	if err != nil {
-		return nil, nil, true
+		got[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
+		got[tensor.Name+" size"] = fmt.Sprintf("%d", s)
+		got[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset)
 	}
 
-	diff["sha"] = fmt.Sprintf("%d", s1)
-	diff2["sha"] = fmt.Sprintf("%d", s2) */
-
-	for _, tensor := range t2.Items {
+	for _, tensor := range wantTensors {
 		sha256sum := sha256.New()
 		var s int64
-		sr := io.NewSectionReader(f2, t1.Offset+int64(tensor.Offset), int64(tensor.Size()))
+		sr := io.NewSectionReader(f2, wantOffset +int64(tensor.Offset), int64(tensor.Size()))
 		s, err := io.Copy(sha256sum, sr)
 		if err != nil {
-			fmt.Println(err)
+			t.Fatalf("error: %v", err)
 		}
 
-		diff2[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
-		diff2[tensor.Name+" size"] = fmt.Sprintf("%d", s)
-		diff2[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset)
+		want[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
+		want[tensor.Name+" size"] = fmt.Sprintf("%d", s)
+		want[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset)
 	}
-	return diff, diff2, len(diff) == 0
-
+	return got, want
 }
-func decodeGGML(t *testing.T, f *os.File) (*GGML, int64, error) {
+
+func decodeGGML(t *testing.T, f *os.File) (int64, *GGML, error) {
 
 	ggml, n, err := DecodeGGML(f, math.MaxInt)
 	if err != nil {
 		t.Fatal(err)
 	}
-	return ggml, n, nil
+	return n, ggml, nil
 }
 
-func rewriteGGML(t *testing.T, ggml *GGML, temp *os.File, f *os.File) (int64, *GGML, error) {
+func rewriteGGML(t *testing.T, ggml *GGML, gotFile *os.File, wantFile *os.File) (int64, *GGML, error) {
 	var tensors []*Tensor
 
-	fmt.Println("11111111111111111111111111111111111111111")
 	for _, tensor := range ggml.Tensors().Items {
 		shape := make([]uint64, len(tensor.Shape))
 		for i := range len(tensor.Shape) {
 			shape[i] = tensor.Shape[len(tensor.Shape)-i-1]
 		}
 
-		fmt.Println("tensors", tensor.Name, shape, tensor.Kind, tensor.Offset)
-		fmt.Println(ggml.Tensors().Offset)
 		tensors = append(tensors, &Tensor{
 			Name:  tensor.Name,
 			Kind:  tensor.Kind,
 			Shape: shape,
 
 			WriterTo: TensorWriter{
-				Reader: io.NewSectionReader(f, ggml.Tensors().Offset+int64(tensor.Offset), int64(tensor.Size())),
+				Reader: io.NewSectionReader(wantFile, ggml.Tensors().Offset+int64(tensor.Offset), int64(tensor.Size())),
 			},
 		})
 	}
 
 	reader := &GGUFWriter{
 		KV: ggml.KV(),
-		// Update .Tensors
 		Tensors: Tensors{
 			Items:  tensors,
 			Offset: ggml.Tensors().Offset,
 		},
 	}
 
-	n, err := io.Copy(temp, reader)
+	n, err := io.Copy(gotFile, reader)
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	fmt.Println(n, " is my offset")
-	file, err := os.Open(temp.Name())
+	file, err := os.Open(gotFile.Name())
 	if err != nil {
 		t.Fatal(err)
 	}