Josh Yan 9 月之前
父节点
当前提交
25be20949c
共有 3 个文件被更改,包括 140 次插入29 次删除
  1. 7 0
      llm/ggml.go
  2. 4 8
      llm/gguf.go
  3. 129 21
      llm/gguf_test.go

+ 7 - 0
llm/ggml.go

@@ -312,10 +312,12 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
 		maxArraySize = 1024
 	}
 
+	fmt.Println("errored 1")
 	rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
 
 	var magic uint32
 	if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
+		fmt.Println("errored 2")
 		return nil, 0, err
 	}
 
@@ -330,19 +332,24 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
 	case FILE_MAGIC_GGUF_BE:
 		c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
 	default:
+		fmt.Println("errored 3")
 		return nil, 0, errors.New("invalid file magic")
 	}
 
+	fmt.Println("valid magic")
 	model, err := c.Decode(rs)
 	if err != nil {
 		return nil, 0, err
 	}
 
+	fmt.Println("valid decode")
 	offset, err := rs.Seek(0, io.SeekCurrent)
 	if err != nil {
+		fmt.Println("invalid seek")
 		return nil, 0, err
 	}
 
+	fmt.Println("valid seek")
 	// final model type
 	return &GGML{
 		container: c,

+ 4 - 8
llm/gguf.go

@@ -8,7 +8,6 @@ import (
 	"io"
 	"log/slog"
 	"slices"
-	"sort"
 	"strings"
 
 	"golang.org/x/exp/maps"
@@ -141,13 +140,11 @@ func (llm *gguf) numKV() uint64 {
 
 func (llm *gguf) Decode(rs io.ReadSeeker) error {
 	// decode key-values
-	fmt.Println(llm.numKV())
 	for i := 0; uint64(i) < llm.numKV(); i++ {
 		k, err := readGGUFString(llm, rs)
 		if err != nil {
 			return err
 		}
-		fmt.Printf("k: %#v\n", k)
 
 		t, err := readGGUF[uint32](llm, rs)
 		if err != nil {
@@ -214,7 +211,6 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
 			}
 			shape = append(shape, shapeVal)
 		}
-		fmt.Println("tensor ", name, " shape ", shape)
 
 		kind, err := readGGUF[uint32](llm, rs)
 		if err != nil {
@@ -226,6 +222,7 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
 			return fmt.Errorf("failed to read tensor offset: %w", err)
 		}
 
+		fmt.Println("tensor", name, shape, kind, offset)
 		tensor := Tensor{
 			Name:   name,
 			Kind:   kind,
@@ -764,8 +761,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
 			}
 		}
 	}
-
-	sort.Sort(gguf.Tensors)
+	//sort.Sort(gguf.Tensors)
 
 	var s uint64
 	for _, t := range gguf.Tensors {
@@ -775,6 +771,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
 		}
 		s += t.Size()
 	}
+	tensorOffset := wo.offset
 
 	for _, t := range gguf.Tensors {
 		if err := ggufWriteTensor(wo, t, wo.offset); err != nil {
@@ -782,7 +779,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
 		}
 	}
 
-	return 0, nil
+	return int64(tensorOffset), nil
 }
 
 func ggufWriteTensorInfo(ws io.Writer, t *Tensor) error {
@@ -797,7 +794,6 @@ func ggufWriteTensorInfo(ws io.Writer, t *Tensor) error {
 	if err := binary.Write(ws, binary.LittleEndian, uint32(len(t.Shape))); err != nil {
 		return err
 	}
-	fmt.Println("tensor ", t.Name, " shape ", t.Shape)
 
 	for i := range len(t.Shape) {
 		if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil {

+ 129 - 21
llm/gguf_test.go

@@ -1,10 +1,13 @@
 package llm
 
 import (
+	"crypto/sha256"
+	"fmt"
 	"io"
 	"math"
 	"os"
 	"path/filepath"
+	"strings"
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
@@ -13,7 +16,6 @@ import (
 func TestGGUFRewrite(t *testing.T) {
 	tests := []string{
 		"glm2.gguf",
-		"nutiny.gguf",
 	}
 
 	for i := range tests {
@@ -26,44 +28,144 @@ func TestGGUFRewrite(t *testing.T) {
 				t.Fatalf("%s not found", p)
 			}
 
-			ggml, err := decodeGGML(t, p)
+			f, err := os.Open(p)
 			if err != nil {
 				t.Fatal(err)
 			}
+			defer f.Close()
 
-			ggml2, err := rewriteGGML(t, ggml, p)
+			ggml, m, err := decodeGGML(t, f)
 			if err != nil {
 				t.Fatal(err)
 			}
 
-			if cmp.Diff(ggml, ggml2) != "" {
-				t.Fatal(cmp.Diff(ggml, ggml2))
+			temp, err := os.CreateTemp("testdata", "2"+tt)
+			if err != nil {
+				t.Fatal(err)
+			}
+			defer temp.Close()
+
+			n, ggml2, err := rewriteGGML(t, ggml, temp)
+
+			if n != m {
+				t.Fatalf("n: %d, m: %d", n, m)
+			}
+
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			if diff, diff2, ok := compareGGML(n, ggml2, ggml, temp, f); !ok {
+				if cmp.Diff(diff, diff2) != "" {
+					t.Fatalf("\n%s,\n%s\ndiff: %s", diff["token_embd.weight"], diff2["token_embd.weight"], cmp.Diff(diff, diff2))
+				}
+			}
+
+			/* // Reset the file offset to the beginning
+			if _, err := f.Seek(0, io.SeekStart); err != nil {
+				t.Fatal(err)
 			}
+			if _, err := temp.Seek(0, io.SeekStart); err != nil {
+				t.Fatal(err)
+			}
+
+			content1, err := io.ReadAll(f)
+			if err != nil {
+				t.Fatalf("failed to read file1: %v", err)
+			}
+
+			content2, err := io.ReadAll(temp)
+			if err != nil {
+				t.Fatalf("failed to read file1: %v", err)
+			}
+
+			if byteCmp := cmp.Diff(content1, content2); byteCmp != "" {
+				t.Fatalf("content diff: %s", byteCmp)
+			} */
 		})
 	}
 }
 
-func decodeGGML(t *testing.T, p string) (*GGML, error) {
-	f, err := os.Open(p)
-	if err != nil {
-		t.Fatal(err)
+func formatDiff(diff map[string]string) string {
+	var builder strings.Builder
+	for k, v := range diff {
+		builder.WriteString(fmt.Sprintf("%s: %s\n", k, v))
 	}
-	defer f.Close()
+	return builder.String()
+}
 
-	ggml, _, err := DecodeGGML(f, math.MaxInt)
-	if err != nil {
-		t.Fatal(err)
+func compareGGML(n int64, ggml1, ggml2 *GGML, f *os.File, f2 *os.File) (map[string]string, map[string]string, bool) {
+	diff := make(map[string]string)
+	diff2 := make(map[string]string)
+
+	kv1 := ggml1.KV()
+	kv2 := ggml2.KV()
+
+	if len(kv1) != len(kv2) {
+		diff["lenKV"] = fmt.Sprintf("kv1: %d, kv2: %d", len(kv1), len(kv2))
+		fmt.Println("lenKV", diff["lenKV"])
+	}
+
+	for k, v := range kv1 {
+		// if v2, ok := kv2[k]; !ok {
+		// diff[k] = fmt.Sprintf("missing key %s", k)
+		// } else if v != v2 {
+		// diff[fmt.Sprintf("%s type diff", k)] = fmt.Sprintf("kv1 type: %T, kv2 type: %T", v.(*array).values, v2.(*array).values)
+		// diff[k] = fmt.Sprintf("kv1: %d, kv2: %d", len(v.(*array).values), len(v2.(*array).values))
+		// diff[fmt.Sprintf("%s values first 10", k)] = fmt.Sprintf("\nkv1: %#v, \nkv2: %#v", v.(*array).values[0:10], v2.(*array).values[0:10])
+		// diff[fmt.Sprintf("%s values last 10", k)] = fmt.Sprintf("\nkv1: %#v, \nkv2: %#v", v.(*array).values[len(v.(*array).values)-10:], v2.(*array).values[len(v2.(*array).values)-10:])
+		// diff[fmt.Sprintf("%s diff", k)] = cmp.Diff(v.(*array).values, v2.(*array).values)
+
+		switch t := v.(type) {
+		case *array:
+			if diffy := cmp.Diff(t.values, kv2[k].(*array).values); diffy != "" {
+				diff[k] = diffy
+			}
+		}
+
+		// }
+	}
+
+	t1 := ggml1.Tensors()
+	t2 := ggml2.Tensors()
+
+	if len(t1) != len(t2) {
+		diff["lenTensors"] = fmt.Sprintf("t1: %d, t2: %d", len(t1), len(t2))
 	}
-	return ggml, nil
+
+	for _, tensor := range t1 {
+		sha256sum := sha256.New()
+		sr := io.NewSectionReader(f, n+int64(tensor.Offset), int64(tensor.Size()))
+		if _, err := io.Copy(sha256sum, sr); err != nil {
+			fmt.Println(err)
+		}
+
+		diff[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
+	}
+
+	for _, tensor := range t2 {
+		sha256sum := sha256.New()
+		sr := io.NewSectionReader(f2, n+int64(tensor.Offset), int64(tensor.Size()))
+		if _, err := io.Copy(sha256sum, sr); err != nil {
+			fmt.Println(err)
+		}
+
+		diff2[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
+	}
+	return diff, diff2, len(diff) == 0
+
 }
+func decodeGGML(t *testing.T, f *os.File) (*GGML, int64, error) {
 
-func rewriteGGML(t *testing.T, ggml *GGML, path string) (*GGML, error) {
-	var tensors Tensors
-	temp, err := os.Create(path)
+	ggml, n, err := DecodeGGML(f, math.MaxInt)
 	if err != nil {
 		t.Fatal(err)
 	}
-	defer temp.Close()
+	return ggml, n, nil
+}
+
+func rewriteGGML(t *testing.T, ggml *GGML, temp *os.File) (int64, *GGML, error) {
+	var tensors Tensors
 
 	for _, tensor := range ggml.Tensors() {
 		shape := make([]uint64, len(tensor.Shape))
@@ -88,15 +190,21 @@ func rewriteGGML(t *testing.T, ggml *GGML, path string) (*GGML, error) {
 		Tensors: tensors,
 	}
 
-	_, err = io.Copy(temp, reader)
+	n, err := io.Copy(temp, reader)
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	ggml2, _, err := DecodeGGML(temp, -1)
+	fmt.Println(n)
+	temp.Seek(0, io.SeekStart)
+	file, err := os.Open(temp.Name())
+	if err != nil {
+		t.Fatal(err)
+	}
+	ggml2, n, err := DecodeGGML(file, math.MaxInt)
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	return ggml2, nil
+	return n, ggml2, nil
 }