pushdown_automata.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. package sample
  2. import (
  3. "slices"
  4. "github.com/ollama/ollama/model"
  5. )
  6. // TODO: / should be valid but an escape character
  7. var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
  8. var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
  9. var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
  10. var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
  11. var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
  12. var validNullRunes = []rune{'n', 'u', 'l', 'l'}
  13. type PDANode struct {
  14. State JSONState
  15. TransitionEdges map[rune]*PDANode
  16. MaskTokenIDToNode map[int32]JSONState
  17. }
  18. func NewPDANode(state JSONState) *PDANode {
  19. return &PDANode{
  20. State: state,
  21. TransitionEdges: make(map[rune]*PDANode),
  22. MaskTokenIDToNode: make(map[int32]JSONState),
  23. }
  24. }
  25. func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
  26. stateToNodeMap := make(map[JSONState]*PDANode)
  27. // TODO: make this a loop
  28. for _, state := range JSONStates {
  29. stateToNodeMap[state] = NewPDANode(state)
  30. }
  31. // TODO:
  32. // consider adding a node to just point to values, could be good to compute that
  33. // mask rather than many different nodes
  34. stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
  35. stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
  36. stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
  37. stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
  38. stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
  39. //new line
  40. stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
  41. stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
  42. stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
  43. stateToNodeMap[StateInObjectKey].TransitionEdges[rune(-1)] = stateToNodeMap[StateInObjectKey]
  44. stateToNodeMap[StateInObjectKey].TransitionEdges['"'] = stateToNodeMap[StateInObjectKeyEnd]
  45. stateToNodeMap[StateInObjectKeyEnd].TransitionEdges[':'] = stateToNodeMap[StateInColon]
  46. stateToNodeMap[StateInObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
  47. stateToNodeMap[StateInObjectEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
  48. // where values should be
  49. // this could be combined but the probl might change, we're alr doing a skip ahead
  50. stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpace]
  51. stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
  52. stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
  53. addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
  54. // Leads to a value
  55. stateToNodeMap[StateInSpace].TransitionEdges['['] = stateToNodeMap[StateInList]
  56. stateToNodeMap[StateInSpace].TransitionEdges['{'] = stateToNodeMap[StateInObject]
  57. addValueConnections(stateToNodeMap[StateInSpace], stateToNodeMap)
  58. // Values
  59. // string node
  60. stateToNodeMap[StateInString].TransitionEdges[rune(-1)] = stateToNodeMap[StateInString]
  61. stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
  62. // String end node
  63. addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
  64. // TODO: add counters for allowable number of decimals, e, E, etc
  65. // number node
  66. for _, r := range validNumberRunes {
  67. stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
  68. }
  69. addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
  70. // bool node
  71. for _, r := range validBoolRunes {
  72. stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
  73. }
  74. addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
  75. // list node
  76. stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
  77. stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
  78. stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
  79. stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
  80. addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
  81. // null node
  82. for _, r := range validNullRunes {
  83. stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
  84. }
  85. addEnds(stateToNodeMap[StateInNull], stateToNodeMap)
  86. // list comma
  87. // should point to values
  88. stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
  89. stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
  90. stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
  91. addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
  92. // list object end
  93. stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
  94. stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
  95. // bool node
  96. for _, r := range validBoolRunes {
  97. stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
  98. }
  99. addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
  100. stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
  101. stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
  102. stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
  103. stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
  104. stateToNodeMap[StateInComma].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
  105. stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
  106. stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
  107. stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
  108. stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
  109. return stateToNodeMap[StateStart], stateToNodeMap, nil
  110. }
  111. func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
  112. node.TransitionEdges[','] = stateToNodeMap[StateInComma]
  113. node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
  114. node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
  115. }
  116. func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
  117. node.TransitionEdges['"'] = stateToNodeMap[StateInString]
  118. for _, r := range validNumberRunes {
  119. node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
  120. }
  121. node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
  122. node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
  123. node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
  124. }
  125. // TODO: tough life fr. plz fix.
  126. func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
  127. vocab := proc.GetVocabulary()
  128. decodedToks := make([]string, len(vocab.Values))
  129. for i := range vocab.Values {
  130. token, err := proc.Decode([]int32{int32(i)})
  131. if err != nil {
  132. return err
  133. }
  134. decodedToks[i] = token
  135. }
  136. var err error
  137. for _, node := range stateToNodeMap {
  138. err = createMask(node, proc, decodedToks, vocab)
  139. if err != nil {
  140. return err
  141. }
  142. }
  143. return nil
  144. }
  145. func createMask(node *PDANode, proc model.TextProcessor, decodedToks []string, vocab *model.Vocabulary) error {
  146. for i := range vocab.Values {
  147. token := decodedToks[i]
  148. // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
  149. if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
  150. continue
  151. }
  152. valid := true
  153. curNode := node
  154. consumedSpecialRunes := make(map[rune]bool)
  155. var err error
  156. for _, r := range token {
  157. valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
  158. if err != nil {
  159. return err
  160. }
  161. if !valid {
  162. break
  163. }
  164. }
  165. if valid {
  166. node.MaskTokenIDToNode[int32(i)] = curNode.State
  167. }
  168. }
  169. return nil
  170. }
  171. // TODO: garbage interface plz fix
  172. func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
  173. if consumedSpecialRunes[r] {
  174. return false, nil, nil
  175. }
  176. specialRune := slices.Contains(stringInvalidRunes, r)
  177. if specialRune {
  178. if curNode.State == StateInString || curNode.State == StateInObjectKey {
  179. return false, nil, nil
  180. }
  181. }
  182. // Check for specific rune transition
  183. if nextNode, ok := curNode.TransitionEdges[r]; ok {
  184. if specialRune {
  185. if curNode.State == nextNode.State {
  186. return false, nil, nil
  187. }
  188. consumedSpecialRunes[r] = true
  189. }
  190. return true, nextNode, nil
  191. }
  192. // Check for sentinel value - if present, any rune is valid
  193. if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
  194. return true, nextNode, nil
  195. }
  196. return false, nil, nil
  197. }