|
@@ -6,7 +6,9 @@ import (
|
|
|
"github.com/ollama/ollama/model"
|
|
|
)
|
|
|
|
|
|
-var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
|
|
|
+// TODO: / should be valid but an escape character
|
|
|
+
|
|
|
+var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
|
|
|
|
|
|
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
|
|
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
|
@@ -34,6 +36,7 @@ func NewPDANode(state JSONState) *PDANode {
|
|
|
func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
|
|
|
stateToNodeMap := make(map[JSONState]*PDANode)
|
|
|
|
|
|
+ // TODO: make this a loop
|
|
|
startNode := NewPDANode(StateStart)
|
|
|
stateToNodeMap[StateStart] = startNode
|
|
|
|
|
@@ -95,6 +98,9 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
|
|
intNode := NewPDANode(StateInInt)
|
|
|
stateToNodeMap[StateInInt] = intNode
|
|
|
|
|
|
+ listObjEndNode := NewPDANode(StateInListObjectEnd)
|
|
|
+ stateToNodeMap[StateInListObjectEnd] = listObjEndNode
|
|
|
+
|
|
|
// TODO:
|
|
|
// consider adding a node to just point to values, could be good to compute that
|
|
|
// mask rather than many different nodes
|
|
@@ -105,108 +111,84 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
|
|
|
|
|
objNode.TransitionEdges['"'] = objKeyNode
|
|
|
objNode.TransitionEdges['\n'] = newlineNode
|
|
|
- // objNode.TransitionEdges['\t'] = tabNode
|
|
|
+ objNode.TransitionEdges[' '] = spaceObjNode
|
|
|
|
|
|
+ //new line
|
|
|
newlineNode.TransitionEdges['"'] = objKeyNode
|
|
|
newlineNode.TransitionEdges['\t'] = tabNode
|
|
|
|
|
|
tabNode.TransitionEdges['"'] = objKeyNode
|
|
|
- // tabNode.TransitionEdges['\t'] = tabNode
|
|
|
|
|
|
objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
|
|
|
objKeyNode.TransitionEdges['"'] = objKeyEndNode
|
|
|
|
|
|
objKeyEndNode.TransitionEdges[':'] = colonNode
|
|
|
- objEndNode.TransitionEdges[' '] = spaceNode
|
|
|
+
|
|
|
+ objEndNode.TransitionEdges[','] = commaNode
|
|
|
+ objEndNode.TransitionEdges['}'] = objEndNode
|
|
|
|
|
|
// where values should be
|
|
|
// this could be combined but the probs might change, we're alr doing a skip ahead
|
|
|
colonNode.TransitionEdges[' '] = spaceNode
|
|
|
+ colonNode.TransitionEdges['['] = listNode
|
|
|
+ colonNode.TransitionEdges['{'] = objNode
|
|
|
+ addValueConnections(colonNode, stateToNodeMap)
|
|
|
|
|
|
// Leads to a value
|
|
|
- spaceNode.TransitionEdges['"'] = stringNode
|
|
|
spaceNode.TransitionEdges['['] = listNode
|
|
|
spaceNode.TransitionEdges['{'] = objNode
|
|
|
-
|
|
|
- for _, r := range validNumberRunes {
|
|
|
- spaceNode.TransitionEdges[r] = numberNode
|
|
|
- }
|
|
|
- for _, r := range validBoolRunes {
|
|
|
- spaceNode.TransitionEdges[r] = boolNode
|
|
|
- }
|
|
|
-
|
|
|
- for _, r := range validNullRunes {
|
|
|
- spaceNode.TransitionEdges[r] = nullNode
|
|
|
- }
|
|
|
+ addValueConnections(spaceNode, stateToNodeMap)
|
|
|
|
|
|
// Values
|
|
|
// string node
|
|
|
stringNode.TransitionEdges[rune(-1)] = stringNode
|
|
|
stringNode.TransitionEdges['"'] = stringEndNode
|
|
|
|
|
|
- stringEndNode.TransitionEdges[','] = commaNode
|
|
|
- stringEndNode.TransitionEdges['}'] = objEndNode
|
|
|
- stringEndNode.TransitionEdges[']'] = listEndNode
|
|
|
+ // String end node
|
|
|
+ addEnds(stringEndNode, stateToNodeMap)
|
|
|
|
|
|
// TODO: add counters for allowable number of decimals, e, E, etc
|
|
|
// number node
|
|
|
for _, r := range validNumberRunes {
|
|
|
numberNode.TransitionEdges[r] = numberNode
|
|
|
}
|
|
|
- numberNode.TransitionEdges[','] = commaNode
|
|
|
- numberNode.TransitionEdges['}'] = objEndNode
|
|
|
- numberNode.TransitionEdges[']'] = listEndNode
|
|
|
+ addEnds(numberNode, stateToNodeMap)
|
|
|
|
|
|
+ // bool node
|
|
|
for _, r := range validBoolRunes {
|
|
|
boolNode.TransitionEdges[r] = boolNode
|
|
|
}
|
|
|
+ addEnds(boolNode, stateToNodeMap)
|
|
|
|
|
|
// list node
|
|
|
listNode.TransitionEdges[','] = commaNode
|
|
|
- listNode.TransitionEdges['"'] = stringNode
|
|
|
- // squash states to a value
|
|
|
- for _, r := range validNumberRunes {
|
|
|
- listNode.TransitionEdges[r] = numberNode
|
|
|
- }
|
|
|
- for _, r := range validBoolRunes {
|
|
|
- listNode.TransitionEdges[r] = boolNode
|
|
|
- }
|
|
|
- for _, r := range validNullRunes {
|
|
|
- listNode.TransitionEdges[r] = nullNode
|
|
|
- }
|
|
|
+ listNode.TransitionEdges['{'] = objNode
|
|
|
+ listNode.TransitionEdges[' '] = listNode
|
|
|
+ listNode.TransitionEdges['\n'] = listNode
|
|
|
+ addValueConnections(listNode, stateToNodeMap)
|
|
|
|
|
|
// null node
|
|
|
for _, r := range validNullRunes {
|
|
|
nullNode.TransitionEdges[r] = nullNode
|
|
|
}
|
|
|
- nullNode.TransitionEdges[','] = commaNode
|
|
|
- nullNode.TransitionEdges['}'] = objEndNode
|
|
|
- nullNode.TransitionEdges[']'] = listEndNode
|
|
|
+ addEnds(nullNode, stateToNodeMap)
|
|
|
|
|
|
// list comma
|
|
|
// should point to values
|
|
|
- listCommaNode.TransitionEdges['"'] = stringNode
|
|
|
listCommaNode.TransitionEdges[' '] = listCommaNode
|
|
|
listCommaNode.TransitionEdges['{'] = objNode
|
|
|
listCommaNode.TransitionEdges['\n'] = newlineNode
|
|
|
+ addValueConnections(listCommaNode, stateToNodeMap)
|
|
|
|
|
|
- for _, r := range validNumberRunes {
|
|
|
- listCommaNode.TransitionEdges[r] = numberNode
|
|
|
- }
|
|
|
- for _, r := range validBoolRunes {
|
|
|
- listCommaNode.TransitionEdges[r] = boolNode
|
|
|
- }
|
|
|
- for _, r := range validNullRunes {
|
|
|
- listCommaNode.TransitionEdges[r] = nullNode
|
|
|
- }
|
|
|
+ // list object end
|
|
|
+ listObjEndNode.TransitionEdges[','] = listCommaNode
|
|
|
+ listObjEndNode.TransitionEdges[']'] = listEndNode
|
|
|
|
|
|
// bool node
|
|
|
for _, r := range validBoolRunes {
|
|
|
boolNode.TransitionEdges[r] = boolNode
|
|
|
}
|
|
|
- boolNode.TransitionEdges['}'] = objEndNode
|
|
|
- boolNode.TransitionEdges[']'] = listEndNode
|
|
|
- boolNode.TransitionEdges[','] = commaNode
|
|
|
+ addEnds(boolNode, stateToNodeMap)
|
|
|
|
|
|
listEndNode.TransitionEdges['}'] = objEndNode
|
|
|
listEndNode.TransitionEdges[','] = commaNode
|
|
@@ -218,10 +200,27 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
|
|
commaNode.TransitionEdges[' '] = spaceObjNode
|
|
|
|
|
|
spaceObjNode.TransitionEdges['"'] = objKeyNode
|
|
|
+ spaceObjNode.TransitionEdges['\n'] = newlineNode
|
|
|
|
|
|
return startNode, stateToNodeMap, nil
|
|
|
}
|
|
|
|
|
|
+func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
|
|
+ node.TransitionEdges[','] = stateToNodeMap[StateInComma]
|
|
|
+ node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
|
|
+ node.TransitionEdges[']'] = stateToNodeMap[StateListEnd]
|
|
|
+}
|
|
|
+
|
|
|
+func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
|
|
+ node.TransitionEdges['"'] = stateToNodeMap[StateInString]
|
|
|
+ for _, r := range validNumberRunes {
|
|
|
+ node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
|
|
+ }
|
|
|
+ node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
|
|
|
+ node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
|
|
|
+ node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
|
|
|
+}
|
|
|
+
|
|
|
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
|
|
|
|
|
|
vocab := proc.GetVocabulary()
|
|
@@ -240,7 +239,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
|
|
|
for i := range vocab.Values {
|
|
|
token := decodedToks[i]
|
|
|
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
|
|
- if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" {
|
|
|
+ if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
|
|
continue
|
|
|
}
|
|
|
valid := true
|
|
@@ -263,6 +262,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+// garbage interface plz fix
|
|
|
func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
|
|
|
if consumedSpecialRunes[r] {
|
|
|
return false, nil, nil
|
|
@@ -281,7 +281,6 @@ func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (
|
|
|
if curNode.State == nextNode.State {
|
|
|
return false, nil, nil
|
|
|
}
|
|
|
- // fmt.Println("special rune", r, "consumed")
|
|
|
consumedSpecialRunes[r] = true
|
|
|
}
|
|
|
return true, nextNode, nil
|