1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- package sample
- import (
- "slices"
- "gonum.org/v1/gonum/floats"
- "gonum.org/v1/gonum/stat/sampleuv"
- )
- type Sampler interface {
- Sample([]float64) ([]float64, error)
- }
- type Temperature float64
- func (s Temperature) Sample(t []float64) ([]float64, error) {
- floats.Div(t, slices.Repeat([]float64{float64(s)}, len(t)))
- return t, nil
- }
- type softmax struct{}
- func Softmax() Sampler {
- return softmax{}
- }
- func (softmax) Sample(t []float64) ([]float64, error) {
- return t, nil
- }
- type TopK int
- func (s TopK) Sample(t []float64) ([]float64, error) {
- return t, nil
- }
- type TopP float32
- func (s TopP) Sample(t []float64) ([]float64, error) {
- return t, nil
- }
- type MinP float32
- func (s MinP) Sample(t []float64) ([]float64, error) {
- return t, nil
- }
- type weighed struct{}
- func Weighed() Sampler {
- return weighed{}
- }
- func (s weighed) Sample(t []float64) ([]float64, error) {
- w := sampleuv.NewWeighted(t, nil)
- if v, ok := w.Take(); ok {
- return []float64{float64(v)}, nil
- }
- return t, nil
- }
- func Sample(floats []float64, samplers ...Sampler) ([]float64, error) {
- var err error
- for _, sampler := range samplers {
- floats, err = sampler.Sample(floats)
- if err != nil {
- return nil, err
- }
- }
- return floats, nil
- }
|