Josh Yan vor 9 Monaten
Ursprung
Commit
e75fb73839
4 geänderte Dateien mit 27 neuen und 43 gelöschten Zeilen
  1. 15 14
      llm/ggml.go
  2. 6 15
      llm/gguf.go
  3. 2 0
      server/layer.go
  4. 4 14
      server/model.go

+ 15 - 14
llm/ggml.go

@@ -1,7 +1,6 @@
 package llm
 
 import (
-	"cmp"
 	"encoding/binary"
 	"errors"
 	"fmt"
@@ -113,26 +112,28 @@ func (kv KV) ChatTemplate() string {
 	return s
 }
 
+// Tensors type as a slice of pointers to Tensor
 type Tensors []*Tensor
 
-func (ts Tensors) Less(i, j int) bool {
-	var x, y int
-	if n, err := fmt.Sscanf(ts[i].Name, "blk.%d", &x); err != nil || n != 1 {
-		return cmp.Less(ts[i].Name, ts[j].Name)
-	} else if n, err := fmt.Sscanf(ts[j].Name, "blk.%d", &y); err != nil || n != 1 {
-		return cmp.Less(ts[i].Name, ts[j].Name)
-	}
-
-	return cmp.Less(x, y)
-}
-
+// Implement the Len method
 func (ts Tensors) Len() int {
 	return len(ts)
 }
 
+// Implement the Swap method
 func (ts Tensors) Swap(i, j int) {
-	var temp Tensor
-	
+	ts[i], ts[j] = ts[j], ts[i]
+}
+
+// Implement the Less method
+func (ts Tensors) Less(i, j int) bool {
+	var x, y int
+	if n, err := fmt.Sscanf(ts[i].Name, "blk.%d", &x); err != nil || n != 1 {
+		return ts[i].Name < ts[j].Name
+	} else if n, err := fmt.Sscanf(ts[j].Name, "blk.%d", &y); err != nil || n != 1 {
+		return ts[i].Name < ts[j].Name
+	}
+	return x < y
 }
 
 func (ts Tensors) Layers() map[string]Layer {

+ 6 - 15
llm/gguf.go

@@ -2,12 +2,12 @@ package llm
 
 import (
 	"bytes"
-	"cmp"
 	"encoding/binary"
 	"encoding/json"
 	"fmt"
 	"io"
 	"slices"
+	"sort"
 	"strings"
 
 	"golang.org/x/exp/maps"
@@ -702,8 +702,8 @@ func (gguf) padding(offset, align int64) int64 {
 
 // Reader and WriterTo
 type GGUFWriter struct {
-	KV KV
-	T  []*Tensor
+	KV
+	Tensors
 }
 
 var _ io.Reader = (*GGUFWriter)(nil)
@@ -740,19 +740,10 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
 		}
 	}
 
-	slices.SortFunc(gguf.T, func(a, b *Tensor) int {
-		var i, j int
-		if n, err := fmt.Sscanf(a.Name, "blk.%d", &i); err != nil || n != 1 {
-			return cmp.Compare(a.Name, b.Name)
-		} else if n, err := fmt.Sscanf(b.Name, "blk.%d", &j); err != nil || n != 1 {
-			return cmp.Compare(a.Name, b.Name)
-		}
-
-		return cmp.Compare(i, j)
-	})
+	sort.Sort(gguf.Tensors)
 
 	var s uint64
-	for _, t := range gguf.T {
+	for _, t := range gguf.Tensors {
 		t.Offset = s
 		if err := ggufWriteTensorInfo(w, t); err != nil {
 			return 0, err
@@ -761,7 +752,7 @@ func (gguf GGUFWriter) WriteTo(w io.Writer) (int64, error) {
 	}
 
 	var alignment int64 = 32
-	for _, t := range gguf.T {
+	for _, t := range gguf.Tensors {
 		if err := ggufWriteTensor(w, t, alignment); err != nil {
 			return 0, err
 		}

+ 2 - 0
server/layer.go

@@ -29,6 +29,8 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
 	defer os.Remove(temp.Name())
 
 	sha256sum := sha256.New()
+	if 
+
 	n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
 	if err != nil {
 		return nil, err

+ 4 - 14
server/model.go

@@ -3,7 +3,6 @@ package server
 import (
 	"archive/zip"
 	"bytes"
-	"cmp"
 	"context"
 	"errors"
 	"fmt"
@@ -12,7 +11,7 @@ import (
 	"net/http"
 	"os"
 	"path/filepath"
-	"slices"
+	"sort"
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/convert"
@@ -244,19 +243,10 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
 		}
 
 		var reader io.Reader = io.NewSectionReader(file, offset, n)
-		if !slices.IsSortedFunc(ggml.Tensors(), func(a, b *llm.Tensor) int {
-			var i, j int
-			if n, err := fmt.Sscanf(a.Name, "blk.%d", &i); err != nil || n != 1 {
-				return cmp.Compare(a.Name, b.Name)
-			} else if n, err := fmt.Sscanf(b.Name, "blk.%d", &j); err != nil || n != 1 {
-				return cmp.Compare(a.Name, b.Name)
-			}
-
-			return cmp.Compare(i, j)
-		}) {
+		if !sort.IsSorted(ggml.Tensors()) {
 			reader = &llm.GGUFWriter{
-				KV: ggml.KV(),
-				T:  ggml.Tensors(),
+				KV:      ggml.KV(),
+				Tensors: ggml.Tensors(),
 			}
 		}