samplers_benchmark_test.go 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. package sample
  2. import (
  3. "fmt"
  4. "math/rand"
  5. "testing"
  6. )
  7. func BenchmarkWeightedSampler(b *testing.B) {
  8. sizes := []int{10, 100, 1000, 10000}
  9. for _, size := range sizes {
  10. b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
  11. logits := make([]float32, size)
  12. for i := range logits {
  13. logits[i] = float32(rand.Float64()*10 - 5)
  14. }
  15. sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
  16. b.ResetTimer()
  17. for b.Loop() {
  18. sampler.Sample(logits)
  19. }
  20. })
  21. }
  22. configs := []struct {
  23. name string
  24. temperature float32
  25. topK int
  26. topP float32
  27. minP float32
  28. seed int
  29. }{
  30. {"Greedy", 0, -1, 0, 0, -1},
  31. {"Temperature", 0.8, -1, 0, 0, -1},
  32. {"TopK", 0.8, 50, 0, 0, -1},
  33. {"TopP", 0.8, -1, 0.9, 0, -1},
  34. {"MinP", 0.8, -1, 0, 0.05, -1},
  35. {"WithSeed", 0.8, 50, 0, 0, 42},
  36. }
  37. // Fixed size for common vocab size
  38. size := 128000
  39. logits := make([]float32, size)
  40. for i := range logits {
  41. logits[i] = float32(rand.Float64()*10 - 5)
  42. }
  43. for _, tc := range configs {
  44. b.Run("Config"+tc.name, func(b *testing.B) {
  45. sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
  46. sampler.Sample(logits)
  47. b.ResetTimer()
  48. for b.Loop() {
  49. sampler.Sample(logits)
  50. }
  51. })
  52. }
  53. // Test with combined transforms separately - topK influences performance greatly
  54. b.Run("TransformCombined", func(b *testing.B) {
  55. sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
  56. b.ResetTimer()
  57. for b.Loop() {
  58. sampler.Sample(logits)
  59. }
  60. })
  61. }
  62. func BenchmarkGreedySampler(b *testing.B) {
  63. sizes := []int{10, 100, 1000, 10000, 100000}
  64. for _, size := range sizes {
  65. b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
  66. logits := make([]float32, size)
  67. for i := range logits {
  68. logits[i] = float32(rand.Float64()*10 - 5)
  69. }
  70. sampler := NewSampler(0, -1, 0, 0, -1, nil)
  71. b.ResetTimer()
  72. for b.Loop() {
  73. sampler.Sample(logits)
  74. }
  75. })
  76. }
  77. }