normalization.go 460 B

12345678910111213141516171819202122
  1. package nn
  2. import (
  3. "github.com/ollama/ollama/ml"
  4. )
  5. type LayerNorm struct {
  6. Weight ml.Tensor `gguf:"weight"`
  7. Bias ml.Tensor `gguf:"bias"`
  8. }
  9. func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
  10. return t.LayerNorm(ctx, m.Weight, m.Bias, eps)
  11. }
  12. type RMSNorm struct {
  13. Weight ml.Tensor `gguf:"weight"`
  14. }
  15. func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
  16. return t.RMSNorm(ctx, m.Weight, eps)
  17. }