pushdown_automata.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. package sample
  2. import (
  3. "slices"
  4. "github.com/ollama/ollama/model"
  5. )
  6. // TODO: / should be valid but an escape character
  7. var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
  8. var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
  9. var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
  10. var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
  11. var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
  12. var validNullRunes = []rune{'n', 'u', 'l', 'l'}
  13. type PDANode struct {
  14. State JSONState
  15. TransitionEdges map[rune]*PDANode
  16. MaskTokenIDToNode map[int32]*PDANode
  17. }
  18. func NewPDANode(state JSONState) *PDANode {
  19. return &PDANode{
  20. State: state,
  21. TransitionEdges: make(map[rune]*PDANode),
  22. MaskTokenIDToNode: make(map[int32]*PDANode),
  23. }
  24. }
  25. func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
  26. stateToNodeMap := make(map[JSONState]*PDANode)
  27. // TODO: make this a loop
  28. for _, state := range JSONStates {
  29. stateToNodeMap[state] = NewPDANode(state)
  30. }
  31. // TODO:
  32. // consider adding a node to just point to values, could be good to compute that
  33. // mask rather than many different nodes
  34. stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
  35. stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
  36. stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
  37. stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
  38. stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
  39. //new line
  40. stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
  41. stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
  42. stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
  43. stateToNodeMap[StateInObjectKey].TransitionEdges[rune(-1)] = stateToNodeMap[StateInObjectKey]
  44. stateToNodeMap[StateInObjectKey].TransitionEdges['"'] = stateToNodeMap[StateInObjectKeyEnd]
  45. stateToNodeMap[StateInObjectKeyEnd].TransitionEdges[':'] = stateToNodeMap[StateInColon]
  46. stateToNodeMap[StateInObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
  47. stateToNodeMap[StateInObjectEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
  48. // where values should be
  49. // this could be combined but the probl might change, we're alr doing a skip ahead
  50. stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpace]
  51. stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
  52. stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
  53. addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
  54. // Leads to a value
  55. stateToNodeMap[StateInSpace].TransitionEdges['['] = stateToNodeMap[StateInList]
  56. stateToNodeMap[StateInSpace].TransitionEdges['{'] = stateToNodeMap[StateInObject]
  57. addValueConnections(stateToNodeMap[StateInSpace], stateToNodeMap)
  58. // Values
  59. // string node
  60. stateToNodeMap[StateInString].TransitionEdges[rune(-1)] = stateToNodeMap[StateInString]
  61. stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
  62. // String end node
  63. addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
  64. // TODO: add counters for allowable number of decimals, e, E, etc
  65. // number node
  66. for _, r := range validNumberRunes {
  67. stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
  68. }
  69. addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
  70. // bool node
  71. for _, r := range validBoolRunes {
  72. stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
  73. }
  74. addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
  75. // list node
  76. stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
  77. stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
  78. stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
  79. stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
  80. // empty list
  81. stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
  82. addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
  83. // null node
  84. for _, r := range validNullRunes {
  85. stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
  86. }
  87. addEnds(stateToNodeMap[StateInNull], stateToNodeMap)
  88. // list comma
  89. // should point to values
  90. stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
  91. stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
  92. stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
  93. addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
  94. // list object end
  95. stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
  96. stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
  97. // bool node
  98. for _, r := range validBoolRunes {
  99. stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
  100. }
  101. addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
  102. stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
  103. stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
  104. stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
  105. stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
  106. stateToNodeMap[StateInComma].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
  107. stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
  108. stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
  109. stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
  110. stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
  111. return stateToNodeMap[StateStart], stateToNodeMap, nil
  112. }
  113. func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
  114. node.TransitionEdges[','] = stateToNodeMap[StateInComma]
  115. node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
  116. node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
  117. }
  118. func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
  119. node.TransitionEdges['"'] = stateToNodeMap[StateInString]
  120. for _, r := range validNumberRunes {
  121. node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
  122. }
  123. node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
  124. node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
  125. node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
  126. }
  127. // TODO: tough life fr. plz fix.
  128. func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
  129. // TODO; should come from top level
  130. vocab := proc.GetVocabulary()
  131. decodedToks := make([]string, len(vocab.Values))
  132. for i := range vocab.Values {
  133. token, err := proc.Decode([]int32{int32(i)})
  134. if err != nil {
  135. return err
  136. }
  137. decodedToks[i] = token
  138. }
  139. var err error
  140. for _, node := range stateToNodeMap {
  141. err = CreateMask(node, proc, decodedToks, vocab)
  142. if err != nil {
  143. return err
  144. }
  145. }
  146. return nil
  147. }
  148. func CreateMask(node *PDANode, proc model.TextProcessor, decodedToks []string, vocab *model.Vocabulary) error {
  149. for i := range vocab.Values {
  150. token := decodedToks[i]
  151. // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
  152. if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
  153. continue
  154. }
  155. valid := true
  156. curNode := node
  157. consumedSpecialRunes := make(map[rune]bool)
  158. var err error
  159. for _, r := range token {
  160. valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
  161. if err != nil {
  162. return err
  163. }
  164. if !valid {
  165. break
  166. }
  167. }
  168. if valid {
  169. // cur node allows skipping
  170. node.MaskTokenIDToNode[int32(i)] = curNode
  171. }
  172. }
  173. return nil
  174. }
  175. // TODO: garbage interface plz fix
  176. func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
  177. if consumedSpecialRunes[r] {
  178. return false, nil, nil
  179. }
  180. specialRune := slices.Contains(stringInvalidRunes, r)
  181. if specialRune {
  182. if curNode.State == StateInString || curNode.State == StateInObjectKey {
  183. return false, nil, nil
  184. }
  185. }
  186. // Check for specific rune transition
  187. if nextNode, ok := curNode.TransitionEdges[r]; ok {
  188. if specialRune {
  189. if curNode.State == nextNode.State {
  190. return false, nil, nil
  191. }
  192. consumedSpecialRunes[r] = true
  193. }
  194. return true, nextNode, nil
  195. }
  196. // Check for sentinel value - if present, any rune is valid
  197. if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
  198. return true, nextNode, nil
  199. }
  200. return false, nil, nil
  201. }