|
@@ -5,6 +5,7 @@ import (
|
|
"encoding/binary"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"fmt"
|
|
"os"
|
|
"os"
|
|
|
|
+ "slices"
|
|
"strconv"
|
|
"strconv"
|
|
"strings"
|
|
"strings"
|
|
)
|
|
)
|
|
@@ -241,16 +242,17 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string)
|
|
}
|
|
}
|
|
|
|
|
|
shape := t.Shape()
|
|
shape := t.Shape()
|
|
|
|
+ slices.Reverse(shape)
|
|
|
|
|
|
var sb strings.Builder
|
|
var sb strings.Builder
|
|
var f func([]int, int)
|
|
var f func([]int, int)
|
|
f = func(dims []int, stride int) {
|
|
f = func(dims []int, stride int) {
|
|
prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
|
|
prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
|
|
- fmt.Fprint(&sb, "[")
|
|
|
|
- defer func() { fmt.Fprint(&sb, "]") }()
|
|
|
|
|
|
+ sb.WriteString("[")
|
|
|
|
+ defer func() { sb.WriteString("]") }()
|
|
for i := 0; i < dims[0]; i++ {
|
|
for i := 0; i < dims[0]; i++ {
|
|
if i >= items && i < dims[0]-items {
|
|
if i >= items && i < dims[0]-items {
|
|
- fmt.Fprint(&sb, "..., ")
|
|
|
|
|
|
+ sb.WriteString("..., ")
|
|
// skip to next printable element
|
|
// skip to next printable element
|
|
skip := dims[0] - 2*items
|
|
skip := dims[0] - 2*items
|
|
if len(dims) > 1 {
|
|
if len(dims) > 1 {
|
|
@@ -265,9 +267,14 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string)
|
|
fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
|
|
fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
|
|
}
|
|
}
|
|
} else {
|
|
} else {
|
|
- fmt.Fprint(&sb, fn(s[stride+i]))
|
|
|
|
|
|
+ text := fn(s[stride+i])
|
|
|
|
+ if len(text) > 0 && text[0] != '-' {
|
|
|
|
+ sb.WriteString(" ")
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ sb.WriteString(text)
|
|
if i < dims[0]-1 {
|
|
if i < dims[0]-1 {
|
|
- fmt.Fprint(&sb, ", ")
|
|
|
|
|
|
+ sb.WriteString(", ")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|