transforms_test.go 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. package sample
  2. import (
  3. "math"
  4. "math/rand/v2"
  5. "testing"
  6. "github.com/google/go-cmp/cmp"
  7. )
  8. func TestTemperature(t *testing.T) {
  9. got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
  10. want := []float64{-4, -10, 0, -14, -6, -12, -8}
  11. if diff := cmp.Diff(want, got); diff != "" {
  12. t.Errorf("logits mismatch (-want +got):\n%s", diff)
  13. }
  14. }
  15. func TestSoftmax(t *testing.T) {
  16. got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
  17. want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
  18. if diff := cmp.Diff(want, got); diff != "" {
  19. t.Errorf("probs mismatch (-want +got):\n%s", diff)
  20. }
  21. }
  22. func TestTopK(t *testing.T) {
  23. got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
  24. want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
  25. if diff := cmp.Diff(want, got); diff != "" {
  26. t.Errorf("logits mismatch (-want +got):\n%s", diff)
  27. }
  28. got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
  29. want = []float64{-3, -2, -1, 0, 1, 2, 4}
  30. if diff := cmp.Diff(want, got); diff != "" {
  31. t.Errorf("logits mismatch (-want +got):\n%s", diff)
  32. }
  33. }
  34. func TestTopP(t *testing.T) {
  35. got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
  36. want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
  37. if diff := cmp.Diff(want, got); diff != "" {
  38. t.Errorf("logits mismatch (-want +got):\n%s", diff)
  39. }
  40. }
  41. func TestMinP(t *testing.T) {
  42. got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
  43. want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
  44. if diff := cmp.Diff(want, got); diff != "" {
  45. t.Errorf("logits mismatch (-want +got):\n%s", diff)
  46. }
  47. }
  48. func BenchmarkTransform(b *testing.B) {
  49. transforms := map[string]Transform{
  50. "Temperature": Temperature(0.5),
  51. "TopK": TopK(10),
  52. "TopP": TopP(0.9),
  53. "MinP": MinP(0.2),
  54. }
  55. logits := make([]float64, 1<<16)
  56. for i := range logits {
  57. logits[i] = rand.Float64()
  58. }
  59. for name, transform := range transforms {
  60. b.Run(name, func(b *testing.B) {
  61. b.ResetTimer()
  62. for range b.N {
  63. transform.Apply(logits)
  64. }
  65. })
  66. }
  67. }