greedy.go 466 B

1234567891011121314151617181920212223242526
  1. package sample
  2. import "gonum.org/v1/gonum/floats"
  3. type greedy struct{}
  4. func Greedy() Sampler {
  5. return greedy{}
  6. }
  7. func (s greedy) Sample(logits []float32, transforms ...Transform) (int, error) {
  8. logits64 := make([]float64, len(logits))
  9. for i, v := range logits {
  10. logits64[i] = float64(v)
  11. }
  12. var err error
  13. for _, t := range transforms {
  14. logits64, err = t.Apply(logits64)
  15. if err != nil {
  16. return -1, err
  17. }
  18. }
  19. return floats.MaxIdx(logits64), nil
  20. }