pushdown_runner.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. package sample
  2. import (
  3. "fmt"
  4. "math"
  5. "runtime"
  6. "github.com/ollama/ollama/model"
  7. )
  8. // TODO: safety in case of invalid json
  9. // TODO: interfaces to cleanup with return values
  10. type PushdownSampler struct {
  11. // stateful
  12. curNode *PDANode
  13. proc model.TextProcessor
  14. stateToNodeMap map[JSONState]*PDANode
  15. braceStack []rune
  16. stateCounter uint32
  17. }
  18. // graph should be built once and reused per tokenizer
  19. func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
  20. // start := time.Now()
  21. // fmt.Println("--------------------------------")
  22. // fmt.Println("PDA sampler")
  23. // fmt.Println("--------------------------------")
  24. var m runtime.MemStats
  25. runtime.ReadMemStats(&m)
  26. // before := m.Alloc
  27. // fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
  28. startNode, stateToNodeMap, err := BuildGraph(proc)
  29. if err != nil {
  30. panic(err)
  31. }
  32. err = PreComputeValidStates(stateToNodeMap, proc)
  33. if err != nil {
  34. panic(err)
  35. }
  36. runtime.ReadMemStats(&m)
  37. // after := m.Alloc
  38. // fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
  39. // fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
  40. // fmt.Printf("Graph build time = %v\n", time.Since(start))
  41. return &PushdownSampler{
  42. curNode: startNode,
  43. proc: proc,
  44. stateToNodeMap: stateToNodeMap,
  45. braceStack: []rune{},
  46. stateCounter: 0,
  47. }
  48. }
  49. // TODO: need to add resampling logic if the first sample was not good
  50. // greedy sample + backtrack?
  51. func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
  52. // fmt.Println(">>> sample:", s.curNode.State)
  53. switch s.curNode.State {
  54. case StateInString:
  55. return s.maskLogits(logits, s.curNode)
  56. case StateInListEnd:
  57. // fmt.Println("in list end", s.braceStack)
  58. // force finish if no braces left
  59. if len(s.braceStack) == 0 {
  60. s.curNode = NewPDANode(StateTerminate)
  61. for i := range logits {
  62. if s.proc.Is(uint32(i), model.SpecialEOS) {
  63. logits[i] = 1.0
  64. } else {
  65. logits[i] = math.NaN()
  66. }
  67. }
  68. return logits, nil
  69. }
  70. logits, err := s.maskLogits(logits, s.curNode)
  71. if err != nil {
  72. return nil, err
  73. }
  74. return logits, nil
  75. case StateInObjectEnd:
  76. // force finish if no braces left
  77. if len(s.braceStack) == 0 {
  78. s.curNode = NewPDANode(StateTerminate)
  79. for i := range logits {
  80. if s.proc.Is(uint32(i), model.SpecialEOS) {
  81. logits[i] = 1.0
  82. } else {
  83. logits[i] = math.NaN()
  84. }
  85. }
  86. return logits, nil
  87. }
  88. peek := s.braceStack[len(s.braceStack)-1]
  89. if peek == rune('[') {
  90. s.curNode = s.stateToNodeMap[StateInListObjectEnd]
  91. // fmt.Println("switching to list object end", s.curNode.State)
  92. }
  93. logits, err := s.maskLogits(logits, s.curNode)
  94. if err != nil {
  95. return nil, err
  96. }
  97. return logits, nil
  98. case StateInComma:
  99. peek := s.braceStack[len(s.braceStack)-1]
  100. if peek == rune('[') {
  101. s.curNode = s.stateToNodeMap[StateInListComma]
  102. // fmt.Println("switching to list comma", s.curNode.State)
  103. }
  104. logits, err := s.maskLogits(logits, s.curNode)
  105. if err != nil {
  106. return nil, err
  107. }
  108. return logits, nil
  109. case StateTerminate:
  110. for i := range logits {
  111. if s.proc.Is(uint32(i), model.SpecialEOS) {
  112. logits[i] = 1.0
  113. } else {
  114. logits[i] = math.NaN()
  115. }
  116. }
  117. return logits, nil
  118. default:
  119. // fmt.Println("masking logits current state", s.curNode.State)
  120. logits, err := s.maskLogits(logits, s.curNode)
  121. if err != nil {
  122. return nil, err
  123. }
  124. return logits, nil
  125. }
  126. }
  127. func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
  128. // fmt.Println("current state - updating", s.curNode.State)
  129. mappedString, err := s.proc.Decode(tokenSlice)
  130. if err != nil {
  131. return err
  132. }
  133. // fmt.Println("mappedString", mappedString)
  134. // TODO: should force closing for all braces - not doing square yet
  135. for _, r := range mappedString {
  136. if r == rune('{') {
  137. s.braceStack = append(s.braceStack, r)
  138. // fmt.Println("pushing { brace stack", r)
  139. }
  140. if r == rune('[') {
  141. s.braceStack = append(s.braceStack, r)
  142. // fmt.Println("pushing [ brace stack", r)
  143. }
  144. if r == rune('}') {
  145. if len(s.braceStack) == 0 {
  146. return fmt.Errorf("stack is empty, extra closing brace %c", r)
  147. }
  148. top := s.braceStack[len(s.braceStack)-1]
  149. if top != rune('{') {
  150. return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
  151. }
  152. s.braceStack = s.braceStack[:len(s.braceStack)-1]
  153. // fmt.Println("popping { brace stack", top)
  154. }
  155. if r == rune(']') {
  156. if len(s.braceStack) == 0 {
  157. return fmt.Errorf("stack is empty, extra closing brace %c", r)
  158. }
  159. top := s.braceStack[len(s.braceStack)-1]
  160. if top != rune('[') {
  161. return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
  162. }
  163. s.braceStack = s.braceStack[:len(s.braceStack)-1]
  164. // fmt.Println("popping [ brace stack", top)
  165. }
  166. }
  167. for _, tokenID := range tokenSlice {
  168. // transition to the next node
  169. nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
  170. if !ok {
  171. return fmt.Errorf("invalid token: %q", mappedString)
  172. }
  173. // fmt.Println("transitioning to", nextNodeState)
  174. // TODO: add a penalty for staying in the same state too long
  175. if nextNode.State == s.curNode.State {
  176. s.stateCounter++
  177. } else {
  178. s.stateCounter = 0
  179. }
  180. s.curNode = nextNode
  181. // fmt.Println("updated curNode state", s.curNode.State)
  182. }
  183. return nil
  184. }
  185. // greedy sample + backtrack?
  186. func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
  187. // TODO: can be optimized by only masking the logits that are not in the node.MaskTokenIDToNode
  188. // Should be possible through bitwise ops as well
  189. for i := range logits {
  190. _, exists := node.MaskTokenIDToNode[int32(i)]
  191. if !exists {
  192. logits[i] = math.NaN()
  193. }
  194. }
  195. return logits, nil
  196. }
  197. // TODO: add penalties for string \n stuff