Jelajahi Sumber

ml: update Context.Forward interface

update Context.Forward to accept multiple tensors to match
Context.Compute signature

update Context.Forward to return Context such that it can be chained
with Context.Compute
Michael Yang 2 bulan lalu
induk
melakukan
3e8b8a1933
6 mengubah file dengan 18 tambahan dan 14 penghapusan
  1. 4 2
      kvcache/causal.go
  2. 1 3
      kvcache/causal_test.go
  3. 4 2
      kvcache/encoder.go
  4. 2 3
      ml/backend.go
  5. 6 2
      ml/backend/ggml/ggml.go
  6. 1 2
      model/model.go

+ 4 - 2
kvcache/causal.go

@@ -330,8 +330,10 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
 		c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
 	}
 
-	ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))))
-	ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))))
+	ctx.Forward(
+		key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))),
+		value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))),
+	)
 }
 
 func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {

+ 1 - 3
kvcache/causal_test.go

@@ -280,9 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
 
 			out, _, mask := cache.Get(context)
 
-			context.Forward(out)
-			context.Forward(mask)
-			context.Compute(out, mask)
+			context.Forward(out, mask).Compute(out, mask)
 
 			if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
 				t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)

+ 4 - 2
kvcache/encoder.go

@@ -80,8 +80,10 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
 		c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
 	}
 
-	ctx.Forward(key.Copy(ctx, c.keys[c.curLayer]))
-	ctx.Forward(value.Copy(ctx, c.values[c.curLayer]))
+	ctx.Forward(
+		key.Copy(ctx, c.keys[c.curLayer]),
+		value.Copy(ctx, c.values[c.curLayer]),
+	)
 }
 
 func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {

+ 2 - 3
ml/backend.go

@@ -65,7 +65,7 @@ type Context interface {
 	FromFloatSlice(s []float32, shape ...int) (Tensor, error)
 	FromIntSlice(s []int32, shape ...int) (Tensor, error)
 
-	Forward(Tensor)
+	Forward(...Tensor) Context
 	Compute(...Tensor)
 	MaxTensors() int
 	Close()
@@ -186,8 +186,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
 
 func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
 	if t.Bytes() == nil {
-		ctx.Forward(t)
-		ctx.Compute(t)
+		ctx.Forward(t).Compute(t)
 	}
 
 	s := make(S, mul(t.Shape()...))

+ 6 - 2
ml/backend/ggml/ggml.go

@@ -256,12 +256,16 @@ type Context struct {
 	nodes int
 }
 
-func (c *Context) Forward(t ml.Tensor) {
+func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
 	if c.graph == nil {
 		c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.nodes), false)
 	}
 
-	C.ggml_build_forward_expand(c.graph, t.(*Tensor).t)
+	for _, tensor := range tensors {
+		C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
+	}
+
+	return c
 }
 
 func (c *Context) Compute(tensors ...ml.Tensor) {

+ 1 - 2
model/model.go

@@ -248,8 +248,7 @@ func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	ctx.Forward(t)
-	ctx.Compute(t)
+	ctx.Forward(t).Compute(t)
 
 	return t, nil
 }