|
@@ -6,8 +6,35 @@ import (
|
|
"github.com/ollama/ollama/model"
|
|
"github.com/ollama/ollama/model"
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+/*
|
|
|
|
+Key JSON rules to consider:
|
|
|
|
+
|
|
|
|
+1. Whitespace handling:
|
|
|
|
+ - Need to handle all valid JSON whitespace characters (\r, spaces between tokens)
|
|
|
|
+ - Current code only handles some whitespace cases
|
|
|
|
+
|
|
|
|
+2. Number validation:
|
|
|
|
+ - Need proper validation for special number cases like -0
|
|
|
|
+ - Should handle .5 style decimals
|
|
|
|
+ - Need limits on scientific notation (e, E)
|
|
|
|
+
|
|
|
|
+3. String escaping:
|
|
|
|
+ - Currently marks \ as invalid but should allow escaped sequences:
|
|
|
|
+ - \"
|
|
|
|
+ - \n
|
|
|
|
+ - \u1234 unicode escapes
|
|
|
|
+
|
|
|
|
+4. Empty object/array transitions:
|
|
|
|
+ - Direct {} and [] cases could be more explicit
|
|
|
|
+ - Need clear transitions for these edge cases
|
|
|
|
+
|
|
|
|
+5. Nested depth limits:
|
|
|
|
+ - No protection against excessive nesting
|
|
|
|
+ - Could cause stack overflow with deeply nested structures
|
|
|
|
+*/
|
|
|
|
+
|
|
// TODO: / should be valid but an escape character
|
|
// TODO: / should be valid but an escape character
|
|
-var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
|
|
|
|
|
|
+var stringInvalidRunes = []rune{'\n', '\t', '{', '}', ':', ',', '/'}
|
|
|
|
|
|
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
|
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
|
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
|
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
|
@@ -18,31 +45,31 @@ var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
|
|
|
|
|
|
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
|
|
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
|
|
|
|
|
|
-type PDANode struct {
|
|
|
|
|
|
+type PDA struct {
|
|
State JSONState
|
|
State JSONState
|
|
- TransitionEdges map[rune]*PDANode
|
|
|
|
- MaskTokenIDToNode map[int32]*PDANode
|
|
|
|
|
|
+ TransitionEdges map[rune]*PDA
|
|
|
|
+ MaskTokenIDToNode map[int32]*PDA
|
|
}
|
|
}
|
|
|
|
|
|
-func NewPDANode(state JSONState) *PDANode {
|
|
|
|
- return &PDANode{
|
|
|
|
|
|
+func NewPDANode(state JSONState) *PDA {
|
|
|
|
+ return &PDA{
|
|
State: state,
|
|
State: state,
|
|
- TransitionEdges: make(map[rune]*PDANode),
|
|
|
|
- MaskTokenIDToNode: make(map[int32]*PDANode),
|
|
|
|
|
|
+ TransitionEdges: make(map[rune]*PDA),
|
|
|
|
+ MaskTokenIDToNode: make(map[int32]*PDA),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
|
|
|
|
- stateToNodeMap := make(map[JSONState]*PDANode)
|
|
|
|
-
|
|
|
|
- // TODO: make this a loop
|
|
|
|
|
|
+type PDAGraphBuilder struct {
|
|
|
|
+ proc model.TextProcessor
|
|
|
|
+ decodedToks []string
|
|
|
|
+ stateToNodeMap map[JSONState]*PDA
|
|
|
|
+}
|
|
|
|
|
|
|
|
+func (b *PDAGraphBuilder) BuildGraph() error {
|
|
|
|
+ stateToNodeMap := make(map[JSONState]*PDA)
|
|
for _, state := range JSONStates {
|
|
for _, state := range JSONStates {
|
|
stateToNodeMap[state] = NewPDANode(state)
|
|
stateToNodeMap[state] = NewPDANode(state)
|
|
}
|
|
}
|
|
- // TODO:
|
|
|
|
- // consider adding a node to just point to values, could be good to compute that
|
|
|
|
- // mask rather than many different nodes
|
|
|
|
|
|
|
|
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
|
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
|
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
|
|
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
|
|
@@ -51,10 +78,21 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
|
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
|
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
|
stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
|
stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
|
|
|
|
|
- //new line
|
|
|
|
|
|
+ // new line
|
|
stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
|
stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
|
stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
|
|
stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
|
|
stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
|
stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
|
|
|
+ stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
|
|
|
+
|
|
|
|
+ // new line end value
|
|
|
|
+ stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
|
|
|
+ stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
|
|
|
+ stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
|
|
|
+
|
|
|
|
+ stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
|
|
|
+ stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
|
|
|
+ // TODO: see if this is needed for formatting
|
|
|
|
+ stateToNodeMap[StateInObjSpace].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
|
|
|
|
|
stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
|
stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
|
|
|
|
|
@@ -68,16 +106,16 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
|
|
|
|
|
// where values should be
|
|
// where values should be
|
|
// this could be combined but the probl might change, we're alr doing a skip ahead
|
|
// this could be combined but the probl might change, we're alr doing a skip ahead
|
|
- stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpace]
|
|
|
|
|
|
+ stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
|
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
|
|
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
|
|
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
|
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
|
- addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
|
|
|
|
|
|
+ b.addValueConnections(stateToNodeMap[StateInColon])
|
|
|
|
|
|
// Leads to a value
|
|
// Leads to a value
|
|
- stateToNodeMap[StateInSpace].TransitionEdges['['] = stateToNodeMap[StateInList]
|
|
|
|
- stateToNodeMap[StateInSpace].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
|
|
|
- addValueConnections(stateToNodeMap[StateInSpace], stateToNodeMap)
|
|
|
|
- stateToNodeMap[StateInSpace].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
|
|
|
|
|
+ stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
|
|
|
|
+ stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
|
|
|
+ b.addValueConnections(stateToNodeMap[StateInSpaceToValue])
|
|
|
|
+ stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
|
|
|
|
|
// Values
|
|
// Values
|
|
// string node
|
|
// string node
|
|
@@ -85,149 +123,142 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
|
stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
|
|
stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
|
|
|
|
|
|
// String end node
|
|
// String end node
|
|
- addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
|
|
|
|
|
|
+ b.addEnds(stateToNodeMap[StateInStringEnd])
|
|
|
|
+ stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
|
|
|
+ stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
|
|
|
|
|
// TODO: add counters for allowable number of decimals, e, E, etc
|
|
// TODO: add counters for allowable number of decimals, e, E, etc
|
|
// number node
|
|
// number node
|
|
for _, r := range validNumberRunes {
|
|
for _, r := range validNumberRunes {
|
|
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
|
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
|
}
|
|
}
|
|
- addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
|
|
|
|
-
|
|
|
|
- // bool node
|
|
|
|
- for _, r := range validBoolRunes {
|
|
|
|
- stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
|
|
|
|
- }
|
|
|
|
- addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
|
|
|
|
- stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpace]
|
|
|
|
|
|
+ b.addEnds(stateToNodeMap[StateInNumber])
|
|
|
|
+ stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
|
|
|
+ stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
|
|
|
|
|
// list node
|
|
// list node
|
|
stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
|
stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
|
stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
|
stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
|
stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
|
|
stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
|
|
stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
|
stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
|
|
|
+
|
|
|
|
+ // list end node
|
|
|
|
+ stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
|
|
|
+ stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
|
|
|
+ stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
|
|
|
+ stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
|
|
|
+
|
|
// empty list
|
|
// empty list
|
|
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
|
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
|
- addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
|
|
|
|
|
|
+ b.addValueConnections(stateToNodeMap[StateInList])
|
|
|
|
|
|
// null node
|
|
// null node
|
|
for _, r := range validNullRunes {
|
|
for _, r := range validNullRunes {
|
|
stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
|
|
stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
|
|
}
|
|
}
|
|
- addEnds(stateToNodeMap[StateInNull], stateToNodeMap)
|
|
|
|
|
|
+ b.addEnds(stateToNodeMap[StateInNull])
|
|
|
|
+ stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
|
|
|
+ stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
|
|
|
|
|
// list comma
|
|
// list comma
|
|
// should point to values
|
|
// should point to values
|
|
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
|
|
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
|
|
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
|
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
|
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
|
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
|
- addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
|
|
|
|
|
|
+ b.addValueConnections(stateToNodeMap[StateInListComma])
|
|
|
|
|
|
// list object end
|
|
// list object end
|
|
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
|
|
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
|
|
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
|
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
|
|
|
+ // TODO: not sure if this is needed
|
|
|
|
+ stateToNodeMap[StateInListObjectEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
|
|
|
|
|
// bool node
|
|
// bool node
|
|
for _, r := range validBoolRunes {
|
|
for _, r := range validBoolRunes {
|
|
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
|
|
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
|
|
}
|
|
}
|
|
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
|
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
|
- addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
|
|
|
|
-
|
|
|
|
- stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
|
|
|
- stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
|
|
|
|
|
+ b.addEnds(stateToNodeMap[StateInBool])
|
|
|
|
+ stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
|
|
|
+ stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
|
|
|
|
|
|
|
+ // comma node
|
|
stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
|
stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
|
- stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
|
|
|
- stateToNodeMap[StateInComma].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
|
|
|
|
|
|
+ stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
|
stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
|
stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
|
stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
|
stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
|
|
|
|
|
- stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
|
|
|
- stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
|
|
|
|
|
+ // space end value
|
|
|
|
+ stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
|
|
|
+ stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
|
|
|
+ stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
|
|
|
+ stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
|
|
|
|
|
- return stateToNodeMap[StateStart], stateToNodeMap, nil
|
|
|
|
|
|
+ b.stateToNodeMap = stateToNodeMap
|
|
|
|
+ if err := b.preComputeValidStates(); err != nil {
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+ return nil
|
|
}
|
|
}
|
|
|
|
|
|
-func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
|
|
|
- node.TransitionEdges[','] = stateToNodeMap[StateInComma]
|
|
|
|
- node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
|
|
|
- node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
|
|
|
|
|
+func (b *PDAGraphBuilder) addEnds(node *PDA) {
|
|
|
|
+ node.TransitionEdges[','] = b.stateToNodeMap[StateInComma]
|
|
|
|
+ node.TransitionEdges['}'] = b.stateToNodeMap[StateInObjectEnd]
|
|
|
|
+ node.TransitionEdges[']'] = b.stateToNodeMap[StateInListEnd]
|
|
}
|
|
}
|
|
|
|
|
|
-func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
|
|
|
- node.TransitionEdges['"'] = stateToNodeMap[StateInString]
|
|
|
|
|
|
+func (b *PDAGraphBuilder) addValueConnections(node *PDA) {
|
|
|
|
+ node.TransitionEdges['"'] = b.stateToNodeMap[StateInString]
|
|
for _, r := range validNumberRunes {
|
|
for _, r := range validNumberRunes {
|
|
- node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
|
|
|
|
|
+ node.TransitionEdges[r] = b.stateToNodeMap[StateInNumber]
|
|
}
|
|
}
|
|
- node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
|
|
|
|
- node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
|
|
|
|
- node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
|
|
|
|
|
|
+ // TODO(parthsareen): force the output and shift similar to structured outputs
|
|
|
|
+ node.TransitionEdges['t'] = b.stateToNodeMap[StateInBool]
|
|
|
|
+ node.TransitionEdges['f'] = b.stateToNodeMap[StateInBool]
|
|
|
|
+ node.TransitionEdges['n'] = b.stateToNodeMap[StateInNull]
|
|
}
|
|
}
|
|
|
|
|
|
-// TODO: tough life fr. plz fix.
|
|
|
|
-func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
|
|
|
|
-
|
|
|
|
- // TODO; should come from top level
|
|
|
|
- vocab := proc.GetVocabulary()
|
|
|
|
-
|
|
|
|
- decodedToks := make([]string, len(vocab.Values))
|
|
|
|
- for i := range vocab.Values {
|
|
|
|
- token, err := proc.Decode([]int32{int32(i)})
|
|
|
|
- if err != nil {
|
|
|
|
- return err
|
|
|
|
- }
|
|
|
|
- decodedToks[i] = token
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- var err error
|
|
|
|
- for _, node := range stateToNodeMap {
|
|
|
|
- err = CreateMask(node, proc, decodedToks)
|
|
|
|
- if err != nil {
|
|
|
|
|
|
+func (b *PDAGraphBuilder) preComputeValidStates() error {
|
|
|
|
+ for _, node := range b.stateToNodeMap {
|
|
|
|
+ if err := b.CreateMask(node); err != nil {
|
|
return err
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
-func CreateMask(node *PDANode, proc model.TextProcessor, decodedToks []string) error {
|
|
|
|
- for i := range decodedToks {
|
|
|
|
- token := decodedToks[i]
|
|
|
|
|
|
+func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
|
|
|
|
+ for i := range b.decodedToks {
|
|
|
|
+ token := b.decodedToks[i]
|
|
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
|
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
|
- if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
|
|
|
|
|
+ if b.proc.Is(uint32(i), model.SpecialEOS) || b.proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
|
continue
|
|
continue
|
|
}
|
|
}
|
|
- valid := true
|
|
|
|
curNode := node
|
|
curNode := node
|
|
|
|
+ valid := true
|
|
consumedSpecialRunes := make(map[rune]bool)
|
|
consumedSpecialRunes := make(map[rune]bool)
|
|
- var err error
|
|
|
|
for _, r := range token {
|
|
for _, r := range token {
|
|
- valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
|
|
|
|
- if err != nil {
|
|
|
|
- return err
|
|
|
|
- }
|
|
|
|
- if !valid {
|
|
|
|
|
|
+ curNode, valid = isRuneValid(r, curNode, consumedSpecialRunes)
|
|
|
|
+ if curNode == nil || !valid {
|
|
break
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if valid {
|
|
if valid {
|
|
- // cur node allows skipping
|
|
|
|
node.MaskTokenIDToNode[int32(i)] = curNode
|
|
node.MaskTokenIDToNode[int32(i)] = curNode
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
-// TODO: garbage interface plz fix
|
|
|
|
-func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
|
|
|
|
|
|
+func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA, bool) {
|
|
if consumedSpecialRunes[r] {
|
|
if consumedSpecialRunes[r] {
|
|
- return false, nil, nil
|
|
|
|
|
|
+ return nil, false
|
|
}
|
|
}
|
|
|
|
|
|
specialRune := slices.Contains(stringInvalidRunes, r)
|
|
specialRune := slices.Contains(stringInvalidRunes, r)
|
|
if specialRune {
|
|
if specialRune {
|
|
if curNode.State == StateInString || curNode.State == StateInObjectKey {
|
|
if curNode.State == StateInString || curNode.State == StateInObjectKey {
|
|
- return false, nil, nil
|
|
|
|
|
|
+ return nil, false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -235,17 +266,17 @@ func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (
|
|
if nextNode, ok := curNode.TransitionEdges[r]; ok {
|
|
if nextNode, ok := curNode.TransitionEdges[r]; ok {
|
|
if specialRune {
|
|
if specialRune {
|
|
if curNode.State == nextNode.State {
|
|
if curNode.State == nextNode.State {
|
|
- return false, nil, nil
|
|
|
|
|
|
+ return nil, false
|
|
}
|
|
}
|
|
consumedSpecialRunes[r] = true
|
|
consumedSpecialRunes[r] = true
|
|
}
|
|
}
|
|
- return true, nextNode, nil
|
|
|
|
|
|
+ return nextNode, true
|
|
}
|
|
}
|
|
|
|
|
|
// Check for sentinel value - if present, any rune is valid
|
|
// Check for sentinel value - if present, any rune is valid
|
|
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
|
|
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
|
|
- return true, nextNode, nil
|
|
|
|
|
|
+ return nextNode, true
|
|
}
|
|
}
|
|
|
|
|
|
- return false, nil, nil
|
|
|
|
|
|
+ return nil, false
|
|
}
|
|
}
|