pushdown_runner.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. package sample
  2. import (
  3. "fmt"
  4. "math"
  5. "runtime"
  6. "time"
  7. "github.com/ollama/ollama/model"
  8. )
  9. // TODO: safety in case of invalid json
  10. // TODO: interfaces to cleanup with return values
  11. // TODO this interface shouldn't be the sampler - should just use Sampler
  12. // TODO: add penalties for string \n stuff
  13. type PushdownSampler struct {
  14. PDAGraphBuilder
  15. curNode *PDA
  16. braceStack []rune
  17. stateCounter uint32
  18. }
  19. // graph should be built once and reused per tokenizer
  20. func NewPushdownSampler(proc model.TextProcessor) (*PushdownSampler, error) {
  21. start := time.Now()
  22. fmt.Println("--------------------------------")
  23. fmt.Println("PDA sampler")
  24. fmt.Println("--------------------------------")
  25. var m runtime.MemStats
  26. runtime.ReadMemStats(&m)
  27. before := m.Alloc
  28. fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
  29. vocab := proc.GetVocabulary()
  30. decodedToks := make([]string, len(vocab.Values))
  31. for i := range vocab.Values {
  32. token, err := proc.Decode([]int32{int32(i)})
  33. if err != nil {
  34. return nil, err
  35. }
  36. decodedToks[i] = token
  37. }
  38. gb := &PDAGraphBuilder{
  39. proc: proc,
  40. decodedToks: decodedToks,
  41. }
  42. if err := gb.BuildGraph(); err != nil {
  43. return nil, err
  44. }
  45. runtime.ReadMemStats(&m)
  46. after := m.Alloc
  47. fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
  48. fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
  49. fmt.Printf("Graph build time = %v\n", time.Since(start))
  50. // TODO: this can be simplified
  51. return &PushdownSampler{
  52. curNode: gb.stateToNodeMap[StateStart],
  53. PDAGraphBuilder: *gb,
  54. braceStack: []rune{},
  55. stateCounter: 0,
  56. }, nil
  57. }
  58. // TODO: need to add resampling logic if the first sample was not good
  59. // greedy sample + backtrack?
  60. func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
  61. switch s.curNode.State {
  62. case StateInString:
  63. return s.maskLogits(logits, s.curNode)
  64. case StateInListEnd:
  65. // force finish if no braces left
  66. if len(s.braceStack) == 0 {
  67. s.curNode = NewPDANode(StateTerminate)
  68. return forceFinish(s, logits)
  69. }
  70. logits, err := s.maskLogits(logits, s.curNode)
  71. if err != nil {
  72. return nil, err
  73. }
  74. return logits, nil
  75. case StateTerminate:
  76. return forceFinish(s, logits)
  77. case StateInObjectEnd:
  78. // force finish if no braces left
  79. if len(s.braceStack) == 0 {
  80. s.curNode = NewPDANode(StateTerminate)
  81. return forceFinish(s, logits)
  82. }
  83. peek := s.braceStack[len(s.braceStack)-1]
  84. if peek == rune('[') {
  85. s.curNode = s.stateToNodeMap[StateInListObjectEnd]
  86. }
  87. logits, err := s.maskLogits(logits, s.curNode)
  88. if err != nil {
  89. return nil, err
  90. }
  91. return logits, nil
  92. case StateInComma:
  93. peek := s.braceStack[len(s.braceStack)-1]
  94. if peek == rune('[') {
  95. s.curNode = s.stateToNodeMap[StateInListComma]
  96. }
  97. logits, err := s.maskLogits(logits, s.curNode)
  98. if err != nil {
  99. return nil, err
  100. }
  101. return logits, nil
  102. default:
  103. fmt.Println("masking logits current state", s.curNode.State)
  104. logits, err := s.maskLogits(logits, s.curNode)
  105. if err != nil {
  106. return nil, err
  107. }
  108. return logits, nil
  109. }
  110. }
  111. func forceFinish(s *PushdownSampler, logits []float64) ([]float64, error) {
  112. for i := range logits {
  113. if s.proc.Is(uint32(i), model.SpecialEOS) {
  114. logits[i] = 1.0
  115. } else {
  116. logits[i] = math.Inf(-1)
  117. }
  118. }
  119. return logits, nil
  120. }
  121. func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
  122. fmt.Println("current state - updating", s.curNode.State)
  123. mappedString, err := s.proc.Decode(tokenSlice)
  124. if err != nil {
  125. return err
  126. }
  127. fmt.Printf(">>> mappedString: %q\n", mappedString)
  128. // TODO: should force closing for all braces - not doing square yet
  129. for _, r := range mappedString {
  130. if r == rune('{') {
  131. s.braceStack = append(s.braceStack, r)
  132. }
  133. if r == rune('[') {
  134. s.braceStack = append(s.braceStack, r)
  135. }
  136. if r == rune('}') {
  137. if len(s.braceStack) == 0 {
  138. return fmt.Errorf("stack is empty, extra closing brace %c", r)
  139. }
  140. top := s.braceStack[len(s.braceStack)-1]
  141. if top != rune('{') {
  142. return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
  143. }
  144. s.braceStack = s.braceStack[:len(s.braceStack)-1]
  145. }
  146. if r == rune(']') {
  147. if len(s.braceStack) == 0 {
  148. return fmt.Errorf("stack is empty, extra closing brace %c", r)
  149. }
  150. top := s.braceStack[len(s.braceStack)-1]
  151. if top != rune('[') {
  152. return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
  153. }
  154. s.braceStack = s.braceStack[:len(s.braceStack)-1]
  155. }
  156. }
  157. for _, tokenID := range tokenSlice {
  158. // transition to the next node
  159. nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
  160. if !ok {
  161. return fmt.Errorf("invalid token: %q", mappedString)
  162. }
  163. fmt.Println("transitioning to", nextNode.State)
  164. // TODO: add a penalty for staying in the same state too long
  165. if nextNode.State == s.curNode.State {
  166. s.stateCounter++
  167. } else {
  168. s.stateCounter = 0
  169. }
  170. s.curNode = nextNode
  171. fmt.Println("updated curNode state", s.curNode.State)
  172. }
  173. return nil
  174. }
  175. // greedy sample + backtrack?
  176. func (s *PushdownSampler) maskLogits(logits []float64, node *PDA) ([]float64, error) {
  177. // Create a new slice with same length as logits, initialized to -Inf
  178. maskedLogits := make([]float64, len(logits))
  179. for i := range maskedLogits {
  180. maskedLogits[i] = math.Inf(-1)
  181. }
  182. // Only update values for valid token IDs from the mask map
  183. for tokenID := range node.MaskTokenIDToNode {
  184. if int(tokenID) < len(logits) {
  185. maskedLogits[tokenID] = logits[tokenID]
  186. }
  187. }
  188. return maskedLogits, nil
  189. }
  190. func (s *PushdownSampler) fastMaskLogits(logits []float64, node *PDA) ([]float64, error) {
  191. maxLogit := math.Inf(-1)
  192. maxIndex := -1
  193. // Find the maximum logit value among valid tokens
  194. for tokenID := range node.MaskTokenIDToNode {
  195. if int(tokenID) < len(logits) && logits[tokenID] > maxLogit {
  196. maxLogit = logits[tokenID]
  197. maxIndex = int(tokenID)
  198. }
  199. }
  200. if maxIndex == -1 {
  201. return nil, fmt.Errorf("no valid tokens found in mask")
  202. }
  203. logits[0] = float64(maxIndex)
  204. return logits, nil
  205. // return maxIndex, nil
  206. }