pushdown_runner.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. package sample
  2. import (
  3. "fmt"
  4. "math"
  5. "runtime"
  6. "time"
  7. "github.com/ollama/ollama/model"
  8. )
  9. // TODO: safety in case of invalid json
  10. // TODO: interfaces to cleanup with return values
  11. type PushdownSampler struct {
  12. // stateful
  13. curNode *PDANode
  14. proc model.TextProcessor
  15. stateToNodeMap map[JSONState]*PDANode
  16. braceStack []rune
  17. stateCounter uint32
  18. }
  19. // graph should be built once and reused per tokenizer
  20. func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
  21. start := time.Now()
  22. var m runtime.MemStats
  23. runtime.ReadMemStats(&m)
  24. before := m.Alloc
  25. fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
  26. startNode, stateToNodeMap, err := BuildGraph(proc)
  27. if err != nil {
  28. panic(err)
  29. }
  30. err = PreComputeValidStates(stateToNodeMap, proc)
  31. if err != nil {
  32. panic(err)
  33. }
  34. runtime.ReadMemStats(&m)
  35. after := m.Alloc
  36. fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
  37. fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
  38. fmt.Printf("Graph build time = %v\n", time.Since(start))
  39. return &PushdownSampler{
  40. curNode: startNode,
  41. proc: proc,
  42. stateToNodeMap: stateToNodeMap,
  43. braceStack: []rune{},
  44. stateCounter: 0,
  45. }
  46. }
  47. // TODO: need to add resampling logic if the first sample was not good
  48. func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
  49. // fmt.Println(">>> sample:", s.curNode.State)
  50. switch s.curNode.State {
  51. case StateInString:
  52. return s.maskLogits(logits, s.curNode)
  53. case StateInListEnd:
  54. fmt.Println("in list end", s.braceStack)
  55. // force finish if no braces left
  56. if len(s.braceStack) == 0 {
  57. s.curNode = NewPDANode(StateTerminate)
  58. for i := range logits {
  59. if s.proc.Is(uint32(i), model.SpecialEOS) {
  60. logits[i] = 1.0
  61. } else {
  62. logits[i] = math.NaN()
  63. }
  64. }
  65. return logits, nil
  66. }
  67. logits, err := s.maskLogits(logits, s.curNode)
  68. if err != nil {
  69. return nil, err
  70. }
  71. return logits, nil
  72. case StateInObjectEnd:
  73. // force finish if no braces left
  74. if len(s.braceStack) == 0 {
  75. s.curNode = NewPDANode(StateTerminate)
  76. for i := range logits {
  77. if s.proc.Is(uint32(i), model.SpecialEOS) {
  78. logits[i] = 1.0
  79. } else {
  80. logits[i] = math.NaN()
  81. }
  82. }
  83. return logits, nil
  84. }
  85. peek := s.braceStack[len(s.braceStack)-1]
  86. if peek == rune('[') {
  87. s.curNode = s.stateToNodeMap[StateInListObjectEnd]
  88. // fmt.Println("switching to list object end", s.curNode.State)
  89. }
  90. logits, err := s.maskLogits(logits, s.curNode)
  91. if err != nil {
  92. return nil, err
  93. }
  94. return logits, nil
  95. case StateInComma:
  96. peek := s.braceStack[len(s.braceStack)-1]
  97. if peek == rune('[') {
  98. s.curNode = s.stateToNodeMap[StateInListComma]
  99. // fmt.Println("switching to list comma", s.curNode.State)
  100. }
  101. logits, err := s.maskLogits(logits, s.curNode)
  102. if err != nil {
  103. return nil, err
  104. }
  105. return logits, nil
  106. case StateTerminate:
  107. for i := range logits {
  108. if s.proc.Is(uint32(i), model.SpecialEOS) {
  109. logits[i] = 1.0
  110. } else {
  111. logits[i] = math.NaN()
  112. }
  113. }
  114. return logits, nil
  115. default:
  116. // fmt.Println("masking logits current state", s.curNode.State)
  117. logits, err := s.maskLogits(logits, s.curNode)
  118. if err != nil {
  119. return nil, err
  120. }
  121. return logits, nil
  122. }
  123. }
  124. func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
  125. fmt.Println("update state", s.curNode.State)
  126. mappedString, err := s.proc.Decode(tokenSlice)
  127. if err != nil {
  128. return err
  129. }
  130. fmt.Println("mappedString", mappedString)
  131. // TODO: should force closing for all braces - not doing square yet
  132. for _, r := range mappedString {
  133. if r == rune('{') {
  134. s.braceStack = append(s.braceStack, r)
  135. // fmt.Println("pushing { brace stack", r)
  136. }
  137. if r == rune('[') {
  138. s.braceStack = append(s.braceStack, r)
  139. // fmt.Println("pushing [ brace stack", r)
  140. }
  141. if r == rune('}') {
  142. if len(s.braceStack) == 0 {
  143. return fmt.Errorf("stack is empty, extra closing brace %c", r)
  144. }
  145. top := s.braceStack[len(s.braceStack)-1]
  146. if top != rune('{') {
  147. return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
  148. }
  149. s.braceStack = s.braceStack[:len(s.braceStack)-1]
  150. // fmt.Println("popping { brace stack", top)
  151. }
  152. if r == rune(']') {
  153. if len(s.braceStack) == 0 {
  154. return fmt.Errorf("stack is empty, extra closing brace %c", r)
  155. }
  156. top := s.braceStack[len(s.braceStack)-1]
  157. if top != rune('[') {
  158. return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
  159. }
  160. s.braceStack = s.braceStack[:len(s.braceStack)-1]
  161. // fmt.Println("popping [ brace stack", top)
  162. }
  163. }
  164. for _, tokenID := range tokenSlice {
  165. // transition to the next node
  166. nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID]
  167. if !ok {
  168. return fmt.Errorf("invalid token: %q", mappedString)
  169. }
  170. // fmt.Println("transitioning to", nextNodeState)
  171. // TODO: add a penalty for staying in the same state too long
  172. if nextNodeState == s.curNode.State {
  173. s.stateCounter++
  174. } else {
  175. s.stateCounter = 0
  176. }
  177. s.curNode = s.stateToNodeMap[nextNodeState]
  178. }
  179. return nil
  180. }
  181. func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
  182. // TODO: can be optimized by only masking the logits that are not in the node.MaskTokenIDToNode
  183. // Should be possible through bitwise ops as well
  184. for i := range logits {
  185. _, exists := node.MaskTokenIDToNode[int32(i)]
  186. if !exists {
  187. logits[i] = math.NaN()
  188. }
  189. }
  190. return logits, nil
  191. }
  192. // TODO: add penalties for string \n stuff