|
@@ -304,7 +304,7 @@ func shapeToGGML(shape []int) *C.int64_t {
|
|
return &sh[0]
|
|
return &sh[0]
|
|
}
|
|
}
|
|
|
|
|
|
-func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
|
|
|
|
|
+func newTensor(ctx Context, dtype ml.DType, zero bool, shape []int) ml.Tensor {
|
|
if len(shape) < 1 || len(shape) > 4 {
|
|
if len(shape) < 1 || len(shape) > 4 {
|
|
panic("unsupported number of dimensions")
|
|
panic("unsupported number of dimensions")
|
|
}
|
|
}
|
|
@@ -318,19 +318,29 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
|
var t *C.struct_ggml_tensor
|
|
var t *C.struct_ggml_tensor
|
|
switch dtype {
|
|
switch dtype {
|
|
case ml.DTypeF32:
|
|
case ml.DTypeF32:
|
|
- t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
|
|
|
|
|
|
+ t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
|
|
case ml.DTypeF16:
|
|
case ml.DTypeF16:
|
|
- t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
|
|
|
|
|
|
+ t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
|
|
case ml.DTypeI32:
|
|
case ml.DTypeI32:
|
|
- t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
|
|
|
|
|
|
+ t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
|
|
default:
|
|
default:
|
|
panic("unsupported dtype")
|
|
panic("unsupported dtype")
|
|
}
|
|
}
|
|
|
|
|
|
- b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t))
|
|
|
|
|
|
+ b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
|
|
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
|
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
|
- C.ggml_set_zero(t)
|
|
|
|
- return &Tensor{b: c.b, t: t}
|
|
|
|
|
|
+ if zero {
|
|
|
|
+ C.ggml_set_zero(t)
|
|
|
|
+ }
|
|
|
|
+ return &Tensor{b: ctx.b, t: t}
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
|
|
|
+ return newTensor(c, dtype, false, shape)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
|
|
|
+ return newTensor(c, dtype, true, shape)
|
|
}
|
|
}
|
|
|
|
|
|
func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
|
|
func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
|