sample_test.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. package sample
  2. import (
  3. "fmt"
  4. "math"
  5. "math/rand/v2"
  6. "testing"
  7. "github.com/google/go-cmp/cmp"
  8. )
  9. func TestTemperature(t *testing.T) {
  10. logits, err := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
  11. if err != nil {
  12. t.Error(err)
  13. return
  14. }
  15. want := []float64{-4, -10, 0, -14, -6, -12, -8}
  16. if diff := cmp.Diff(want, logits); diff != "" {
  17. t.Errorf("logits mismatch (-want +got):\n%s", diff)
  18. }
  19. logits, err = Temperature(-1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
  20. if err == nil {
  21. t.Errorf("expected error for temperature=-1, got %v", logits)
  22. }
  23. logits, err = Temperature(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
  24. if err == nil {
  25. t.Errorf("expected error for temperature=0, got %v", logits)
  26. }
  27. logits, err = Temperature(2.1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
  28. if err == nil {
  29. t.Errorf("expected error for temperature=2.1, got %v", logits)
  30. }
  31. }
  32. func TestSoftmax(t *testing.T) {
  33. probs := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
  34. expectedProbs := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
  35. if diff := cmp.Diff(expectedProbs, probs); diff != "" {
  36. t.Errorf("probs mismatch (-want +got):\n%s", diff)
  37. }
  38. }
  39. func TestTopK(t *testing.T) {
  40. logits, err := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
  41. if err != nil {
  42. t.Error(err)
  43. return
  44. }
  45. expectedlogits := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
  46. if diff := cmp.Diff(expectedlogits, logits); diff != "" {
  47. t.Errorf("logits mismatch (-want +got):\n%s", diff)
  48. }
  49. _, err = TopK(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
  50. if err == nil {
  51. t.Errorf("expected error for k=0, got %v", err)
  52. }
  53. logits, err = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
  54. if err != nil {
  55. t.Error(err)
  56. return
  57. }
  58. expectedlogits = []float64{-3, -2, -1, 0, 1, 2, 4}
  59. if diff := cmp.Diff(expectedlogits, logits); diff != "" {
  60. t.Errorf("logits mismatch (-want +got):\n%s", diff)
  61. }
  62. }
  63. func TestTopP(t *testing.T) {
  64. logits, err := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
  65. if err != nil {
  66. t.Error(err)
  67. return
  68. }
  69. want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
  70. if diff := cmp.Diff(want, logits); diff != "" {
  71. t.Errorf("logits mismatch (-want +got):\n%s", diff)
  72. }
  73. _, err = TopP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
  74. if err == nil {
  75. t.Error("expected error for p=1.0")
  76. }
  77. _, err = TopP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
  78. if err == nil {
  79. t.Error("expected error for p=0.0")
  80. }
  81. }
  82. func TestMinP(t *testing.T) {
  83. logits, err := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
  84. if err != nil {
  85. t.Error(err)
  86. return
  87. }
  88. want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
  89. if diff := cmp.Diff(want, logits); diff != "" {
  90. t.Errorf("logits mismatch (-want +got):\n%s", diff)
  91. }
  92. _, err = MinP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
  93. if err == nil {
  94. t.Error("expected error for p=1.0")
  95. }
  96. _, err = MinP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
  97. if err == nil {
  98. t.Error("expected error for p=0.0")
  99. }
  100. }
  101. func TestWeighed(t *testing.T) {
  102. idx, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
  103. if err != nil {
  104. t.Error(err)
  105. return
  106. }
  107. want := 1
  108. if diff := cmp.Diff(want, idx); diff != "" {
  109. t.Errorf("index mismatch (-want +got):\n%s", diff)
  110. }
  111. idx, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
  112. if err == nil {
  113. t.Error("expected error for no valid tokens, got index", idx)
  114. }
  115. }
  116. func TestSample(t *testing.T) {
  117. input := []float32{1, 2, 3, 4}
  118. var callOrder []int
  119. mock1 := &testTransform{
  120. id: 1,
  121. callOrder: &callOrder,
  122. }
  123. mock2 := &testTransform{
  124. id: 2,
  125. callOrder: &callOrder,
  126. }
  127. mock3 := &testTransform{
  128. id: 3,
  129. callOrder: &callOrder,
  130. }
  131. got, err := Greedy().Sample(input, mock1, mock2, mock3)
  132. if err != nil {
  133. t.Error(err)
  134. return
  135. }
  136. want := 3 // Greedy sampler should pick highest logit
  137. if diff := cmp.Diff(want, got); diff != "" {
  138. t.Errorf("sampled index mismatch (-want +got):\n%s", diff)
  139. }
  140. _, err = Weighted(nil).Sample(input, mock1, mock2, mock3)
  141. if err != nil {
  142. t.Error(err)
  143. return
  144. }
  145. wantOrder := []int{1, 2, 3}
  146. if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
  147. t.Errorf("call order mismatch (-want +got):\n%s", diff)
  148. }
  149. errMock := &testTransform{
  150. returnErr: fmt.Errorf("mock error"),
  151. }
  152. _, err = Weighted(nil).Sample(input, mock1, errMock, mock2)
  153. if err == nil {
  154. t.Error("Expected error from sampler")
  155. }
  156. }
  157. type testTransform struct {
  158. id int
  159. callOrder *[]int
  160. returnErr error
  161. }
  162. func (ts *testTransform) Apply(logits []float64) ([]float64, error) {
  163. if ts.callOrder != nil {
  164. *ts.callOrder = append(*ts.callOrder, ts.id)
  165. }
  166. if ts.returnErr != nil {
  167. return nil, ts.returnErr
  168. }
  169. return logits, nil
  170. }
  171. func BenchmarkTransform(b *testing.B) {
  172. transforms := map[string]Transform{
  173. "Temperature": Temperature(0.5),
  174. "TopK": TopK(10),
  175. "TopP": TopP(0.9),
  176. "MinP": MinP(0.2),
  177. }
  178. logits := make([]float64, 1<<16)
  179. for i := range logits {
  180. logits[i] = rand.Float64()
  181. }
  182. for name, transform := range transforms {
  183. b.Run(name, func(b *testing.B) {
  184. b.ResetTimer()
  185. for range b.N {
  186. _, err := transform.Apply(logits)
  187. if err != nil {
  188. b.Error(err)
  189. }
  190. }
  191. })
  192. }
  193. }
  194. func BenchmarkSample(b *testing.B) {
  195. samplers := map[string]Sampler{
  196. "Greedy": Greedy(),
  197. "Weighted": Weighted(nil),
  198. }
  199. logits := make([]float32, 1<<16)
  200. for i := range logits {
  201. logits[i] = rand.Float32()
  202. }
  203. for name, s := range samplers {
  204. b.Run(name, func(b *testing.B) {
  205. b.ResetTimer()
  206. for range b.N {
  207. if _, err := s.Sample(logits); err != nil {
  208. b.Error(err)
  209. }
  210. }
  211. })
  212. }
  213. }