|
@@ -5,6 +5,7 @@ import (
|
|
|
"encoding/binary"
|
|
|
"fmt"
|
|
|
"os"
|
|
|
+ "strconv"
|
|
|
"strings"
|
|
|
)
|
|
|
|
|
@@ -126,15 +127,19 @@ func Dump(t Tensor, opts ...DumpOptions) string {
|
|
|
|
|
|
switch t.DType() {
|
|
|
case DTypeF32:
|
|
|
- return dump[[]float32](t, opts[0])
|
|
|
+ return dump[[]float32](t, opts[0].Items, func(f float32) string {
|
|
|
+ return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
|
|
+ })
|
|
|
case DTypeI32:
|
|
|
- return dump[[]int32](t, opts[0])
|
|
|
+ return dump[[]int32](t, opts[0].Items, func(i int32) string {
|
|
|
+ return strconv.FormatInt(int64(i), 10)
|
|
|
+ })
|
|
|
default:
|
|
|
return "<unsupported>"
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
|
|
|
+func dump[S ~[]E, E number](t Tensor, items int64, fn func(E) string) string {
|
|
|
bts := t.Bytes()
|
|
|
if bts == nil {
|
|
|
return "<nil>"
|
|
@@ -154,10 +159,10 @@ func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
|
|
|
fmt.Fprint(&sb, "[")
|
|
|
defer func() { fmt.Fprint(&sb, "]") }()
|
|
|
for i := int64(0); i < dims[0]; i++ {
|
|
|
- if i >= opts.Items && i < dims[0]-opts.Items {
|
|
|
+ if i >= items && i < dims[0]-items {
|
|
|
fmt.Fprint(&sb, "..., ")
|
|
|
// skip to next printable element
|
|
|
- skip := dims[0] - 2*opts.Items
|
|
|
+ skip := dims[0] - 2*items
|
|
|
if len(dims) > 1 {
|
|
|
stride += mul(append(dims[1:], skip)...)
|
|
|
fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
|
|
@@ -170,7 +175,7 @@ func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
|
|
|
fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
|
|
|
}
|
|
|
} else {
|
|
|
- fmt.Fprint(&sb, s[stride+i])
|
|
|
+ fmt.Fprint(&sb, fn(s[stride+i]))
|
|
|
if i < dims[0]-1 {
|
|
|
fmt.Fprint(&sb, ", ")
|
|
|
}
|