package llm import ( "crypto/sha256" "fmt" "io" "math" "os" "path/filepath" "strings" "testing" "github.com/google/go-cmp/cmp" ) func TestGGUFRewrite(t *testing.T) { tests := []string{ "phi3.gguf", } for i := range tests { tt := tests[i] t.Run(tt, func(t *testing.T) { t.Parallel() p := filepath.Join("testdata", tt) if _, err := os.Stat(p); err != nil { t.Fatalf("%s not found", p) } f, err := os.Open(p) if err != nil { t.Fatal(err) } defer f.Close() ggml, _, err := decodeGGML(t, f) if err != nil { t.Fatal(err) } temp, err := os.CreateTemp("testdata", "2"+tt) if err != nil { t.Fatal(err) } defer temp.Close() n, ggml2, err := rewriteGGML(t, ggml, temp, f) /* if n != m { t.Fatalf("n: %d, m: %d", n, m) } */ if err != nil { t.Fatal(err) } if diff, diff2, ok := compareGGML(n, ggml2, ggml, temp, f); !ok { if cmp.Diff(diff, diff2) != "" { t.Fatalf("\n%s,\n%s\n%s\n%s\ndiff: %s", diff["token_embd.weight"], diff2["token_embd.weight"], diff["token_embd.weight size"], diff["token_embd.weight offset"], cmp.Diff(diff, diff2)) } } }) } } func formatDiff(diff map[string]string) string { var builder strings.Builder for k, v := range diff { builder.WriteString(fmt.Sprintf("%s: %s\n", k, v)) } return builder.String() } func compareGGML(n int64, ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string, bool) { diff := make(map[string]string) diff2 := make(map[string]string) kv1 := ggml1.KV() kv2 := ggml2.KV() if len(kv1) != len(kv2) { diff["lenKV"] = fmt.Sprintf("kv1: %d, kv2: %d", len(kv1), len(kv2)) fmt.Println("lenKV", diff["lenKV"]) } for k, v := range kv1 { // if v2, ok := kv2[k]; !ok { // diff[k] = fmt.Sprintf("missing key %s", k) // } else if v != v2 { // diff[fmt.Sprintf("%s type diff", k)] = fmt.Sprintf("kv1 type: %T, kv2 type: %T", v.(*array).values, v2.(*array).values) // diff[k] = fmt.Sprintf("kv1: %d, kv2: %d", len(v.(*array).values), len(v2.(*array).values)) // diff[fmt.Sprintf("%s values first 10", k)] = fmt.Sprintf("\nkv1: %#v, \nkv2: %#v", v.(*array).values[0:10], v2.(*array).values[0:10]) // diff[fmt.Sprintf("%s values last 10", k)] = fmt.Sprintf("\nkv1: %#v, \nkv2: %#v", v.(*array).values[len(v.(*array).values)-10:], v2.(*array).values[len(v2.(*array).values)-10:]) // diff[fmt.Sprintf("%s diff", k)] = cmp.Diff(v.(*array).values, v2.(*array).values) switch t := v.(type) { case *array: if diffy := cmp.Diff(t.values, kv2[k].(*array).values); diffy != "" { diff[k] = diffy } default: if v != kv2[k] { diff[k] = fmt.Sprintf("kv1: %v, kv2: %v", v, kv2[k]) } } // } } t1 := ggml1.Tensors() t2 := ggml2.Tensors() if len(t1) != len(t2) { diff["lenTensors"] = fmt.Sprintf("t1: %d, t2: %d", len(t1), len(t2)) } for _, tensor := range t1 { sha256sum := sha256.New() sr := io.NewSectionReader(f, n+int64(tensor.Offset), int64(tensor.Size())) var s int64 s, err := io.Copy(sha256sum, sr) if err != nil { fmt.Println(err) } diff[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil)) diff[tensor.Name+" size"] = fmt.Sprintf("%d", s) diff[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset) } /* sha256Sum2 := sha256.New() sr1 := io.NewSectionReader(f2, 0, n) s1, err := io.Copy(sha256Sum2, sr1) if err != nil { return nil, nil, true } sha256Sum3 := sha256.New() sr2 := io.NewSectionReader(f, 0, n) s2, err := io.Copy(sha256Sum3, sr2) if err != nil { return nil, nil, true } diff["sha"] = fmt.Sprintf("%d", s1) diff2["sha"] = fmt.Sprintf("%d", s2) */ for _, tensor := range t2 { sha256sum := sha256.New() var s int64 sr := io.NewSectionReader(f2, n+int64(tensor.Offset), int64(tensor.Size())) s, err := io.Copy(sha256sum, sr) if err != nil { fmt.Println(err) } diff2[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil)) diff2[tensor.Name+" size"] = fmt.Sprintf("%d", s) diff2[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset) } return diff, diff2, len(diff) == 0 } func decodeGGML(t *testing.T, f *os.File) (*GGML, int64, error) { ggml, n, err := DecodeGGML(f, math.MaxInt) if err != nil { t.Fatal(err) } return ggml, n, nil } func rewriteGGML(t *testing.T, ggml *GGML, temp *os.File, f *os.File) (int64, *GGML, error) { var tensors Tensors fmt.Println("11111111111111111111111111111111111111111") for _, tensor := range ggml.Tensors() { shape := make([]uint64, len(tensor.Shape)) for i := range len(tensor.Shape) { shape[i] = tensor.Shape[len(tensor.Shape)-i-1] } fmt.Println("tensors", tensor.Name, shape, tensor.Kind, 737414+int64(tensor.Offset)) tensors = append(tensors, &Tensor{ Name: tensor.Name, Kind: tensor.Kind, Shape: shape, WriterTo: TensorWriter{ Reader: io.NewSectionReader(f, 737414+int64(tensor.Offset), int64(tensor.Size())), }, }) } reader := &GGUFWriter{ KV: ggml.KV(), // Update .Tensors Tensors: tensors, } n, err := io.Copy(temp, reader) if err != nil { t.Fatal(err) } fmt.Println(n, " is my offset") file, err := os.Open(temp.Name()) if err != nil { t.Fatal(err) } ggml2, _, err := DecodeGGML(file, math.MaxInt) if err != nil { t.Fatal(err) } return n, ggml2, nil }