|
@@ -2,12 +2,12 @@ package llm
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
- "cmp"
|
|
|
"encoding/binary"
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"slices"
|
|
|
+ "sort"
|
|
|
"strings"
|
|
|
|
|
|
"golang.org/x/exp/maps"
|
|
@@ -702,8 +702,8 @@ func (gguf) padding(offset, align int64) int64 {
|
|
|
|
|
|
// Reader and WriterTo
|
|
|
type GGUFWriter struct {
|
|
|
- KV KV
|
|
|
- T []*Tensor
|
|
|
+ KV
|
|
|
+ Tensors
|
|
|
}
|
|
|
|
|
|
var _ io.Reader = (*GGUFWriter)(nil)
|
|
@@ -740,19 +740,10 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- slices.SortFunc(gguf.T, func(a, b *Tensor) int {
|
|
|
- var i, j int
|
|
|
- if n, err := fmt.Sscanf(a.Name, "blk.%d", &i); err != nil || n != 1 {
|
|
|
- return cmp.Compare(a.Name, b.Name)
|
|
|
- } else if n, err := fmt.Sscanf(b.Name, "blk.%d", &j); err != nil || n != 1 {
|
|
|
- return cmp.Compare(a.Name, b.Name)
|
|
|
- }
|
|
|
-
|
|
|
- return cmp.Compare(i, j)
|
|
|
- })
|
|
|
+ sort.Sort(gguf.Tensors)
|
|
|
|
|
|
var s uint64
|
|
|
- for _, t := range gguf.T {
|
|
|
+ for _, t := range gguf.Tensors {
|
|
|
t.Offset = s
|
|
|
if err := ggufWriteTensorInfo(w, t); err != nil {
|
|
|
return 0, err
|
|
@@ -761,7 +752,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
|
|
|
}
|
|
|
|
|
|
var alignment int64 = 32
|
|
|
- for _, t := range gguf.T {
|
|
|
+ for _, t := range gguf.Tensors {
|
|
|
if err := ggufWriteTensor(w, t, alignment); err != nil {
|
|
|
return 0, err
|
|
|
}
|