samplers.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. package sample
  2. import (
  3. "errors"
  4. "math/rand/v2"
  5. "slices"
  6. )
  7. // Sampler is not thread-safe. Each goroutine should have its own instance
  8. type Sampler interface {
  9. Sample([]float32) (int32, error)
  10. }
  11. // logit represents information about a single token during sampling
  12. type logit struct {
  13. id int32 // The token's unique identifier
  14. value float32 // The raw logit or probability from the model
  15. }
  16. type weighted struct {
  17. rng *rand.Rand
  18. tokens []logit
  19. topK int
  20. topP float32
  21. minP float32
  22. temperature float32
  23. }
  24. func (s *weighted) Sample(logits []float32) (int32, error) {
  25. if len(s.tokens) < len(logits) {
  26. s.tokens = make([]logit, len(logits))
  27. }
  28. tokens := s.tokens[:len(logits)]
  29. for i, v := range logits {
  30. tokens[i].id = int32(i)
  31. tokens[i].value = v
  32. }
  33. // Tokens are sorted by logits in TopK or SortTokens
  34. if s.topK > 0 {
  35. tokens = topK(tokens, s.topK)
  36. } else {
  37. sortLogits(tokens)
  38. }
  39. tokens = temperature(tokens, s.temperature)
  40. tokens = softmax(tokens)
  41. tokens = topP(tokens, s.topP)
  42. tokens = minP(tokens, s.minP)
  43. if len(tokens) == 0 {
  44. return -1, errors.New("no valid logits found for weighted sampling")
  45. }
  46. var r float32
  47. if s.rng != nil {
  48. r = s.rng.Float32()
  49. } else {
  50. r = rand.Float32()
  51. }
  52. // Calculate cumulative sum of probabilities
  53. var sum float32
  54. for i := range tokens {
  55. sum += tokens[i].value
  56. tokens[i].value = sum
  57. }
  58. r *= tokens[len(tokens)-1].value
  59. idx, _ := slices.BinarySearchFunc(tokens, r, func(token logit, target float32) int {
  60. // Compare cumulative probabilities
  61. if token.value < target {
  62. return -1
  63. }
  64. // First token that exceeds target
  65. return 1
  66. })
  67. if idx >= len(tokens) {
  68. idx = len(tokens) - 1
  69. }
  70. return tokens[idx].id, nil
  71. }
  72. type greedy struct{}
  73. // Greedy sample returns the index of the maximum value in logits.
  74. func (s greedy) Sample(logits []float32) (int32, error) {
  75. if len(logits) == 0 {
  76. return -1, errors.New("no logits provided for greedy sampling")
  77. }
  78. maxIdx := 0
  79. maxVal := logits[0]
  80. for i := 1; i < len(logits); i++ {
  81. if logits[i] > maxVal {
  82. maxVal = logits[i]
  83. maxIdx = i
  84. }
  85. }
  86. return int32(maxIdx), nil
  87. }
  88. // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
  89. func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) Sampler {
  90. if temperature == 0 {
  91. return &greedy{}
  92. }
  93. var rng *rand.Rand
  94. if seed != -1 {
  95. // PCG requires two parameters: sequence and stream
  96. // Use original seed for sequence
  97. sequence := uint64(seed)
  98. // Use golden ratio hash to generate statistically independent seeds
  99. rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
  100. }
  101. temperature = max(temperature, 1)
  102. if topP < 0.0 {
  103. topP = 0.0
  104. }
  105. if topP >= 1.0 {
  106. topP = 1.0
  107. }
  108. if minP < 0.0 {
  109. minP = 0.0
  110. }
  111. if minP >= 1.0 {
  112. minP = 1.0
  113. }
  114. return &weighted{
  115. rng: rng,
  116. topK: topK,
  117. topP: topP,
  118. minP: minP,
  119. temperature: temperature,
  120. }
  121. }