normalization.go 500 B

1234567891011121314151617181920212223
  1. package nn
  2. import (
  3. "github.com/ollama/ollama/ml"
  4. )
  5. type LayerNorm struct {
  6. Weight ml.Tensor `ggml:"weight"`
  7. Bias ml.Tensor `ggml:"bias"`
  8. }
  9. func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
  10. return t.LayerNorm(ctx, m.Weight, m.Bias, eps)
  11. }
  12. type RMSNorm struct {
  13. Weight ml.Tensor `ggml:"weight"`
  14. Bias ml.Tensor `ggml:"bias"`
  15. }
  16. func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
  17. return t.RMSNorm(ctx, m.Weight, m.Bias, eps)
  18. }