|
@@ -12,10 +12,14 @@ import (
|
|
|
"github.com/google/go-cmp/cmp"
|
|
|
)
|
|
|
|
|
|
+// TestGGUFDecode tests the decoding and rewriting of (unsorted) GGUF files
|
|
|
+// To run, add GGUF files to /llm/testdata and add the name of the file to the tests slice
|
|
|
+// Should comment //sort.Sort(tensors) in gguf.go
|
|
|
+// This creates a temporary file in /llm/testdata that will deleted only if the test passes
|
|
|
func TestGGUFRewrite(t *testing.T) {
|
|
|
- // to test this GGUF Rewrite, add gguf files to /llm/testdata
|
|
|
- // add the name of the file to the tests slice
|
|
|
- tests := []string{}
|
|
|
+ tests := []string{
|
|
|
+ "nutiny.gguf",
|
|
|
+ }
|
|
|
|
|
|
for i := range tests {
|
|
|
tt := tests[i]
|
|
@@ -24,169 +28,152 @@ func TestGGUFRewrite(t *testing.T) {
|
|
|
p := filepath.Join("testdata", tt)
|
|
|
|
|
|
if _, err := os.Stat(p); err != nil {
|
|
|
- t.Fatalf("%s not found", p)
|
|
|
+ t.Skip("file not found", p)
|
|
|
}
|
|
|
|
|
|
- f, err := os.Open(p)
|
|
|
+ wantFile, err := os.Open(p)
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
- defer f.Close()
|
|
|
+ defer wantFile.Close()
|
|
|
|
|
|
// decode original gguf
|
|
|
- ggml, _, err := decodeGGML(t, f)
|
|
|
+ _, wantGGML, err := decodeGGML(t, wantFile)
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
|
|
|
- temp, err := os.CreateTemp("testdata", "2"+tt)
|
|
|
+ gotFile, err := os.CreateTemp("testdata", tt)
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
- defer temp.Close()
|
|
|
+ defer func() {
|
|
|
+ gotFile.Close()
|
|
|
+ if !t.Failed() {
|
|
|
+ os.Remove(gotFile.Name())
|
|
|
+ }
|
|
|
+ }()
|
|
|
|
|
|
- _, ggml2, err := rewriteGGML(t, ggml, temp, f)
|
|
|
+ _, gotGGML, err := rewriteGGML(t, wantGGML, gotFile, wantFile)
|
|
|
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
|
|
|
- if diff, diff2, ok := compareGGML(ggml2, ggml, temp, f); !ok {
|
|
|
- if cmp.Diff(diff, diff2) != "" {
|
|
|
- t.Fatalf("diff: \n%s", cmp.Diff(diff, diff2))
|
|
|
- }
|
|
|
+ diff, diff2 := compareGGML(t, gotGGML, wantGGML, gotFile, wantFile)
|
|
|
+ if cmp.Diff(diff, diff2) != "" {
|
|
|
+ t.Fatalf("diff: \n%s", cmp.Diff(diff, diff2))
|
|
|
}
|
|
|
})
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func compareGGML(ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string, bool) {
|
|
|
- diff := make(map[string]string)
|
|
|
- diff2 := make(map[string]string)
|
|
|
+func compareGGML(t *testing.T, gotGGML, wantGGML *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string) {
|
|
|
+ got := make(map[string]string)
|
|
|
+ want := make(map[string]string)
|
|
|
|
|
|
- kv1 := ggml1.KV()
|
|
|
- kv2 := ggml2.KV()
|
|
|
+ gotKV := gotGGML.KV()
|
|
|
+ wantKV := wantGGML.KV()
|
|
|
|
|
|
- if len(kv1) != len(kv2) {
|
|
|
- diff["lenKV"] = fmt.Sprintf("kv1: %d, kv2: %d", len(kv1), len(kv2))
|
|
|
- fmt.Println("lenKV", diff["lenKV"])
|
|
|
+ if len(gotKV) != len(wantKV) {
|
|
|
+ t.Fatalf("got length: %d != want length: %d", len(gotKV), len(wantKV))
|
|
|
}
|
|
|
|
|
|
- for k, v := range kv1 {
|
|
|
+ for k, v := range gotKV {
|
|
|
switch t := v.(type) {
|
|
|
case *array:
|
|
|
- if diffy := cmp.Diff(t.values, kv2[k].(*array).values); diffy != "" {
|
|
|
- diff[k] = diffy
|
|
|
+ if diffy := cmp.Diff(t.values, wantKV[k].(*array).values); diffy != "" {
|
|
|
+ got[k] = diffy
|
|
|
}
|
|
|
default:
|
|
|
- if v != kv2[k] {
|
|
|
- diff[k] = fmt.Sprintf("kv1: %v, kv2: %v", v, kv2[k])
|
|
|
+ if v != wantKV[k] {
|
|
|
+ got[k] = fmt.Sprintf("kv1: %v, kv2: %v", v, want[k])
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- t1 := ggml1.Tensors()
|
|
|
- t2 := ggml2.Tensors()
|
|
|
+ gotTensors := gotGGML.Tensors().Items
|
|
|
+ gotOffset := gotGGML.Tensors().Offset
|
|
|
+ wantTensors := wantGGML.Tensors().Items
|
|
|
+ wantOffset := wantGGML.Tensors().Offset
|
|
|
|
|
|
- if len(t1.Items) != len(t2.Items) {
|
|
|
- diff["lenTensors"] = fmt.Sprintf("t1: %d, t2: %d", len(t1.Items), len(t2.Items))
|
|
|
+ if len(gotTensors) != len(wantTensors) {
|
|
|
+ got["lenTensors"] = fmt.Sprintf("t1: %d, t2: %d", len(gotTensors), len(wantTensors))
|
|
|
}
|
|
|
|
|
|
- for _, tensor := range t1.Items {
|
|
|
+ for _, tensor := range gotTensors {
|
|
|
sha256sum := sha256.New()
|
|
|
- sr := io.NewSectionReader(f, t1.Offset+int64(tensor.Offset), int64(tensor.Size()))
|
|
|
+ sr := io.NewSectionReader(f, gotOffset+int64(tensor.Offset), int64(tensor.Size()))
|
|
|
var s int64
|
|
|
s, err := io.Copy(sha256sum, sr)
|
|
|
if err != nil {
|
|
|
- fmt.Println(err)
|
|
|
+ t.Fatalf("error: %v", err)
|
|
|
}
|
|
|
|
|
|
- diff[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
|
|
|
- diff[tensor.Name+" size"] = fmt.Sprintf("%d", s)
|
|
|
- diff[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset)
|
|
|
- }
|
|
|
-
|
|
|
- /* sha256Sum2 := sha256.New()
|
|
|
- sr1 := io.NewSectionReader(f2, 0, n)
|
|
|
- s1, err := io.Copy(sha256Sum2, sr1)
|
|
|
- if err != nil {
|
|
|
- return nil, nil, true
|
|
|
- }
|
|
|
-
|
|
|
- sha256Sum3 := sha256.New()
|
|
|
- sr2 := io.NewSectionReader(f, 0, n)
|
|
|
- s2, err := io.Copy(sha256Sum3, sr2)
|
|
|
- if err != nil {
|
|
|
- return nil, nil, true
|
|
|
+ got[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
|
|
|
+ got[tensor.Name+" size"] = fmt.Sprintf("%d", s)
|
|
|
+ got[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset)
|
|
|
}
|
|
|
|
|
|
- diff["sha"] = fmt.Sprintf("%d", s1)
|
|
|
- diff2["sha"] = fmt.Sprintf("%d", s2) */
|
|
|
-
|
|
|
- for _, tensor := range t2.Items {
|
|
|
+ for _, tensor := range wantTensors {
|
|
|
sha256sum := sha256.New()
|
|
|
var s int64
|
|
|
- sr := io.NewSectionReader(f2, t1.Offset+int64(tensor.Offset), int64(tensor.Size()))
|
|
|
+ sr := io.NewSectionReader(f2, wantOffset +int64(tensor.Offset), int64(tensor.Size()))
|
|
|
s, err := io.Copy(sha256sum, sr)
|
|
|
if err != nil {
|
|
|
- fmt.Println(err)
|
|
|
+ t.Fatalf("error: %v", err)
|
|
|
}
|
|
|
|
|
|
- diff2[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
|
|
|
- diff2[tensor.Name+" size"] = fmt.Sprintf("%d", s)
|
|
|
- diff2[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset)
|
|
|
+ want[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
|
|
|
+ want[tensor.Name+" size"] = fmt.Sprintf("%d", s)
|
|
|
+ want[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset)
|
|
|
}
|
|
|
- return diff, diff2, len(diff) == 0
|
|
|
-
|
|
|
+ return got, want
|
|
|
}
|
|
|
-func decodeGGML(t *testing.T, f *os.File) (*GGML, int64, error) {
|
|
|
+
|
|
|
+func decodeGGML(t *testing.T, f *os.File) (int64, *GGML, error) {
|
|
|
|
|
|
ggml, n, err := DecodeGGML(f, math.MaxInt)
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
- return ggml, n, nil
|
|
|
+ return n, ggml, nil
|
|
|
}
|
|
|
|
|
|
-func rewriteGGML(t *testing.T, ggml *GGML, temp *os.File, f *os.File) (int64, *GGML, error) {
|
|
|
+func rewriteGGML(t *testing.T, ggml *GGML, gotFile *os.File, wantFile *os.File) (int64, *GGML, error) {
|
|
|
var tensors []*Tensor
|
|
|
|
|
|
- fmt.Println("11111111111111111111111111111111111111111")
|
|
|
for _, tensor := range ggml.Tensors().Items {
|
|
|
shape := make([]uint64, len(tensor.Shape))
|
|
|
for i := range len(tensor.Shape) {
|
|
|
shape[i] = tensor.Shape[len(tensor.Shape)-i-1]
|
|
|
}
|
|
|
|
|
|
- fmt.Println("tensors", tensor.Name, shape, tensor.Kind, tensor.Offset)
|
|
|
- fmt.Println(ggml.Tensors().Offset)
|
|
|
tensors = append(tensors, &Tensor{
|
|
|
Name: tensor.Name,
|
|
|
Kind: tensor.Kind,
|
|
|
Shape: shape,
|
|
|
|
|
|
WriterTo: TensorWriter{
|
|
|
- Reader: io.NewSectionReader(f, ggml.Tensors().Offset+int64(tensor.Offset), int64(tensor.Size())),
|
|
|
+ Reader: io.NewSectionReader(wantFile, ggml.Tensors().Offset+int64(tensor.Offset), int64(tensor.Size())),
|
|
|
},
|
|
|
})
|
|
|
}
|
|
|
|
|
|
reader := &GGUFWriter{
|
|
|
KV: ggml.KV(),
|
|
|
- // Update .Tensors
|
|
|
Tensors: Tensors{
|
|
|
Items: tensors,
|
|
|
Offset: ggml.Tensors().Offset,
|
|
|
},
|
|
|
}
|
|
|
|
|
|
- n, err := io.Copy(temp, reader)
|
|
|
+ n, err := io.Copy(gotFile, reader)
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
|
|
|
- fmt.Println(n, " is my offset")
|
|
|
- file, err := os.Open(temp.Name())
|
|
|
+ file, err := os.Open(gotFile.Name())
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|