1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- package sample
- import (
- "math"
- "math/rand/v2"
- "testing"
- "github.com/google/go-cmp/cmp"
- )
- func TestTemperature(t *testing.T) {
- got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
- want := []float64{-4, -10, 0, -14, -6, -12, -8}
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("logits mismatch (-want +got):\n%s", diff)
- }
- }
- func TestSoftmax(t *testing.T) {
- got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
- want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("probs mismatch (-want +got):\n%s", diff)
- }
- }
- func TestTopK(t *testing.T) {
- got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
- want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("logits mismatch (-want +got):\n%s", diff)
- }
- got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
- want = []float64{-3, -2, -1, 0, 1, 2, 4}
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("logits mismatch (-want +got):\n%s", diff)
- }
- }
- func TestTopP(t *testing.T) {
- got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
- want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("logits mismatch (-want +got):\n%s", diff)
- }
- }
- func TestMinP(t *testing.T) {
- got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
- want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("logits mismatch (-want +got):\n%s", diff)
- }
- }
- func BenchmarkTransform(b *testing.B) {
- transforms := map[string]Transform{
- "Temperature": Temperature(0.5),
- "TopK": TopK(10),
- "TopP": TopP(0.9),
- "MinP": MinP(0.2),
- }
- logits := make([]float64, 1<<16)
- for i := range logits {
- logits[i] = rand.Float64()
- }
- for name, transform := range transforms {
- b.Run(name, func(b *testing.B) {
- b.ResetTimer()
- for range b.N {
- transform.Apply(logits)
- }
- })
- }
- }
|