123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- package llm
- import (
- "crypto/sha256"
- "fmt"
- "io"
- "math"
- "os"
- "path/filepath"
- "testing"
- "github.com/google/go-cmp/cmp"
- )
- // TestGGUFDecode tests the decoding and rewriting of (unsorted) GGUF files
- // To run, add GGUF files to /llm/testdata and add the name of the file to the tests slice
- // This creates a temporary file in /llm/testdata that will deleted only if the test passes
- // Note: map[Tensor.Name + " offset"] is commented since sorting will reorder the tensors
- // Comment out sort.Sort(gguf.Tensors) in gguf.go to test offsets
- func TestGGUFRewrite(t *testing.T) {
- tests := []string{
- "phi3.gguf",
- }
- for i := range tests {
- tt := tests[i]
- t.Run(tt, func(t *testing.T) {
- t.Parallel()
- p := filepath.Join("testdata", tt)
- if _, err := os.Stat(p); err != nil {
- t.Skip("file not found", p)
- }
- wantFile, err := os.Open(p)
- if err != nil {
- t.Fatal(err)
- }
- defer wantFile.Close()
- // decode original gguf
- _, wantGGML, err := decodeGGML(t, wantFile)
- if err != nil {
- t.Fatal(err)
- }
- gotFile, err := os.CreateTemp("testdata", tt)
- if err != nil {
- t.Fatal(err)
- }
- defer func() {
- gotFile.Close()
- if !t.Failed() {
- os.Remove(gotFile.Name())
- }
- }()
- _, gotGGML, err := rewriteGGML(t, wantGGML, gotFile, wantFile)
- if err != nil {
- t.Fatal(err)
- }
- diff, diff2 := compareGGML(t, gotGGML, wantGGML, gotFile, wantFile)
- if cmp.Diff(diff, diff2) != "" {
- t.Fatalf("diff: \n%s", cmp.Diff(diff, diff2))
- }
- })
- }
- }
- func compareGGML(t *testing.T, gotGGML, wantGGML *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string) {
- got := make(map[string]string)
- want := make(map[string]string)
- gotKV := gotGGML.KV()
- wantKV := wantGGML.KV()
- if len(gotKV) != len(wantKV) {
- t.Fatalf("got length: %d != want length: %d", len(gotKV), len(wantKV))
- }
- for k, v := range gotKV {
- switch t := v.(type) {
- case *array:
- if diffy := cmp.Diff(t.values, wantKV[k].(*array).values); diffy != "" {
- got[k] = diffy
- }
- default:
- if v != wantKV[k] {
- got[k] = fmt.Sprintf("kv1: %v, kv2: %v", v, want[k])
- }
- }
- }
- gotTensors := gotGGML.Tensors().Items
- gotOffset := gotGGML.Tensors().Offset
- wantTensors := wantGGML.Tensors().Items
- wantOffset := wantGGML.Tensors().Offset
- if len(gotTensors) != len(wantTensors) {
- got["lenTensors"] = fmt.Sprintf("t1: %d, t2: %d", len(gotTensors), len(wantTensors))
- }
- for _, tensor := range gotTensors {
- sha256sum := sha256.New()
- sr := io.NewSectionReader(f, gotOffset+int64(tensor.Offset), int64(tensor.Size()))
- var s int64
- s, err := io.Copy(sha256sum, sr)
- if err != nil {
- t.Fatalf("error: %v", err)
- }
- got[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
- got[tensor.Name+" size"] = fmt.Sprintf("%d", s)
- // got[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset)
- }
- for _, tensor := range wantTensors {
- sha256sum := sha256.New()
- var s int64
- sr := io.NewSectionReader(f2, wantOffset +int64(tensor.Offset), int64(tensor.Size()))
- s, err := io.Copy(sha256sum, sr)
- if err != nil {
- t.Fatalf("error: %v", err)
- }
- want[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
- want[tensor.Name+" size"] = fmt.Sprintf("%d", s)
- // want[tensor.Name+" offset"] = fmt.Sprintf("%v", tensor.Offset)
- }
- return got, want
- }
- func decodeGGML(t *testing.T, f *os.File) (int64, *GGML, error) {
- ggml, n, err := DecodeGGML(f, math.MaxInt)
- if err != nil {
- t.Fatal(err)
- }
- return n, ggml, nil
- }
- func rewriteGGML(t *testing.T, ggml *GGML, gotFile *os.File, wantFile *os.File) (int64, *GGML, error) {
- var tensors []*Tensor
- for _, tensor := range ggml.Tensors().Items {
- shape := make([]uint64, len(tensor.Shape))
- for i := range len(tensor.Shape) {
- shape[i] = tensor.Shape[len(tensor.Shape)-i-1]
- }
- tensors = append(tensors, &Tensor{
- Name: tensor.Name,
- Kind: tensor.Kind,
- Shape: shape,
- WriterTo: TensorWriter{
- Reader: io.NewSectionReader(wantFile, ggml.Tensors().Offset+int64(tensor.Offset), int64(tensor.Size())),
- },
- })
- }
- reader := &GGUFWriter{
- KV: ggml.KV(),
- Tensors: Tensors{
- Items: tensors,
- Offset: ggml.Tensors().Offset,
- },
- }
- n, err := io.Copy(gotFile, reader)
- if err != nil {
- t.Fatal(err)
- }
- file, err := os.Open(gotFile.Name())
- if err != nil {
- t.Fatal(err)
- }
- ggml2, _, err := DecodeGGML(file, math.MaxInt)
- if err != nil {
- t.Fatal(err)
- }
- return n, ggml2, nil
- }
|