pushdown_runner.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. package sample
  2. import (
  3. "fmt"
  4. "math"
  5. "runtime"
  6. "time"
  7. "github.com/ollama/ollama/model"
  8. )
  9. // TODO: safety in case of invalid json
  10. // TODO: partial JSON matching?
  11. // TODO: interfaces to cleanup with return values
  12. // TODO this interface shouldn't be the sampler - should just use Sampler
  13. // TODO: add penalties for string \n stuff
  14. // TODO: minimize number of fwd passes if there is only one match
  15. // TODO: greedy sample initially and then backtrack if no match
  16. type PushdownSampler struct {
  17. PDAGraphBuilder
  18. curNode *PDA
  19. braceStack []rune
  20. stateCounter uint32
  21. }
  22. // graph should be built once and reused per tokenizer
  23. func NewPushdownSampler(proc model.TextProcessor) (*PushdownSampler, error) {
  24. start := time.Now()
  25. fmt.Println("--------------------------------")
  26. fmt.Println("PDA sampler")
  27. fmt.Println("--------------------------------")
  28. var m runtime.MemStats
  29. runtime.ReadMemStats(&m)
  30. before := m.Alloc
  31. fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
  32. vocab := proc.Vocab()
  33. decodedToks := make([]string, len(vocab.Values))
  34. for i := range vocab.Values {
  35. token, err := proc.Decode([]int32{int32(i)})
  36. if err != nil {
  37. return nil, err
  38. }
  39. decodedToks[i] = token
  40. }
  41. gb := &PDAGraphBuilder{
  42. proc: proc,
  43. decodedToks: decodedToks,
  44. }
  45. if err := gb.BuildGraph(); err != nil {
  46. return nil, err
  47. }
  48. runtime.ReadMemStats(&m)
  49. after := m.Alloc
  50. fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
  51. fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
  52. fmt.Printf("Graph build time = %v\n", time.Since(start))
  53. // TODO: this can be simplified
  54. return &PushdownSampler{
  55. curNode: gb.stateToNodeMap[StateStart],
  56. PDAGraphBuilder: *gb,
  57. braceStack: []rune{},
  58. stateCounter: 0,
  59. }, nil
  60. }
  61. // TODO: need to add resampling logic if the first sample was not good
  62. // greedy sample + backtrack?
  63. func (s *PushdownSampler) Apply(logits []float32) ([]float32, error) {
  64. switch s.curNode.State {
  65. case StateInString:
  66. return s.maskLogits(logits, s.curNode)
  67. case StateInListEnd:
  68. // force finish if no braces left
  69. if len(s.braceStack) == 0 {
  70. s.curNode = NewPDANode(StateTerminate)
  71. return forceFinish(s, logits)
  72. }
  73. logits, err := s.maskLogits(logits, s.curNode)
  74. if err != nil {
  75. return nil, err
  76. }
  77. return logits, nil
  78. case StateTerminate:
  79. return forceFinish(s, logits)
  80. case StateInObjectEnd:
  81. // force finish if no braces left
  82. if len(s.braceStack) == 0 {
  83. s.curNode = NewPDANode(StateTerminate)
  84. return forceFinish(s, logits)
  85. }
  86. peek := s.braceStack[len(s.braceStack)-1]
  87. if peek == rune('[') {
  88. s.curNode = s.stateToNodeMap[StateInListObjectEnd]
  89. }
  90. logits, err := s.maskLogits(logits, s.curNode)
  91. if err != nil {
  92. return nil, err
  93. }
  94. return logits, nil
  95. case StateInComma:
  96. peek := s.braceStack[len(s.braceStack)-1]
  97. if peek == rune('[') {
  98. s.curNode = s.stateToNodeMap[StateInListComma]
  99. }
  100. logits, err := s.maskLogits(logits, s.curNode)
  101. if err != nil {
  102. return nil, err
  103. }
  104. return logits, nil
  105. default:
  106. fmt.Println("masking logits current state", s.curNode.State)
  107. logits, err := s.maskLogits(logits, s.curNode)
  108. if err != nil {
  109. return nil, err
  110. }
  111. return logits, nil
  112. }
  113. }
  114. func forceFinish(s *PushdownSampler, logits []float32) ([]float32, error) {
  115. for i := range logits {
  116. if s.proc.Is(int32(i), model.SpecialEOS) {
  117. logits[i] = 1.0
  118. } else {
  119. logits[i] = float32(math.Inf(-1))
  120. }
  121. }
  122. return logits, nil
  123. }
  124. func (s *PushdownSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
  125. fmt.Println("current state - updating", s.curNode.State)
  126. mappedString, err := s.proc.Decode(tokenSlice)
  127. if err != nil {
  128. return nil, err
  129. }
  130. fmt.Printf(">>> mappedString: %q\n", mappedString)
  131. // Special handling for EOS token in terminate state
  132. if s.curNode.State == StateTerminate {
  133. for _, tokenID := range tokenSlice {
  134. if s.proc.Is(tokenID, model.SpecialEOS) {
  135. return tokenSlice, nil
  136. }
  137. }
  138. }
  139. // flag := -1
  140. // endBraceRunes := []rune{'}', ']'}
  141. for _, r := range mappedString {
  142. // TODO: if this is enabled again, make sure to appropriately handle the state transitions
  143. // if slices.Contains(endBraceRunes, r) && len(s.braceStack) == 0 {
  144. // fmt.Printf("stack is empty, extra closing brace %c\n", r)
  145. // // flag = i
  146. // break
  147. // }
  148. if r == rune('{') {
  149. s.braceStack = append(s.braceStack, r)
  150. }
  151. if r == rune('[') {
  152. s.braceStack = append(s.braceStack, r)
  153. }
  154. if r == rune('}') {
  155. if len(s.braceStack) == 0 {
  156. return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
  157. }
  158. top := s.braceStack[len(s.braceStack)-1]
  159. if top != rune('{') {
  160. return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
  161. }
  162. s.braceStack = s.braceStack[:len(s.braceStack)-1]
  163. }
  164. if r == rune(']') {
  165. if len(s.braceStack) == 0 {
  166. return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
  167. }
  168. top := s.braceStack[len(s.braceStack)-1]
  169. if top != rune('[') {
  170. return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
  171. }
  172. s.braceStack = s.braceStack[:len(s.braceStack)-1]
  173. }
  174. }
  175. // if flag != -1 {
  176. // tokenSlice = tokenSlice[:flag]
  177. // }
  178. // fmt.Println("flag!", flag)
  179. for _, tokenID := range tokenSlice {
  180. // transition to the next node
  181. nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
  182. if !ok {
  183. return nil, fmt.Errorf("invalid token: %q", mappedString)
  184. }
  185. fmt.Println("transitioning to", nextNode.State)
  186. // TODO: add a penalty for staying in the same state too long
  187. if nextNode.State == s.curNode.State {
  188. s.stateCounter++
  189. } else {
  190. s.stateCounter = 0
  191. }
  192. s.curNode = nextNode
  193. fmt.Println("updated curNode state", s.curNode.State)
  194. }
  195. return tokenSlice, nil
  196. }
  197. // greedy sample + backtrack?
  198. func (s *PushdownSampler) maskLogits(logits []float32, node *PDA) ([]float32, error) {
  199. // Create a new slice with same length as logits, initialized to -Inf
  200. maskedLogits := make([]float32, len(logits))
  201. for i := range maskedLogits {
  202. maskedLogits[i] = float32(math.Inf(-1))
  203. }
  204. // Only update values for valid token IDs from the mask map
  205. for tokenID := range node.MaskTokenIDToNode {
  206. if int(tokenID) < len(logits) {
  207. maskedLogits[tokenID] = logits[tokenID]
  208. }
  209. }
  210. return maskedLogits, nil
  211. }
  212. func (s *PushdownSampler) fastMaskLogits(logits []float32, node *PDA) ([]float32, error) {
  213. maxLogit := float32(math.Inf(-1))
  214. maxIndex := -1
  215. // Find the maximum logit value among valid tokens
  216. for tokenID := range node.MaskTokenIDToNode {
  217. if int(tokenID) < len(logits) && logits[tokenID] > maxLogit {
  218. maxLogit = logits[tokenID]
  219. maxIndex = int(tokenID)
  220. }
  221. }
  222. if maxIndex == -1 {
  223. return nil, fmt.Errorf("no valid tokens found in mask")
  224. }
  225. logits[0] = float32(maxIndex)
  226. return logits, nil
  227. // return maxIndex, nil
  228. }