Josh Yan преди 10 месеца
родител
ревизия
d7c8d4f3f4
променени са 1 файла, в които са добавени 55 реда и са изтрити 3 реда
  1. 55 3
      llm/gguf.go

+ 55 - 3
llm/gguf.go

@@ -6,6 +6,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"io"
+	"log/slog"
 	"slices"
 	"sort"
 	"strings"
@@ -637,7 +638,7 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
 
 	for k, v := range kvCheck {
 		if !v {
-			return fmt.Errorf("Didn't know how to write kv %s", k)
+			return fmt.Errorf("didn't know how to write kv %s", k)
 		}
 	}
 
@@ -769,6 +770,57 @@ func ggufWriteTensorInfo(io.Writer, *Tensor) error {
 	return nil
 }
 
-func ggufWriteKV(io.Writer, string, any) error {
-	return nil
+func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
+	slog.Debug(k, "type", fmt.Sprintf("%T", v))
+	if err := binary.Write(ws, binary.LittleEndian, uint64(len(k))); err != nil {
+		return err
+	}
+
+	if err := binary.Write(ws, binary.LittleEndian, []byte(k)); err != nil {
+		return err
+	}
+
+	var err error
+	switch v := v.(type) {
+	case uint32:
+		err = writeGGUF(ws, ggufTypeUint32, v)
+	case float32:
+		err = writeGGUF(ws, ggufTypeFloat32, v)
+	case bool:
+		err = writeGGUF(ws, ggufTypeBool, v)
+	case string:
+		err = writeGGUFString(ws, v)
+	case []int32:
+		err = writeGGUFArray(ws, ggufTypeInt32, v)
+	case []uint32:
+		err = writeGGUFArray(ws, ggufTypeUint32, v)
+	case []float32:
+		err = writeGGUFArray(ws, ggufTypeFloat32, v)
+	case []string:
+		if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
+			return err
+		}
+
+		if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
+			return err
+		}
+
+		if err := binary.Write(ws, binary.LittleEndian, uint64(len(v))); err != nil {
+			return err
+		}
+
+		for _, e := range v {
+			if err := binary.Write(ws, binary.LittleEndian, uint64(len(e))); err != nil {
+				return err
+			}
+
+			if err := binary.Write(ws, binary.LittleEndian, []byte(e)); err != nil {
+				return err
+			}
+		}
+	default:
+		return fmt.Errorf("improper type for '%s'", k)
+	}
+
+	return err
 }