浏览代码

ml/backend/ggml: fix rms norm

Michael Yang 3 月之前
父节点
当前提交
2192a28eed
共有 1 个文件被更改,包括 1 次插入1 次删除
  1. 1 1
      ml/backend/ggml/ggml.go

+ 1 - 1
ml/backend/ggml/ggml.go

@@ -485,7 +485,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
 }
 
 func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
-	return (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
+	return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
 }
 
 func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {