sample_test.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. package sample
  2. import (
  3. "fmt"
  4. "math"
  5. "slices"
  6. "testing"
  7. "gonum.org/v1/gonum/floats"
  8. )
  9. func TestTemperature(t *testing.T) {
  10. logits, err := Temperature(0.5).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  11. if err != nil {
  12. t.Fatal(err)
  13. }
  14. expectedlogits := []float64{-6, -4, -2, 0, 2, 4, 8}
  15. if !floats.Equal(logits, expectedlogits) {
  16. t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
  17. }
  18. // Only expect the max value returned
  19. logits, err = Temperature(0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  20. if err != nil {
  21. t.Fatal(err)
  22. }
  23. expectedlogits = []float64{4}
  24. if !floats.Equal(logits, expectedlogits) {
  25. t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
  26. }
  27. if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
  28. t.Fatalf("expected error for temperature=-1, got %v", logits)
  29. }
  30. }
  31. func TestSoftmax(t *testing.T) {
  32. probs, err := computeSoftmax([]float64{-3, -2, -1, 0, 1, 2, 4})
  33. if err != nil {
  34. t.Fatal(err)
  35. }
  36. expectedProbs := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
  37. if !floats.Equal(probs, expectedProbs) {
  38. t.Fatalf("logits: %v, expectedlogits: %v", probs, expectedProbs)
  39. }
  40. }
  41. func TestTopK(t *testing.T) {
  42. logits, err := TopK(3).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  43. if err != nil {
  44. t.Fatal(err)
  45. }
  46. expectedlogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), 1, 2, 4}
  47. if !floats.Same(logits, expectedlogits) {
  48. t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
  49. }
  50. logits, err = TopK(0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  51. if err == nil {
  52. t.Fatalf("expected error for k=0, got %v", logits)
  53. }
  54. }
  55. func TestTopP(t *testing.T) {
  56. logits, err := TopP(0.9).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  57. if err != nil {
  58. t.Fatal(err)
  59. }
  60. expectedLogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 2, 4}
  61. if !floats.Same(logits, expectedLogits) {
  62. t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits)
  63. }
  64. logits, err = TopP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  65. if err == nil {
  66. t.Fatalf("expected error for p=1.0, got %v", logits)
  67. }
  68. logits, err = TopP(0.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4})
  69. if err == nil {
  70. t.Fatalf("expected error for p=0.0, got %v", logits)
  71. }
  72. }
  73. func TestMinP(t *testing.T) {
  74. logits, err := MinP(0.2).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
  75. if err != nil {
  76. t.Fatal(err)
  77. }
  78. expectedLogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 3, 4}
  79. if !floats.Same(logits, expectedLogits) {
  80. t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits)
  81. }
  82. logits, err = MinP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
  83. if err == nil {
  84. t.Fatalf("expected error for p=1.0, got %v", logits)
  85. }
  86. logits, err = MinP(0.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
  87. if err == nil {
  88. t.Fatalf("expected error for p=0.0, got %v", logits)
  89. }
  90. }
  91. func TestWeighed(t *testing.T) {
  92. logits, err := Weighed().Sample([]float64{math.NaN(), 2, math.NaN(), math.NaN()})
  93. if err != nil {
  94. t.Fatal(err)
  95. }
  96. expectedLogits := []float64{1}
  97. if !floats.Equal(logits, expectedLogits) {
  98. t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedLogits)
  99. }
  100. logits, err = Weighed().Sample([]float64{math.NaN(), math.NaN(), math.NaN()})
  101. if err == nil {
  102. t.Fatalf("expected error for no valid tokens, got %v", logits)
  103. }
  104. }
  105. func TestSample(t *testing.T) {
  106. input := []float64{1, 2, 3, 4}
  107. expectedOutput := []float64{1, 2, 3, 4}
  108. var callOrder []int
  109. mock1 := &mockSampler{
  110. id: 1,
  111. callOrder: &callOrder,
  112. returnVals: expectedOutput,
  113. }
  114. mock2 := &mockSampler{
  115. id: 2,
  116. callOrder: &callOrder,
  117. returnVals: expectedOutput,
  118. }
  119. mock3 := &mockSampler{
  120. id: 3,
  121. callOrder: &callOrder,
  122. returnVals: expectedOutput,
  123. }
  124. result, err := Sample(input, mock1, mock2, mock3)
  125. if err != nil {
  126. t.Fatal(err)
  127. }
  128. if !slices.Equal(callOrder, []int{1, 2, 3}) {
  129. t.Errorf("Expected call order [1,2,3], got %v", callOrder)
  130. }
  131. if !floats.Equal(result, expectedOutput) {
  132. t.Errorf("Expected output %v, got %v", expectedOutput, result)
  133. }
  134. errMock := &mockSampler{
  135. returnErr: fmt.Errorf("mock error"),
  136. }
  137. _, err = Sample(input, mock1, errMock, mock2)
  138. if err == nil {
  139. t.Error("Expected error from sampler")
  140. }
  141. }
  142. type mockSampler struct {
  143. id int
  144. callOrder *[]int
  145. returnVals []float64
  146. returnErr error
  147. }
  148. func (m *mockSampler) Sample(logits []float64) ([]float64, error) {
  149. if m.callOrder != nil {
  150. *m.callOrder = append(*m.callOrder, m.id)
  151. }
  152. if m.returnErr != nil {
  153. return nil, m.returnErr
  154. }
  155. return m.returnVals, nil
  156. }