Ver código fonte

instead of static number of parameters for each model family, get the real number from the tensors (#1022)

* parse tensor info

* refactor decoder

* return actual parameter count

* explicit rounding

* s/Human/HumanNumber/
Michael Yang 1 ano atrás
pai
commit
c5e1bbabda
2 arquivos alterados com 72 adições e 18 exclusões
  1. 25 0
      format/format.go
  2. 47 18
      llm/gguf.go

+ 25 - 0
format/format.go

@@ -0,0 +1,25 @@
+package format
+
+import (
+	"fmt"
+	"math"
+)
+
+const (
+	Thousand = 1000
+	Million  = Thousand * 1000
+	Billion  = Million * 1000
+)
+
+func HumanNumber(b uint64) string {
+	switch {
+	case b > Billion:
+		return fmt.Sprintf("%.0fB", math.Round(float64(b)/Billion))
+	case b > Million:
+		return fmt.Sprintf("%.0fM", math.Round(float64(b)/Million))
+	case b > Thousand:
+		return fmt.Sprintf("%.0fK", math.Round(float64(b)/Thousand))
+	default:
+		return fmt.Sprintf("%d", b)
+	}
+}

+ 47 - 18
llm/gguf.go

@@ -5,6 +5,8 @@ import (
 	"encoding/binary"
 	"encoding/binary"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
+
+	"github.com/jmorganca/ollama/format"
 )
 )
 
 
 type containerGGUF struct {
 type containerGGUF struct {
@@ -21,6 +23,8 @@ type containerGGUF struct {
 		NumTensor uint64
 		NumTensor uint64
 		NumKV     uint64
 		NumKV     uint64
 	}
 	}
+
+	parameters uint64
 }
 }
 
 
 func (c *containerGGUF) Name() string {
 func (c *containerGGUF) Name() string {
@@ -75,6 +79,14 @@ func newGGUFModel(container *containerGGUF) *ggufModel {
 	}
 	}
 }
 }
 
 
+func (llm *ggufModel) NumTensor() uint64 {
+	if llm.Version == 1 {
+		return uint64(llm.V1.NumTensor)
+	}
+
+	return llm.V2.NumTensor
+}
+
 func (llm *ggufModel) NumKV() uint64 {
 func (llm *ggufModel) NumKV() uint64 {
 	if llm.Version == 1 {
 	if llm.Version == 1 {
 		return uint64(llm.V1.NumKV)
 		return uint64(llm.V1.NumKV)
@@ -93,6 +105,10 @@ func (llm *ggufModel) ModelFamily() string {
 }
 }
 
 
 func (llm *ggufModel) ModelType() string {
 func (llm *ggufModel) ModelType() string {
+	if llm.parameters > 0 {
+		return format.HumanNumber(llm.parameters)
+	}
+
 	switch llm.ModelFamily() {
 	switch llm.ModelFamily() {
 	case "llama":
 	case "llama":
 		if blocks, ok := llm.kv["llama.block_count"].(uint32); ok {
 		if blocks, ok := llm.kv["llama.block_count"].(uint32); ok {
@@ -127,13 +143,9 @@ func (llm *ggufModel) FileType() string {
 }
 }
 
 
 func (llm *ggufModel) Decode(r io.Reader) error {
 func (llm *ggufModel) Decode(r io.Reader) error {
-	read := llm.readString
-	if llm.Version == 1 {
-		read = llm.readStringV1
-	}
-
+	// decode key-values
 	for i := 0; uint64(i) < llm.NumKV(); i++ {
 	for i := 0; uint64(i) < llm.NumKV(); i++ {
-		k, err := read(r)
+		k, err := llm.readString(r)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -165,24 +177,14 @@ func (llm *ggufModel) Decode(r io.Reader) error {
 		case ggufTypeBool:
 		case ggufTypeBool:
 			v = llm.readBool(r)
 			v = llm.readBool(r)
 		case ggufTypeString:
 		case ggufTypeString:
-			fn := llm.readString
-			if llm.Version == 1 {
-				fn = llm.readStringV1
-			}
-
-			s, err := fn(r)
+			s, err := llm.readString(r)
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
 
 
 			v = s
 			v = s
 		case ggufTypeArray:
 		case ggufTypeArray:
-			fn := llm.readArray
-			if llm.Version == 1 {
-				fn = llm.readArrayV1
-			}
-
-			a, err := fn(r)
+			a, err := llm.readArray(r)
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
@@ -195,6 +197,25 @@ func (llm *ggufModel) Decode(r io.Reader) error {
 		llm.kv[k] = v
 		llm.kv[k] = v
 	}
 	}
 
 
+	// decode tensors
+	for i := 0; uint64(i) < llm.NumTensor(); i++ {
+		if _, err := llm.readString(r); err != nil {
+			return err
+		}
+
+		dimensions := llm.readU32(r)
+
+		var elements uint64 = 1
+		for i := 0; uint32(i) < dimensions; i++ {
+			elements *= llm.readU64(r)
+		}
+
+		llm.readU32(r) // type
+		llm.readU64(r) // offset
+
+		llm.parameters += elements
+	}
+
 	return nil
 	return nil
 }
 }
 
 
@@ -290,6 +311,10 @@ func (llm ggufModel) readStringV1(r io.Reader) (string, error) {
 }
 }
 
 
 func (llm ggufModel) readString(r io.Reader) (string, error) {
 func (llm ggufModel) readString(r io.Reader) (string, error) {
+	if llm.Version == 1 {
+		return llm.readStringV1(r)
+	}
+
 	var nameLength uint64
 	var nameLength uint64
 	binary.Read(r, llm.bo, &nameLength)
 	binary.Read(r, llm.bo, &nameLength)
 
 
@@ -339,6 +364,10 @@ func (llm *ggufModel) readArrayV1(r io.Reader) (arr []any, err error) {
 }
 }
 
 
 func (llm *ggufModel) readArray(r io.Reader) (arr []any, err error) {
 func (llm *ggufModel) readArray(r io.Reader) (arr []any, err error) {
+	if llm.Version == 1 {
+		return llm.readArrayV1(r)
+	}
+
 	atype := llm.readU32(r)
 	atype := llm.readU32(r)
 	n := llm.readU64(r)
 	n := llm.readU64(r)