瀏覽代碼

err handling + fixing scope issue

ParthSareen 2 月之前
父節點
當前提交
ffd6428758
共有 2 個文件被更改,包括 29 次插入20 次删除
  1. 24 19
      sample/pushdown_automata.go
  2. 5 1
      sample/structured_outputs.go

+ 24 - 19
sample/pushdown_automata.go

@@ -1,6 +1,7 @@
 package sample
 package sample
 
 
 import (
 import (
+	"fmt"
 	"slices"
 	"slices"
 
 
 	"github.com/ollama/ollama/model"
 	"github.com/ollama/ollama/model"
@@ -34,7 +35,7 @@ Key JSON rules to consider:
 */
 */
 
 
 // TODO: / should be valid but an escape character
 // TODO: / should be valid but an escape character
-var stringInvalidRunes = []rune{'\n', '\t', '{', '}', ':', ',', '/'}
+var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
 
 
 var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
 var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
 var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
 var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
@@ -109,12 +110,12 @@ func (b *PDAGraphBuilder) BuildGraph() error {
 	stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
 	stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
 	stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
 	stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
 	stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
 	stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
-	b.addValueConnections(stateToNodeMap[StateInColon])
+	addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
 
 
 	// Leads to a value
 	// Leads to a value
 	stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
 	stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
 	stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
 	stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
-	b.addValueConnections(stateToNodeMap[StateInSpaceToValue])
+	addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap)
 	stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
 	stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
 
 
 	// Values
 	// Values
@@ -123,7 +124,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
 	stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
 	stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
 
 
 	// String end node
 	// String end node
-	b.addEnds(stateToNodeMap[StateInStringEnd])
+	addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
 	stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
 	stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
 	stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
 	stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
 
 
@@ -132,7 +133,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
 	for _, r := range validNumberRunes {
 	for _, r := range validNumberRunes {
 		stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
 		stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
 	}
 	}
-	b.addEnds(stateToNodeMap[StateInNumber])
+	addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
 	stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
 	stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
 	stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
 	stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
 
 
@@ -150,13 +151,13 @@ func (b *PDAGraphBuilder) BuildGraph() error {
 
 
 	// empty list
 	// empty list
 	stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
 	stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
-	b.addValueConnections(stateToNodeMap[StateInList])
+	addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
 
 
 	// null node
 	// null node
 	for _, r := range validNullRunes {
 	for _, r := range validNullRunes {
 		stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
 		stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
 	}
 	}
-	b.addEnds(stateToNodeMap[StateInNull])
+	addEnds(stateToNodeMap[StateInNull], stateToNodeMap)
 	stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
 	stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
 	stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
 	stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
 
 
@@ -165,7 +166,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
 	stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
 	stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
 	stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
 	stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
 	stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
 	stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
-	b.addValueConnections(stateToNodeMap[StateInListComma])
+	addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
 
 
 	// list object end
 	// list object end
 	stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
 	stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
@@ -178,7 +179,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
 		stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
 		stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
 	}
 	}
 	stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
 	stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
-	b.addEnds(stateToNodeMap[StateInBool])
+	addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
 	stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
 	stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
 	stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
 	stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
 
 
@@ -201,21 +202,21 @@ func (b *PDAGraphBuilder) BuildGraph() error {
 	return nil
 	return nil
 }
 }
 
 
-func (b *PDAGraphBuilder) addEnds(node *PDA) {
-	node.TransitionEdges[','] = b.stateToNodeMap[StateInComma]
-	node.TransitionEdges['}'] = b.stateToNodeMap[StateInObjectEnd]
-	node.TransitionEdges[']'] = b.stateToNodeMap[StateInListEnd]
+func addEnds(node *PDA, stateToNodeMap map[JSONState]*PDA) {
+	node.TransitionEdges[','] = stateToNodeMap[StateInComma]
+	node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
+	node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
 }
 }
 
 
-func (b *PDAGraphBuilder) addValueConnections(node *PDA) {
-	node.TransitionEdges['"'] = b.stateToNodeMap[StateInString]
+func addValueConnections(node *PDA, stateToNodeMap map[JSONState]*PDA) {
+	node.TransitionEdges['"'] = stateToNodeMap[StateInString]
 	for _, r := range validNumberRunes {
 	for _, r := range validNumberRunes {
-		node.TransitionEdges[r] = b.stateToNodeMap[StateInNumber]
+		node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
 	}
 	}
 	// TODO(parthsareen): force the output and shift similar to structured outputs
 	// TODO(parthsareen): force the output and shift similar to structured outputs
-	node.TransitionEdges['t'] = b.stateToNodeMap[StateInBool]
-	node.TransitionEdges['f'] = b.stateToNodeMap[StateInBool]
-	node.TransitionEdges['n'] = b.stateToNodeMap[StateInNull]
+	node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
+	node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
+	node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
 }
 }
 
 
 func (b *PDAGraphBuilder) preComputeValidStates() error {
 func (b *PDAGraphBuilder) preComputeValidStates() error {
@@ -228,6 +229,9 @@ func (b *PDAGraphBuilder) preComputeValidStates() error {
 }
 }
 
 
 func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
 func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
+	if node == nil {
+		return fmt.Errorf("node cannot be nil")
+	}
 	for i := range b.decodedToks {
 	for i := range b.decodedToks {
 		token := b.decodedToks[i]
 		token := b.decodedToks[i]
 		// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
 		// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
@@ -264,6 +268,7 @@ func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA
 
 
 	// Check for specific rune transition
 	// Check for specific rune transition
 	if nextNode, ok := curNode.TransitionEdges[r]; ok {
 	if nextNode, ok := curNode.TransitionEdges[r]; ok {
+		// fmt.Println("next node", nextNode)
 		if specialRune {
 		if specialRune {
 			if curNode.State == nextNode.State {
 			if curNode.State == nextNode.State {
 				return nil, false
 				return nil, false

+ 5 - 1
sample/structured_outputs.go

@@ -17,9 +17,13 @@ type JSONSampler struct {
 }
 }
 
 
 func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, error) {
 func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, error) {
+	if proc == nil {
+		return nil, fmt.Errorf("TextProcessor cannot be nil")
+	}
+
 	pdaSampler, err := NewPushdownSampler(proc)
 	pdaSampler, err := NewPushdownSampler(proc)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("failed to create PushdownSampler: %w", err)
 	}
 	}
 
 
 	if schema == nil {
 	if schema == nil {