|
@@ -108,10 +108,13 @@ type PythonSampler struct {
|
|
|
proc model.TextProcessor
|
|
|
decodedToks []string
|
|
|
curNode *Node
|
|
|
+ completed int
|
|
|
+ functions []PythonFunction
|
|
|
}
|
|
|
|
|
|
func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
|
|
|
s.proc = proc
|
|
|
+ s.functions = functions
|
|
|
decodedToks := make([]string, len(proc.Vocab().Values))
|
|
|
for i := range proc.Vocab().Values {
|
|
|
token, err := proc.Decode([]int32{int32(i)})
|
|
@@ -194,7 +197,7 @@ func (s *PythonSampler) BuildGraph() error {
|
|
|
|
|
|
// String end
|
|
|
s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
|
|
|
- s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
|
|
+ // s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
|
|
// Number
|
|
|
for _, r := range validNumberRunes {
|
|
|
s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
|
|
@@ -237,6 +240,16 @@ func (s *PythonSampler) UpdateState(token int32) error {
|
|
|
if !ok {
|
|
|
return fmt.Errorf("invalid token: %q", mappedString)
|
|
|
}
|
|
|
+
|
|
|
+ if mappedString == "\"" {
|
|
|
+ if s.curNode.State == PStateInStringEnd {
|
|
|
+ s.completed++
|
|
|
+ }
|
|
|
+ if s.completed == len(s.functions) {
|
|
|
+ s.curNode.TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
|
|
+ s.CreateMask(s.curNode)
|
|
|
+ }
|
|
|
+ }
|
|
|
s.curNode = nextNode
|
|
|
fmt.Println("curNode", s.curNode.State)
|
|
|
for r, node := range s.curNode.TransitionEdges {
|