Преглед изворни кода

ml/backend/ggml: fix rms norm

Michael Yang пре 2 месеци
родитељ
комит
2192a28eed
1 измењених фајлова са 1 додато и 1 уклоњено
  1. 1 1
      ml/backend/ggml/ggml.go

+ 1 - 1
ml/backend/ggml/ggml.go

@@ -485,7 +485,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
 }
 
 func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
-	return (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
+	return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
 }
 
 func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {