samplers_benchmark_test.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. package sample
  2. import (
  3. "fmt"
  4. "math/rand"
  5. "testing"
  6. )
  7. func BenchmarkWeightedSampler(b *testing.B) {
  8. sizes := []int{10, 100, 1000, 10000}
  9. for _, size := range sizes {
  10. b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
  11. logits := make([]float32, size)
  12. for i := range logits {
  13. logits[i] = float32(rand.Float64()*10 - 5)
  14. }
  15. sampler := NewSampler(0.8, 0, 0, 0, 42)
  16. b.ResetTimer()
  17. for b.Loop() {
  18. _, err := sampler.Sample(logits)
  19. if err != nil {
  20. b.Fatalf("Sampling failed: %v", err)
  21. }
  22. }
  23. })
  24. }
  25. configs := []struct {
  26. name string
  27. temperature float32
  28. topK int
  29. topP float32
  30. minP float32
  31. seed int
  32. }{
  33. {"Greedy", 0, -1, 0, 0, -1},
  34. {"Temperature", 0.8, -1, 0, 0, -1},
  35. {"TopK", 0.8, 50, 0, 0, -1},
  36. {"TopP", 0.8, -1, 0.9, 0, -1},
  37. {"MinP", 0.8, -1, 0, 0.05, -1},
  38. {"WithSeed", 0.8, 50, 0, 0, 42},
  39. }
  40. // Fixed size for common vocab size
  41. size := 128000
  42. logits := make([]float32, size)
  43. for i := range logits {
  44. logits[i] = float32(rand.Float64()*10 - 5)
  45. }
  46. for _, tc := range configs {
  47. b.Run("Config"+tc.name, func(b *testing.B) {
  48. sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed)
  49. sampler.Sample(logits)
  50. b.ResetTimer()
  51. for b.Loop() {
  52. _, err := sampler.Sample(logits)
  53. if err != nil {
  54. b.Fatalf("Sampling failed: %v", err)
  55. }
  56. }
  57. })
  58. }
  59. // Test with combined transforms separately - topK influences performance greatly
  60. b.Run("TransformCombined", func(b *testing.B) {
  61. sampler := NewSampler(0.8, 50, 0.9, 0.05, 42)
  62. b.ResetTimer()
  63. for b.Loop() {
  64. _, err := sampler.Sample(logits)
  65. if err != nil {
  66. b.Fatalf("Sampling failed: %v", err)
  67. }
  68. }
  69. })
  70. }
  71. func BenchmarkGreedySampler(b *testing.B) {
  72. sizes := []int{10, 100, 1000, 10000, 100000}
  73. for _, size := range sizes {
  74. b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
  75. logits := make([]float32, size)
  76. for i := range logits {
  77. logits[i] = float32(rand.Float64()*10 - 5)
  78. }
  79. sampler := NewSampler(0, -1, 0, 0, -1)
  80. b.ResetTimer()
  81. for b.Loop() {
  82. _, err := sampler.Sample(logits)
  83. if err != nil {
  84. b.Fatalf("Sampling failed: %v", err)
  85. }
  86. }
  87. })
  88. }
  89. }