sample_test.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. package sample
  2. import (
  3. "fmt"
  4. "math"
  5. "math/rand"
  6. "os"
  7. "runtime"
  8. "slices"
  9. "testing"
  10. "runtime/trace"
  11. "gonum.org/v1/gonum/floats"
  12. )
  13. func TestTemperature(t *testing.T) {
  14. logits, err := Temperature(0.5).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  15. if err != nil {
  16. t.Fatal(err)
  17. }
  18. want := []float64{-14, -12, -10, -8, -6, -4, 0}
  19. if !floats.Equal(logits, want) {
  20. t.Fatalf("got: %v, want: %v", logits, want)
  21. }
  22. if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
  23. t.Fatalf("expected error for temperature=-1, got %v", logits)
  24. }
  25. if _, err := Temperature(2.1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
  26. t.Fatalf("expected error for temperature=2.1, got %v", logits)
  27. }
  28. }
  29. func TestSoftmax(t *testing.T) {
  30. probs, err := Softmax().Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  31. if err != nil {
  32. t.Fatal(err)
  33. }
  34. expectedProbs := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
  35. if !floats.Equal(probs, expectedProbs) {
  36. t.Fatalf("logits: %v, expectedlogits: %v", probs, expectedProbs)
  37. }
  38. }
  39. func TestTopK(t *testing.T) {
  40. logits, err := TopK(3).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  41. if err != nil {
  42. t.Fatal(err)
  43. }
  44. expectedlogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), 1, 2, 4}
  45. if !floats.Same(logits, expectedlogits) {
  46. t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
  47. }
  48. logits, err = TopK(0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  49. if err == nil {
  50. t.Fatalf("expected error for k=0, got %v", logits)
  51. }
  52. }
  53. func TestTopP(t *testing.T) {
  54. logits, err := TopP(0.9).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  55. if err != nil {
  56. t.Fatal(err)
  57. }
  58. want := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 2, 4}
  59. if !floats.Same(logits, want) {
  60. t.Fatalf("got: %v, want: %v", logits, want)
  61. }
  62. logits, err = TopP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  63. if err == nil {
  64. t.Fatalf("expected error for p=1.0, got %v", logits)
  65. }
  66. logits, err = TopP(0.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  67. if err == nil {
  68. t.Fatalf("expected error for p=0.0, got %v", logits)
  69. }
  70. }
  71. func TestMinP(t *testing.T) {
  72. logits, err := MinP(0.2).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
  73. if err != nil {
  74. t.Fatal(err)
  75. }
  76. want := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 3, 4}
  77. if !floats.Same(logits, want) {
  78. t.Fatalf("got: %v, want: %v", logits, want)
  79. }
  80. logits, err = MinP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
  81. if err == nil {
  82. t.Fatalf("expected error for p=1.0, got %v", logits)
  83. }
  84. logits, err = MinP(0.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
  85. if err == nil {
  86. t.Fatalf("expected error for p=0.0, got %v", logits)
  87. }
  88. }
  89. func TestWeighed(t *testing.T) {
  90. logits, err := Weighed().Sample([]float64{math.NaN(), 2, math.NaN(), math.NaN()})
  91. if err != nil {
  92. t.Fatal(err)
  93. }
  94. want := []float64{1}
  95. if !floats.Equal(logits, want) {
  96. t.Fatalf("got: %v, want: %v", logits, want)
  97. }
  98. logits, err = Weighed().Sample([]float64{math.NaN(), math.NaN(), math.NaN()})
  99. if err == nil {
  100. t.Fatalf("expected error for no valid tokens, got %v", logits)
  101. }
  102. }
  103. func TestSample(t *testing.T) {
  104. input := []float64{1, 2, 3, 4}
  105. want := []float64{1, 2, 3, 4}
  106. var callOrder []int
  107. mock1 := &testSampler{
  108. id: 1,
  109. callOrder: &callOrder,
  110. returnVals: want,
  111. }
  112. mock2 := &testSampler{
  113. id: 2,
  114. callOrder: &callOrder,
  115. returnVals: want,
  116. }
  117. mock3 := &testSampler{
  118. id: 3,
  119. callOrder: &callOrder,
  120. returnVals: want,
  121. }
  122. got, err := Sample(input, mock1, mock2, mock3)
  123. if err != nil {
  124. t.Fatal(err)
  125. }
  126. if !slices.Equal(callOrder, []int{1, 2, 3}) {
  127. t.Errorf("got %v, want %v", callOrder, []int{1, 2, 3})
  128. }
  129. if !floats.Equal(got, want) {
  130. t.Errorf("got %v, want %v", got, want)
  131. }
  132. errMock := &testSampler{
  133. returnErr: fmt.Errorf("mock error"),
  134. }
  135. _, err = Sample(input, mock1, errMock, mock2)
  136. if err == nil {
  137. t.Error("Expected error from sampler")
  138. }
  139. }
  140. type testSampler struct {
  141. id int
  142. callOrder *[]int
  143. returnVals []float64
  144. returnErr error
  145. }
  146. func (ts *testSampler) Sample(logits []float64) ([]float64, error) {
  147. if ts.callOrder != nil {
  148. *ts.callOrder = append(*ts.callOrder, ts.id)
  149. }
  150. if ts.returnErr != nil {
  151. return nil, ts.returnErr
  152. }
  153. return ts.returnVals, nil
  154. }
  155. func TestSampleTemperatureZero(t *testing.T) {
  156. logits, err := Sample([]float64{1, 2, 3, 4}, Temperature(0))
  157. if err != nil {
  158. t.Fatal(err)
  159. }
  160. want := []float64{3}
  161. if !floats.Equal(logits, want) {
  162. t.Fatalf("got: %v, want: %v", logits, want)
  163. }
  164. }