Browse Source

Fix tests and drift from main

Jesse Gross 1 month ago
parent
commit
0e886595bf
3 changed files with 6 additions and 2 deletions
  1. 4 0
      kvcache/causal_test.go
  2. 1 1
      model/models/gemma2/model.go
  3. 1 1
      model/models/mllama/model_text.go

+ 4 - 0
kvcache/causal_test.go

@@ -499,6 +499,10 @@ func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
 	panic("not implemented")
 }
 
+func (t *testTensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
+	panic("not implemented")
+}
+
 func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
 	panic("not implemented")
 }

+ 1 - 1
model/models/gemma2/model.go

@@ -179,7 +179,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 		return nil, err
 	}

+ 1 - 1
model/models/mllama/model_text.go

@@ -28,7 +28,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
 
 	key := sa.Key.Forward(ctx, hiddenState)
 	key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
-	key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim,  ropeType, opts.ropeBase, opts.ropeScale)
+	key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
 
 	value := sa.Value.Forward(ctx, hiddenState)
 	value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)