pushdown_automata.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. package sample
  2. import (
  3. "slices"
  4. "github.com/ollama/ollama/model"
  5. )
  6. var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
  7. type PDANode struct {
  8. State JSONState
  9. TransitionEdges map[rune]*PDANode
  10. MaskTokenIDToNode map[int32]JSONState
  11. }
  12. func NewPDANode(state JSONState) *PDANode {
  13. return &PDANode{
  14. State: state,
  15. TransitionEdges: make(map[rune]*PDANode),
  16. MaskTokenIDToNode: make(map[int32]JSONState),
  17. }
  18. }
  19. func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
  20. stateToNodeMap := make(map[JSONState]*PDANode)
  21. startNode := NewPDANode(StateStart)
  22. stateToNodeMap[StateStart] = startNode
  23. objNode := NewPDANode(StateInObject)
  24. stateToNodeMap[StateInObject] = objNode
  25. objEndNode := NewPDANode(StateInObjectEnd)
  26. stateToNodeMap[StateInObjectEnd] = objEndNode
  27. objKeyNode := NewPDANode(StateInObjectKey)
  28. stateToNodeMap[StateInObjectKey] = objKeyNode
  29. objKeyEndNode := NewPDANode(StateInObjectKeyEnd)
  30. stateToNodeMap[StateInObjectKeyEnd] = objKeyEndNode
  31. colonNode := NewPDANode(StateInColon)
  32. stateToNodeMap[StateInColon] = colonNode
  33. commaNode := NewPDANode(StateInComma)
  34. stateToNodeMap[StateInComma] = commaNode
  35. newlineNode := NewPDANode(StateInNewline)
  36. stateToNodeMap[StateInNewline] = newlineNode
  37. spaceNode := NewPDANode(StateInSpace)
  38. stateToNodeMap[StateInSpace] = spaceNode
  39. tabNode := NewPDANode(StateInTab)
  40. stateToNodeMap[StateInTab] = tabNode
  41. stringNode := NewPDANode(StateInString)
  42. stateToNodeMap[StateInString] = stringNode
  43. stringEndNode := NewPDANode(StateInStringEnd)
  44. stateToNodeMap[StateInStringEnd] = stringEndNode
  45. // terminateNode := NewNode(StateTerminate)
  46. // Connect nodes
  47. // TODO: if all are single tokens then this can just be connected instead of defining the token
  48. startNode.TransitionEdges['{'] = objNode
  49. objNode.TransitionEdges['"'] = objKeyNode
  50. objNode.TransitionEdges['\n'] = newlineNode
  51. newlineNode.TransitionEdges['"'] = objKeyNode
  52. newlineNode.TransitionEdges['\t'] = tabNode
  53. tabNode.TransitionEdges['"'] = objKeyNode
  54. spaceNode.TransitionEdges['"'] = stringNode
  55. objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
  56. objKeyNode.TransitionEdges['"'] = objKeyEndNode
  57. objKeyNode.TransitionEdges[' '] = spaceNode
  58. // objKeyNode.TransitionEdges['\t'] = tabNode
  59. objKeyEndNode.TransitionEdges[':'] = colonNode
  60. colonNode.TransitionEdges['"'] = stringNode
  61. colonNode.TransitionEdges[' '] = spaceNode
  62. stringNode.TransitionEdges[rune(-1)] = stringNode
  63. stringNode.TransitionEdges['"'] = stringEndNode
  64. stringEndNode.TransitionEdges[','] = commaNode
  65. stringEndNode.TransitionEdges['}'] = objEndNode
  66. commaNode.TransitionEdges['{'] = objNode
  67. commaNode.TransitionEdges['\n'] = newlineNode
  68. commaNode.TransitionEdges['\t'] = tabNode
  69. commaNode.TransitionEdges['"'] = objKeyNode
  70. return startNode, stateToNodeMap, nil
  71. }
  72. func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
  73. vocab := proc.GetVocabulary()
  74. decodedToks := make([]string, len(vocab.Values))
  75. for i := range vocab.Values {
  76. token, err := proc.Decode([]int32{int32(i)})
  77. if err != nil {
  78. return err
  79. }
  80. decodedToks[i] = token
  81. }
  82. var err error
  83. for _, node := range stateToNodeMap {
  84. for i := range vocab.Values {
  85. token := decodedToks[i]
  86. // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
  87. if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" {
  88. continue
  89. }
  90. valid := true
  91. curNode := node
  92. consumedSpecialRunes := make(map[rune]bool)
  93. for _, r := range token {
  94. valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
  95. if err != nil {
  96. return err
  97. }
  98. if !valid {
  99. break
  100. }
  101. }
  102. if valid {
  103. node.MaskTokenIDToNode[int32(i)] = curNode.State
  104. }
  105. }
  106. }
  107. return nil
  108. }
  109. func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
  110. if consumedSpecialRunes[r] {
  111. return false, nil, nil
  112. }
  113. specialRune := slices.Contains(stringInvalidRunes, r)
  114. if specialRune {
  115. if curNode.State == StateInString || curNode.State == StateInObjectKey {
  116. return false, nil, nil
  117. }
  118. }
  119. // Check for specific rune transition
  120. if nextNode, ok := curNode.TransitionEdges[r]; ok {
  121. if specialRune {
  122. if curNode.State == nextNode.State {
  123. return false, nil, nil
  124. }
  125. // fmt.Println("special rune", r, "consumed")
  126. consumedSpecialRunes[r] = true
  127. }
  128. return true, nextNode, nil
  129. }
  130. // Check for sentinel value - if present, any rune is valid
  131. if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
  132. return true, nextNode, nil
  133. }
  134. return false, nil, nil
  135. }