瀏覽代碼

ml: update Dump to handle precision

Michael Yang 2 月之前
父節點
當前提交
3c653195f4
共有 1 個文件被更改,包括 11 次插入6 次删除
  1. 11 6
      ml/backend.go

+ 11 - 6
ml/backend.go

@@ -5,6 +5,7 @@ import (
 	"encoding/binary"
 	"fmt"
 	"os"
+	"strconv"
 	"strings"
 )
 
@@ -126,15 +127,19 @@ func Dump(t Tensor, opts ...DumpOptions) string {
 
 	switch t.DType() {
 	case DTypeF32:
-		return dump[[]float32](t, opts[0])
+		return dump[[]float32](t, opts[0].Items, func(f float32) string {
+			return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
+		})
 	case DTypeI32:
-		return dump[[]int32](t, opts[0])
+		return dump[[]int32](t, opts[0].Items, func(i int32) string {
+			return strconv.FormatInt(int64(i), 10)
+		})
 	default:
 		return "<unsupported>"
 	}
 }
 
-func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
+func dump[S ~[]E, E number](t Tensor, items int64, fn func(E) string) string {
 	bts := t.Bytes()
 	if bts == nil {
 		return "<nil>"
@@ -154,10 +159,10 @@ func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
 		fmt.Fprint(&sb, "[")
 		defer func() { fmt.Fprint(&sb, "]") }()
 		for i := int64(0); i < dims[0]; i++ {
-			if i >= opts.Items && i < dims[0]-opts.Items {
+			if i >= items && i < dims[0]-items {
 				fmt.Fprint(&sb, "..., ")
 				// skip to next printable element
-				skip := dims[0] - 2*opts.Items
+				skip := dims[0] - 2*items
 				if len(dims) > 1 {
 					stride += mul(append(dims[1:], skip)...)
 					fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
@@ -170,7 +175,7 @@ func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
 					fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
 				}
 			} else {
-				fmt.Fprint(&sb, s[stride+i])
+				fmt.Fprint(&sb, fn(s[stride+i]))
 				if i < dims[0]-1 {
 					fmt.Fprint(&sb, ", ")
 				}