فهرست منبع

Enable array type json

ParthSareen 3 ماه پیش
والد
کامیت
198fde82aa
3فایلهای تغییر یافته به همراه28 افزوده شده و 9 حذف شده
  1. 0 4
      sample/fast_json.go
  2. 5 4
      sample/pushdown_automata.go
  3. 23 1
      sample/pushdown_runner.go

+ 0 - 4
sample/fast_json.go

@@ -29,7 +29,6 @@ const (
 	StateInObjSpace
 	StateInObjSpace
 	StateInList
 	StateInList
 	StateInListComma
 	StateInListComma
-	StateListEnd
 	StateInValue
 	StateInValue
 	StateInValueEnd
 	StateInValueEnd
 	StateInListEnd
 	StateInListEnd
@@ -63,7 +62,6 @@ var JSONStates = []JSONState{
 	StateInObjSpace,
 	StateInObjSpace,
 	StateInList,
 	StateInList,
 	StateInListComma,
 	StateInListComma,
-	StateListEnd,
 	StateInValue,
 	StateInValue,
 	StateInValueEnd,
 	StateInValueEnd,
 	StateInListEnd,
 	StateInListEnd,
@@ -118,8 +116,6 @@ func (s JSONState) String() string {
 		return "StateInListObjectEnd"
 		return "StateInListObjectEnd"
 	case StateInListComma:
 	case StateInListComma:
 		return "StateInListComma"
 		return "StateInListComma"
-	case StateListEnd:
-		return "StateListEnd"
 	case StateInListEnd:
 	case StateInListEnd:
 		return "StateInListEnd"
 		return "StateInListEnd"
 	case StateInNewline:
 	case StateInNewline:

+ 5 - 4
sample/pushdown_automata.go

@@ -47,6 +47,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
 	// Connect nodes
 	// Connect nodes
 	// TODO: if all are single tokens then this can just be connected instead of defining the token
 	// TODO: if all are single tokens then this can just be connected instead of defining the token
 	stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
 	stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
+	stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
 
 
 	stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
 	stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
 	stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
 	stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
@@ -121,7 +122,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
 
 
 	// list object end
 	// list object end
 	stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
 	stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
-	stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateListEnd]
+	stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
 
 
 	// bool node
 	// bool node
 	for _, r := range validBoolRunes {
 	for _, r := range validBoolRunes {
@@ -129,8 +130,8 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
 	}
 	}
 	addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
 	addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
 
 
-	stateToNodeMap[StateListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
-	stateToNodeMap[StateListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
+	stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
+	stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
 
 
 	stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
 	stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
 	stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
 	stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
@@ -147,7 +148,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
 func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
 func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
 	node.TransitionEdges[','] = stateToNodeMap[StateInComma]
 	node.TransitionEdges[','] = stateToNodeMap[StateInComma]
 	node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
 	node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
-	node.TransitionEdges[']'] = stateToNodeMap[StateListEnd]
+	node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
 }
 }
 
 
 func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
 func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {

+ 23 - 1
sample/pushdown_runner.go

@@ -58,6 +58,27 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
 	case StateInString:
 	case StateInString:
 		return s.maskLogits(logits, s.curNode)
 		return s.maskLogits(logits, s.curNode)
 
 
+	case StateInListEnd:
+		fmt.Println("in list end", s.braceStack)
+		// force finish if no braces left
+		if len(s.braceStack) == 0 {
+			s.curNode = NewPDANode(StateTerminate)
+			for i := range logits {
+				if s.proc.Is(uint32(i), model.SpecialEOS) {
+					logits[i] = 1.0
+				} else {
+					logits[i] = math.NaN()
+				}
+			}
+			return logits, nil
+		}
+
+		logits, err := s.maskLogits(logits, s.curNode)
+		if err != nil {
+			return nil, err
+		}
+		return logits, nil
+
 	case StateInObjectEnd:
 	case StateInObjectEnd:
 		// force finish if no braces left
 		// force finish if no braces left
 		if len(s.braceStack) == 0 {
 		if len(s.braceStack) == 0 {
@@ -117,11 +138,12 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
 }
 }
 
 
 func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
 func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
-	// fmt.Println("update state", s.curNode.State)
+	fmt.Println("update state", s.curNode.State)
 	mappedString, err := s.proc.Decode(tokenSlice)
 	mappedString, err := s.proc.Decode(tokenSlice)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
+	fmt.Println("mappedString", mappedString)
 
 
 	// TODO: should force closing for all braces - not doing square yet
 	// TODO: should force closing for all braces - not doing square yet
 	for _, r := range mappedString {
 	for _, r := range mappedString {