sample.go 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. package sample
  2. import (
  3. "slices"
  4. "gonum.org/v1/gonum/floats"
  5. "gonum.org/v1/gonum/stat/sampleuv"
  6. )
  7. type Sampler interface {
  8. Sample([]float64) ([]float64, error)
  9. }
  10. type Temperature float64
  11. func (s Temperature) Sample(t []float64) ([]float64, error) {
  12. floats.Div(t, slices.Repeat([]float64{float64(s)}, len(t)))
  13. return t, nil
  14. }
  15. type softmax struct{}
  16. func Softmax() Sampler {
  17. return softmax{}
  18. }
  19. func (softmax) Sample(t []float64) ([]float64, error) {
  20. return t, nil
  21. }
  22. type TopK int
  23. func (s TopK) Sample(t []float64) ([]float64, error) {
  24. return t, nil
  25. }
  26. type TopP float32
  27. func (s TopP) Sample(t []float64) ([]float64, error) {
  28. return t, nil
  29. }
  30. type MinP float32
  31. func (s MinP) Sample(t []float64) ([]float64, error) {
  32. return t, nil
  33. }
  34. type weighed struct{}
  35. func Weighed() Sampler {
  36. return weighed{}
  37. }
  38. func (s weighed) Sample(t []float64) ([]float64, error) {
  39. w := sampleuv.NewWeighted(t, nil)
  40. if v, ok := w.Take(); ok {
  41. return []float64{float64(v)}, nil
  42. }
  43. return t, nil
  44. }
  45. func Sample(floats []float64, samplers ...Sampler) ([]float64, error) {
  46. var err error
  47. for _, sampler := range samplers {
  48. floats, err = sampler.Sample(floats)
  49. if err != nil {
  50. return nil, err
  51. }
  52. }
  53. return floats, nil
  54. }