123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- package sample
- import (
- "fmt"
- "math/rand"
- "testing"
- )
- func BenchmarkWeightedSampler(b *testing.B) {
- sizes := []int{10, 100, 1000, 10000}
- for _, size := range sizes {
- b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
- logits := make([]float32, size)
- for i := range logits {
- logits[i] = float32(rand.Float64()*10 - 5)
- }
- sampler := NewSampler(0.8, 0, 0, 0, 42)
- b.ResetTimer()
- for b.Loop() {
- _, err := sampler.Sample(logits)
- if err != nil {
- b.Fatalf("Sampling failed: %v", err)
- }
- }
- })
- }
- configs := []struct {
- name string
- temperature float32
- topK int
- topP float32
- minP float32
- seed int
- }{
- {"Greedy", 0, -1, 0, 0, -1},
- {"Temperature", 0.8, -1, 0, 0, -1},
- {"TopK", 0.8, 50, 0, 0, -1},
- {"TopP", 0.8, -1, 0.9, 0, -1},
- {"MinP", 0.8, -1, 0, 0.05, -1},
- {"WithSeed", 0.8, 50, 0, 0, 42},
- }
- // Fixed size for common vocab size
- size := 128000
- logits := make([]float32, size)
- for i := range logits {
- logits[i] = float32(rand.Float64()*10 - 5)
- }
- for _, tc := range configs {
- b.Run("Config"+tc.name, func(b *testing.B) {
- sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed)
- sampler.Sample(logits)
- b.ResetTimer()
- for b.Loop() {
- _, err := sampler.Sample(logits)
- if err != nil {
- b.Fatalf("Sampling failed: %v", err)
- }
- }
- })
- }
- // Test with combined transforms separately - topK influences performance greatly
- b.Run("TransformCombined", func(b *testing.B) {
- sampler := NewSampler(0.8, 50, 0.9, 0.05, 42)
- b.ResetTimer()
- for b.Loop() {
- _, err := sampler.Sample(logits)
- if err != nil {
- b.Fatalf("Sampling failed: %v", err)
- }
- }
- })
- }
- func BenchmarkGreedySampler(b *testing.B) {
- sizes := []int{10, 100, 1000, 10000, 100000}
- for _, size := range sizes {
- b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
- logits := make([]float32, size)
- for i := range logits {
- logits[i] = float32(rand.Float64()*10 - 5)
- }
- sampler := NewSampler(0, -1, 0, 0, -1)
- b.ResetTimer()
- for b.Loop() {
- _, err := sampler.Sample(logits)
- if err != nil {
- b.Fatalf("Sampling failed: %v", err)
- }
- }
- })
- }
- }
|