Browse Source

ml/backend/ggml: fix rms norm

Michael Yang 2 months ago
parent
commit
2192a28eed
1 changed files with 1 additions and 1 deletions
  1. 1 1
      ml/backend/ggml/ggml.go

+ 1 - 1
ml/backend/ggml/ggml.go

@@ -485,7 +485,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
 }
 }
 
 
 func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
 func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
-	return (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
+	return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
 }
 }
 
 
 func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
 func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {