pushdown_runner.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. package sample
  2. import (
  3. "fmt"
  4. "math"
  5. "runtime"
  6. "time"
  7. "github.com/ollama/ollama/model"
  8. )
  9. type PushdownSampler struct {
  10. // stateful
  11. curNode *PDANode
  12. proc model.TextProcessor
  13. stateToNodeMap map[JSONState]*PDANode
  14. braceStack []rune
  15. stateCounter uint32
  16. }
  17. func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
  18. start := time.Now()
  19. var m runtime.MemStats
  20. runtime.ReadMemStats(&m)
  21. before := m.Alloc
  22. fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
  23. startNode, stateToNodeMap, err := BuildGraph(proc)
  24. if err != nil {
  25. panic(err)
  26. }
  27. err = PreComputeValidStates(stateToNodeMap, proc)
  28. if err != nil {
  29. panic(err)
  30. }
  31. runtime.ReadMemStats(&m)
  32. after := m.Alloc
  33. fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
  34. fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
  35. fmt.Printf("Graph build time = %v\n", time.Since(start))
  36. // for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode {
  37. // token, err := proc.Decode([]int32{int32(id)})
  38. // if err != nil {
  39. // panic(err)
  40. // }
  41. // fmt.Println("id", id, "node", node, "token", token)
  42. // }
  43. // time.Sleep(10 * time.Second)
  44. return &PushdownSampler{
  45. curNode: startNode,
  46. proc: proc,
  47. stateToNodeMap: stateToNodeMap,
  48. braceStack: []rune{},
  49. stateCounter: 0,
  50. }
  51. }
  52. func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
  53. fmt.Println("sample:", s.curNode.State)
  54. switch s.curNode.State {
  55. case StateInObjectEnd:
  56. // force finish if no braces left
  57. if len(s.braceStack) == 0 {
  58. s.curNode = NewPDANode(StateTerminate)
  59. for i := range logits {
  60. if s.proc.Is(uint32(i), model.SpecialEOS) {
  61. logits[i] = 1.0
  62. } else {
  63. logits[i] = math.NaN()
  64. }
  65. }
  66. return logits, nil
  67. }
  68. valid, err := s.proc.Encode("}")
  69. if err != nil {
  70. return nil, err
  71. }
  72. for i := range logits {
  73. for _, token := range valid {
  74. if i != int(token) {
  75. logits[i] = math.NaN()
  76. }
  77. }
  78. }
  79. return logits, nil
  80. case StateInComma:
  81. peek := s.braceStack[len(s.braceStack)-1]
  82. if peek == rune('[') {
  83. s.curNode = s.stateToNodeMap[StateInListComma]
  84. fmt.Println("switching to list comma", s.curNode.State)
  85. }
  86. logits, err := s.maskLogits(logits, s.curNode)
  87. if err != nil {
  88. return nil, err
  89. }
  90. return logits, nil
  91. case StateTerminate:
  92. for i := range logits {
  93. if s.proc.Is(uint32(i), model.SpecialEOS) {
  94. logits[i] = 1.0
  95. } else {
  96. logits[i] = math.NaN()
  97. }
  98. }
  99. return logits, nil
  100. default:
  101. fmt.Println("masking logits current state", s.curNode.State)
  102. logits, err := s.maskLogits(logits, s.curNode)
  103. if err != nil {
  104. return nil, err
  105. }
  106. return logits, nil
  107. }
  108. }
  109. func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
  110. fmt.Println("update state", s.curNode.State)
  111. // TODO: need to handle end states and entering object case, and list case
  112. if s.curNode.State == StateInObjectEnd {
  113. fmt.Println("in object end")
  114. if len(s.braceStack) > 0 {
  115. s.braceStack = s.braceStack[:len(s.braceStack)-1]
  116. return nil
  117. }
  118. s.curNode = NewPDANode(StateTerminate)
  119. // TODO: return here?
  120. }
  121. // need this cause there could be multiple transitions
  122. mappedString, err := s.proc.Decode(tokenSlice)
  123. if err != nil {
  124. return err
  125. }
  126. // TODO: should force closing for all braces
  127. for _, r := range mappedString {
  128. if r == rune('{') {
  129. s.braceStack = append(s.braceStack, r)
  130. }
  131. if r == rune('[') {
  132. s.braceStack = append(s.braceStack, r)
  133. }
  134. if r == rune('}') {
  135. if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') {
  136. return fmt.Errorf("unmatched closing brace")
  137. }
  138. s.braceStack = s.braceStack[:len(s.braceStack)-1]
  139. fmt.Println("popping brace stack", s.braceStack)
  140. }
  141. if r == rune(']') {
  142. if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('[') {
  143. return fmt.Errorf("unmatched closing brace")
  144. }
  145. s.braceStack = s.braceStack[:len(s.braceStack)-1]
  146. fmt.Println("popping brace stack", s.braceStack)
  147. }
  148. }
  149. for _, tokenID := range tokenSlice {
  150. // transition to the next node
  151. nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID]
  152. if !ok {
  153. return fmt.Errorf("invalid token: %q", mappedString)
  154. }
  155. fmt.Println("transitioning to", nextNodeState)
  156. // TODO: add a penalty for staying in the same state too long
  157. if nextNodeState == s.curNode.State {
  158. s.stateCounter++
  159. } else {
  160. s.stateCounter = 0
  161. }
  162. s.curNode = s.stateToNodeMap[nextNodeState]
  163. }
  164. return nil
  165. }
  166. func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
  167. for i := range logits {
  168. _, exists := node.MaskTokenIDToNode[int32(i)]
  169. if !exists {
  170. logits[i] = math.NaN()
  171. }
  172. }
  173. return logits, nil
  174. }
  175. // TODO: add penalties for string \n stuff