sample_test.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. package sample
  2. import (
  3. "fmt"
  4. "math"
  5. "slices"
  6. "testing"
  7. "gonum.org/v1/gonum/floats"
  8. )
  9. func TestTemperature(t *testing.T) {
  10. logits, err := Temperature(0.5).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  11. if err != nil {
  12. t.Fatal(err)
  13. }
  14. want := []float64{-14, -12, -10, -8, -6, -4, 0}
  15. if !floats.Equal(logits, want) {
  16. t.Fatalf("got: %v, want: %v", logits, want)
  17. }
  18. if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
  19. t.Fatalf("expected error for temperature=-1, got %v", logits)
  20. }
  21. if _, err := Temperature(2.1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
  22. t.Fatalf("expected error for temperature=2.1, got %v", logits)
  23. }
  24. }
  25. func TestSoftmax(t *testing.T) {
  26. probs, err := Softmax().Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  27. if err != nil {
  28. t.Fatal(err)
  29. }
  30. expectedProbs := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
  31. if !floats.Equal(probs, expectedProbs) {
  32. t.Fatalf("logits: %v, expectedlogits: %v", probs, expectedProbs)
  33. }
  34. }
  35. func TestTopK(t *testing.T) {
  36. logits, err := TopK(3).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  37. if err != nil {
  38. t.Fatal(err)
  39. }
  40. expectedlogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), 1, 2, 4}
  41. if !floats.Same(logits, expectedlogits) {
  42. t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
  43. }
  44. logits, err = TopK(0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  45. if err == nil {
  46. t.Fatalf("expected error for k=0, got %v", logits)
  47. }
  48. }
  49. func TestTopP(t *testing.T) {
  50. logits, err := TopP(0.9).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  51. if err != nil {
  52. t.Fatal(err)
  53. }
  54. want := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 2, 4}
  55. if !floats.Same(logits, want) {
  56. t.Fatalf("got: %v, want: %v", logits, want)
  57. }
  58. logits, err = TopP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  59. if err == nil {
  60. t.Fatalf("expected error for p=1.0, got %v", logits)
  61. }
  62. logits, err = TopP(0.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  63. if err == nil {
  64. t.Fatalf("expected error for p=0.0, got %v", logits)
  65. }
  66. }
  67. func TestMinP(t *testing.T) {
  68. logits, err := MinP(0.2).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
  69. if err != nil {
  70. t.Fatal(err)
  71. }
  72. want := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 3, 4}
  73. if !floats.Same(logits, want) {
  74. t.Fatalf("got: %v, want: %v", logits, want)
  75. }
  76. logits, err = MinP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
  77. if err == nil {
  78. t.Fatalf("expected error for p=1.0, got %v", logits)
  79. }
  80. logits, err = MinP(0.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
  81. if err == nil {
  82. t.Fatalf("expected error for p=0.0, got %v", logits)
  83. }
  84. }
  85. func TestWeighed(t *testing.T) {
  86. logits, err := Weighed().Sample([]float64{math.NaN(), 2, math.NaN(), math.NaN()})
  87. if err != nil {
  88. t.Fatal(err)
  89. }
  90. want := []float64{1}
  91. if !floats.Equal(logits, want) {
  92. t.Fatalf("got: %v, want: %v", logits, want)
  93. }
  94. logits, err = Weighed().Sample([]float64{math.NaN(), math.NaN(), math.NaN()})
  95. if err == nil {
  96. t.Fatalf("expected error for no valid tokens, got %v", logits)
  97. }
  98. }
  99. func TestSample(t *testing.T) {
  100. input := []float64{1, 2, 3, 4}
  101. want := []float64{1, 2, 3, 4}
  102. var callOrder []int
  103. mock1 := &testSampler{
  104. id: 1,
  105. callOrder: &callOrder,
  106. returnVals: want,
  107. }
  108. mock2 := &testSampler{
  109. id: 2,
  110. callOrder: &callOrder,
  111. returnVals: want,
  112. }
  113. mock3 := &testSampler{
  114. id: 3,
  115. callOrder: &callOrder,
  116. returnVals: want,
  117. }
  118. got, err := Sample(input, mock1, mock2, mock3)
  119. if err != nil {
  120. t.Fatal(err)
  121. }
  122. if !slices.Equal(callOrder, []int{1, 2, 3}) {
  123. t.Errorf("got %v, want %v", callOrder, []int{1, 2, 3})
  124. }
  125. if !floats.Equal(got, want) {
  126. t.Errorf("got %v, want %v", got, want)
  127. }
  128. errMock := &testSampler{
  129. returnErr: fmt.Errorf("mock error"),
  130. }
  131. _, err = Sample(input, mock1, errMock, mock2)
  132. if err == nil {
  133. t.Error("Expected error from sampler")
  134. }
  135. }
  136. type testSampler struct {
  137. id int
  138. callOrder *[]int
  139. returnVals []float64
  140. returnErr error
  141. }
  142. func (ts *testSampler) Sample(logits []float64) ([]float64, error) {
  143. if ts.callOrder != nil {
  144. *ts.callOrder = append(*ts.callOrder, ts.id)
  145. }
  146. if ts.returnErr != nil {
  147. return nil, ts.returnErr
  148. }
  149. return ts.returnVals, nil
  150. }
  151. func TestSampleTemperatureZero(t *testing.T) {
  152. logits, err := Sample([]float64{1, 2, 3, 4}, Temperature(0))
  153. if err != nil {
  154. t.Fatal(err)
  155. }
  156. want := []float64{3}
  157. if !floats.Equal(logits, want) {
  158. t.Fatalf("got: %v, want: %v", logits, want)
  159. }
  160. }