浏览代码

fix: pad tensor item if ge zero

this produces a nicer output since both positive and negative values
produces the same width
Michael Yang 1 月之前
父节点
当前提交
9926eae015
共有 1 个文件被更改,包括 12 次插入5 次删除
  1. 12 5
      ml/backend.go

+ 12 - 5
ml/backend.go

@@ -5,6 +5,7 @@ import (
 	"encoding/binary"
 	"fmt"
 	"os"
+	"slices"
 	"strconv"
 	"strings"
 )
@@ -241,16 +242,17 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string)
 	}
 
 	shape := t.Shape()
+	slices.Reverse(shape)
 
 	var sb strings.Builder
 	var f func([]int, int)
 	f = func(dims []int, stride int) {
 		prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
-		fmt.Fprint(&sb, "[")
-		defer func() { fmt.Fprint(&sb, "]") }()
+		sb.WriteString("[")
+		defer func() { sb.WriteString("]") }()
 		for i := 0; i < dims[0]; i++ {
 			if i >= items && i < dims[0]-items {
-				fmt.Fprint(&sb, "..., ")
+				sb.WriteString("..., ")
 				// skip to next printable element
 				skip := dims[0] - 2*items
 				if len(dims) > 1 {
@@ -265,9 +267,14 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string)
 					fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
 				}
 			} else {
-				fmt.Fprint(&sb, fn(s[stride+i]))
+				text := fn(s[stride+i])
+				if len(text) > 0 && text[0] != '-' {
+					sb.WriteString(" ")
+				}
+
+				sb.WriteString(text)
 				if i < dims[0]-1 {
-					fmt.Fprint(&sb, ", ")
+					sb.WriteString(", ")
 				}
 			}
 		}