sample.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. package sample
  2. import (
  3. "cmp"
  4. "errors"
  5. "math"
  6. "slices"
  7. "gonum.org/v1/gonum/floats"
  8. "gonum.org/v1/gonum/stat/sampleuv"
  9. )
  10. type Sampler interface {
  11. Sample([]float64) ([]float64, error)
  12. }
  13. type Temperature float64
  14. func (s Temperature) Sample(logits []float64) ([]float64, error) {
  15. if s < 0 || s > 1 {
  16. return nil, errors.New("temperature must be between 0 and 1")
  17. }
  18. // greedy sampling
  19. if s == 0 {
  20. return []float64{floats.Max(logits)}, nil
  21. }
  22. floats.Scale(1.0/float64(s), logits)
  23. return logits, nil
  24. }
  25. type softmax struct{}
  26. func Softmax() Sampler {
  27. return softmax{}
  28. }
  29. func (softmax) Sample(logits []float64) ([]float64, error) {
  30. return computeSoftmax(logits)
  31. }
  32. func computeSoftmax(logits []float64) ([]float64, error) {
  33. copiedLogits := make([]float64, len(logits))
  34. copy(copiedLogits, logits)
  35. for i := range copiedLogits {
  36. copiedLogits[i] = math.Exp(copiedLogits[i])
  37. }
  38. floatSum := floats.Sum(copiedLogits)
  39. if floatSum == 0 {
  40. return nil, errors.New("no valid tokens found")
  41. }
  42. floats.Scale(1.0/floatSum, copiedLogits)
  43. return copiedLogits, nil
  44. }
  45. type TopK int
  46. func (k TopK) Sample(logits []float64) ([]float64, error) {
  47. if k <= 0 {
  48. return nil, errors.New("k must be positive")
  49. }
  50. if int(k) >= len(logits) {
  51. return logits, nil
  52. }
  53. indices := make([]int, len(logits))
  54. for i := range indices {
  55. indices[i] = i
  56. }
  57. // sort in descending order
  58. slices.SortFunc(indices, func(i, j int) int {
  59. return cmp.Compare(logits[j], logits[i])
  60. })
  61. for _, idx := range indices[k:] {
  62. logits[idx] = math.NaN()
  63. }
  64. return logits, nil
  65. }
  66. type TopP float32
  67. func (p TopP) Sample(logits []float64) ([]float64, error) {
  68. if p <= 0 || p >= 1 {
  69. return nil, errors.New("p must be between 0 and 1")
  70. }
  71. probs, err := computeSoftmax(logits)
  72. if err != nil {
  73. return nil, err
  74. }
  75. indices := make([]int, len(probs))
  76. for i := range indices {
  77. indices[i] = i
  78. }
  79. // sort in descending order
  80. slices.SortFunc(indices, func(i, j int) int {
  81. return cmp.Compare(probs[j], probs[i])
  82. })
  83. cumSum := 0.0
  84. for i, idx := range indices {
  85. cumSum += probs[idx]
  86. if cumSum > float64(p) {
  87. for _, idx := range indices[i+1:] {
  88. logits[idx] = math.NaN()
  89. }
  90. break
  91. }
  92. }
  93. return logits, nil
  94. }
  95. type MinP float32
  96. func (p MinP) Sample(logits []float64) ([]float64, error) {
  97. if p <= 0 || p >= 1 {
  98. return nil, errors.New("p must be between 0 and 1")
  99. }
  100. probs, err := computeSoftmax(logits)
  101. if err != nil {
  102. return nil, err
  103. }
  104. copiedProbs := make([]float64, len(probs))
  105. copy(copiedProbs, probs)
  106. slices.Sort(copiedProbs)
  107. maxProb := copiedProbs[len(copiedProbs)-1]
  108. probThreshold := float64(p) * maxProb
  109. for i := range probs {
  110. if probs[i] < probThreshold {
  111. logits[i] = math.NaN()
  112. }
  113. }
  114. return logits, nil
  115. }
  116. type weighed struct{}
  117. func Weighed() Sampler {
  118. return weighed{}
  119. }
  120. func (s weighed) Sample(logits []float64) ([]float64, error) {
  121. logitsCopy := make([]float64, 0, len(logits))
  122. indices := make([]int, 0, len(logits))
  123. // the uv sampler does not support NaN values
  124. for i, logit := range logits {
  125. if !math.IsNaN(logit) {
  126. logitsCopy = append(logitsCopy, logit)
  127. indices = append(indices, i)
  128. }
  129. }
  130. if len(logitsCopy) == 0 {
  131. return nil, errors.New("no valid tokens found")
  132. }
  133. // usually, a softmax is applied to sample from the logits
  134. // in this case the uv sampler normalizes the logits so that the sum of the weights is 1
  135. w := sampleuv.NewWeighted(logitsCopy, nil)
  136. if v, ok := w.Take(); ok {
  137. // returns the token ID
  138. return []float64{float64(indices[v])}, nil
  139. }
  140. return nil, errors.New("weighed sampler failed")
  141. }
  142. func Sample(logits []float64, samplers ...Sampler) ([]float64, error) {
  143. var err error
  144. for _, sampler := range samplers {
  145. logits, err = sampler.Sample(logits)
  146. if err != nil {
  147. return nil, err
  148. }
  149. }
  150. return logits, nil
  151. }