Explorar o código

Separate Rounding Functions

Roy Han hai 10 meses
pai
achega
c71698426c
Modificáronse 4 ficheiros con 62 adicións e 12 borrados
  1. 2 2
      cmd/cmd.go
  2. 29 7
      format/format.go
  3. 30 2
      format/format_test.go
  4. 1 1
      server/images.go

+ 2 - 2
cmd/cmd.go

@@ -656,7 +656,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
 
 	modelData := [][]string{
 		{"arch", arch},
-		{"parameters", format.HumanNumber(uint64(resp.ModelInfo["general.parameter_count"].(float64)))},
+		{"parameters", format.Parameters(uint64(resp.ModelInfo["general.parameter_count"].(float64)))},
 		{"quantization", resp.Details.QuantizationLevel},
 		{"context length", fmt.Sprintf("%v", resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))},
 		{"embedding length", fmt.Sprintf("%v", resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64))},
@@ -670,7 +670,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
 	if resp.ProjectorInfo != nil {
 		projectorData := [][]string{
 			{"arch", "clip"},
-			{"parameters", format.HumanNumber(uint64(resp.ProjectorInfo["general.parameter_count"].(float64)))},
+			{"parameters", format.Parameters(uint64(resp.ProjectorInfo["general.parameter_count"].(float64)))},
 			{"projector type", resp.ProjectorInfo["clip.projector_type"].(string)},
 			{"embedding length", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.embedding_length"].(float64))},
 			{"projection dimensionality", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.projection_dim"].(float64))},

+ 29 - 7
format/format.go

@@ -2,16 +2,38 @@ package format
 
 import (
 	"fmt"
+	"math"
 )
 
-func HumanNumber(b uint64) string {
-	const (
-		Thousand = 1000
-		Million  = Thousand * 1000
-		Billion  = Million * 1000
-		Trillion = Billion * 1000
-	)
+const (
+	Thousand = 1000
+	Million  = Thousand * 1000
+	Billion  = Million * 1000
+	Trillion = Billion * 1000
+)
+
+func RoundedParameter(b uint64) string {
+	switch {
+	case b >= Billion:
+		number := float64(b) / Billion
+		if number == math.Floor(number) {
+			return fmt.Sprintf("%.0fB", number) // no decimals if whole number
+		}
+		return fmt.Sprintf("%.1fB", number) // one decimal if not a whole number
+	case b >= Million:
+		number := float64(b) / Million
+		if number == math.Floor(number) {
+			return fmt.Sprintf("%.0fM", number) // no decimals if whole number
+		}
+		return fmt.Sprintf("%.2fM", number) // two decimals if not a whole number
+	case b >= Thousand:
+		return fmt.Sprintf("%.0fK", float64(b)/Thousand)
+	default:
+		return fmt.Sprintf("%d", b)
+	}
+}
 
+func Parameters(b uint64) string {
 	switch {
 	case b >= Trillion:
 		number := float64(b) / Trillion

+ 30 - 2
format/format_test.go

@@ -4,7 +4,35 @@ import (
 	"testing"
 )
 
-func TestHumanNumber(t *testing.T) {
+func TestRoundedParameter(t *testing.T) {
+	type testCase struct {
+		input    uint64
+		expected string
+	}
+
+	testCases := []testCase{
+		{0, "0"},
+		{1000000, "1M"},
+		{125000000, "125M"},
+		{500500000, "500.50M"},
+		{500550000, "500.55M"},
+		{1000000000, "1B"},
+		{2800000000, "2.8B"},
+		{2850000000, "2.9B"},
+		{1000000000000, "1000B"},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.expected, func(t *testing.T) {
+			result := RoundedParameter(tc.input)
+			if result != tc.expected {
+				t.Errorf("Expected %s, got %s", tc.expected, result)
+			}
+		})
+	}
+}
+
+func TestParameters(t *testing.T) {
 	type testCase struct {
 		input    uint64
 		expected string
@@ -23,7 +51,7 @@ func TestHumanNumber(t *testing.T) {
 
 	for _, tc := range testCases {
 		t.Run(tc.expected, func(t *testing.T) {
-			result := HumanNumber(tc.input)
+			result := Parameters(tc.input)
 			if result != tc.expected {
 				t.Errorf("Expected %s, got %s", tc.expected, result)
 			}

+ 1 - 1
server/images.go

@@ -431,7 +431,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 				if baseLayer.GGML != nil {
 					config.ModelFormat = cmp.Or(config.ModelFormat, baseLayer.GGML.Name())
 					config.ModelFamily = cmp.Or(config.ModelFamily, baseLayer.GGML.KV().Architecture())
-					config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount()))
+					config.ModelType = cmp.Or(config.ModelType, format.RoundedParameter(baseLayer.GGML.KV().ParameterCount()))
 					config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType().String())
 					config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
 				}