Explorar o código

wip with json stuff and cleanup

ParthSareen hai 2 meses
pai
achega
aa6d5151df

+ 49 - 0
sample/constrained.go

@@ -0,0 +1,49 @@
+package sample
+
+import (
+	"github.com/ollama/ollama/model"
+)
+
+type ConstrainedSampler struct {
+	schema        *Schema
+	propIdx       int
+	propToNodeMap map[string]*PDA
+	pdaSampler    *PushdownSampler
+	decodedToks   []string
+}
+
+func NewConstrainedSampler(proc model.TextProcessor, schema *Schema) (*ConstrainedSampler, error) {
+	pdaSampler, err := NewPushdownSampler(proc)
+	if err != nil {
+		return nil, err
+	}
+
+	// if schema == nil {
+	return &ConstrainedSampler{
+		schema:        nil,
+		propIdx:       -1,
+		propToNodeMap: nil,
+		pdaSampler:    pdaSampler,
+	}, nil
+
+}
+
+func (s *ConstrainedSampler) Apply(logits []float64) ([]float64, error) {
+	if s.schema == nil {
+		return s.pdaSampler.Apply(logits)
+	}
+
+	return nil, nil
+}
+
+func (s *ConstrainedSampler) UpdateState(tokenSlice []int32) error {
+	if err := s.pdaSampler.UpdateState(tokenSlice); err != nil {
+		return err
+	}
+
+	if s.schema == nil {
+		return nil
+	}
+
+	return nil
+}

+ 32 - 0
sample/feedback.txt

@@ -0,0 +1,32 @@
+// Feedback from code review:
+
+// pushdown_automata.go:
+// 1. The BuildGraph function is quite long and could be split into smaller, more focused functions
+// 2. Consider using constants instead of magic runes like rune(-1) for sentinel values
+// 3. The state machine transitions could be defined more declaratively, perhaps in a config
+// 4. The stringInvalidRunes list needs to handle escape sequences properly
+// 5. The graph building could be optimized to avoid duplicate nodes/transitions
+// 6. Consider adding validation for max nesting depth of braces/brackets
+// 7. The CreateMask function is doing a lot - could be split into smaller pieces
+// 8. isRuneValid has a "garbage interface" per TODO - needs cleaner design
+
+// pushdown_runner.go:
+// 1. The Apply method has a lot of duplicated logic around EOS handling
+// 2. The UpdateState method could use more granular error messages
+// 3. The braceStack validation could be moved to a separate validator
+// 4. Consider adding max length limits for strings/numbers
+// 5. The stateCounter isn't being used effectively yet
+// 6. Need to add penalties for staying in same state too long
+// 7. The maskLogits function could be optimized to avoid allocations
+// 8. Missing proper cleanup/reset functionality
+// 9. Error handling could be more consistent throughout
+// 10. Consider adding debug logging levels instead of raw fmt.Println
+
+// General improvements needed:
+// - More comprehensive testing, especially edge cases
+// - Better documentation of state machine transitions
+// - Performance optimization for large inputs
+// - Memory usage optimization for the graph structure
+// - Cleaner interfaces between components
+// - More robust error handling and recovery
+

+ 11 - 0
sample/fused_mask_sample.go

@@ -0,0 +1,11 @@
+package sample
+
+// type fusedMaskSampler struct{}
+
+// func FusedMaskSampler() Sampler {
+// 	return fusedMaskSampler{}
+// }
+
+// func (f fusedMaskSampler) Sample(logits []float64) (int, error) {
+// 	return int(logits[0]), nil
+// }

+ 15 - 2
sample/greedy.go

@@ -8,6 +8,19 @@ func Greedy() Sampler {
 	return greedy{}
 }
 
-func (s greedy) Sample(t []float64) (int, error) {
-	return floats.MaxIdx(t), nil
+func (s greedy) Sample(logits []float32, transforms ...Transform) (int, error) {
+	logits64 := make([]float64, len(logits))
+	for i, v := range logits {
+		logits64[i] = float64(v)
+	}
+
+	var err error
+	for _, t := range transforms {
+		logits64, err = t.Apply(logits64)
+		if err != nil {
+			return -1, err
+		}
+	}
+
+	return floats.MaxIdx(logits64), nil
 }

+ 11 - 3
sample/fast_json.go → sample/json_types.go

@@ -23,7 +23,9 @@ const (
 	StateInColon
 	StateInComma
 	StateInTab
-	StateInSpace
+	StateInSpaceToValue
+	StateInSpaceEndValue
+	StateInNewlineEndValue
 	StateInObjSpace
 	StateInList
 	StateInListComma
@@ -57,7 +59,9 @@ var JSONStates = []JSONState{
 	StateInColon,
 	StateInComma,
 	StateInTab,
-	StateInSpace,
+	StateInSpaceToValue,
+	StateInSpaceEndValue,
+	StateInNewlineEndValue,
 	StateInObjSpace,
 	StateInList,
 	StateInListComma,
@@ -107,7 +111,7 @@ func (s JSONState) String() string {
 		return "StateInComma"
 	case StateInTab:
 		return "StateInTab"
-	case StateInSpace:
+	case StateInSpaceToValue:
 		return "StateInSpace"
 	case StateInObjSpace:
 		return "StateInObjSpace"
@@ -121,6 +125,8 @@ func (s JSONState) String() string {
 		return "StateInListEnd"
 	case StateInNewline:
 		return "StateInNewline"
+	case StateInNewlineEndValue:
+		return "StateInNewlineEndValue"
 	case StateInNumber:
 		return "StateInNumber"
 	case StateInNumberEnd:
@@ -129,6 +135,8 @@ func (s JSONState) String() string {
 		return "StateInStringEnd"
 	case StateInObjectKeyEnd:
 		return "StateInObjectKeyEnd"
+	case StateInSpaceEndValue:
+		return "StateInSpaceEndValue"
 	case StateTerminate:
 		return "StateTerminate"
 	case StateInObjectEnd:

+ 123 - 92
sample/pushdown_automata.go

@@ -6,8 +6,35 @@ import (
 	"github.com/ollama/ollama/model"
 )
 
+/*
+Key JSON rules to consider:
+
+1. Whitespace handling:
+   - Need to handle all valid JSON whitespace characters (\r, spaces between tokens)
+   - Current code only handles some whitespace cases
+
+2. Number validation:
+   - Need proper validation for special number cases like -0
+   - Should handle .5 style decimals
+   - Need limits on scientific notation (e, E)
+
+3. String escaping:
+   - Currently marks \ as invalid but should allow escaped sequences:
+     - \"
+     - \n
+     - \u1234 unicode escapes
+
+4. Empty object/array transitions:
+   - Direct {} and [] cases could be more explicit
+   - Need clear transitions for these edge cases
+
+5. Nested depth limits:
+   - No protection against excessive nesting
+   - Could cause stack overflow with deeply nested structures
+*/
+
 // TODO: / should be valid but an escape character
-var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
+var stringInvalidRunes = []rune{'\n', '\t', '{', '}', ':', ',', '/'}
 
 var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
 var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
@@ -18,31 +45,31 @@ var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
 
 var validNullRunes = []rune{'n', 'u', 'l', 'l'}
 
-type PDANode struct {
+type PDA struct {
 	State             JSONState
-	TransitionEdges   map[rune]*PDANode
-	MaskTokenIDToNode map[int32]*PDANode
+	TransitionEdges   map[rune]*PDA
+	MaskTokenIDToNode map[int32]*PDA
 }
 
-func NewPDANode(state JSONState) *PDANode {
-	return &PDANode{
+func NewPDANode(state JSONState) *PDA {
+	return &PDA{
 		State:             state,
-		TransitionEdges:   make(map[rune]*PDANode),
-		MaskTokenIDToNode: make(map[int32]*PDANode),
+		TransitionEdges:   make(map[rune]*PDA),
+		MaskTokenIDToNode: make(map[int32]*PDA),
 	}
 }
 
-func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
-	stateToNodeMap := make(map[JSONState]*PDANode)
-
-	// TODO: make this a loop
+type PDAGraphBuilder struct {
+	proc           model.TextProcessor
+	decodedToks    []string
+	stateToNodeMap map[JSONState]*PDA
+}
 
+func (b *PDAGraphBuilder) BuildGraph() error {
+	stateToNodeMap := make(map[JSONState]*PDA)
 	for _, state := range JSONStates {
 		stateToNodeMap[state] = NewPDANode(state)
 	}
-	// TODO:
-	// consider adding a node to just point to values, could be good to compute that
-	// mask rather than many different nodes
 
 	stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
 	stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
@@ -51,10 +78,21 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
 	stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
 	stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
 
-	//new line
+	// new line
 	stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
 	stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
 	stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
+	stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
+
+	// new line end value
+	stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
+	stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
+	stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
+
+	stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
+	stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
+	// TODO: see if this is needed for formatting
+	stateToNodeMap[StateInObjSpace].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
 
 	stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
 
@@ -68,16 +106,16 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
 
 	// where values should be
 	// this could be combined but the probl might change, we're alr doing a skip ahead
-	stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpace]
+	stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
 	stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
 	stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
-	addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
+	b.addValueConnections(stateToNodeMap[StateInColon])
 
 	// Leads to a value
-	stateToNodeMap[StateInSpace].TransitionEdges['['] = stateToNodeMap[StateInList]
-	stateToNodeMap[StateInSpace].TransitionEdges['{'] = stateToNodeMap[StateInObject]
-	addValueConnections(stateToNodeMap[StateInSpace], stateToNodeMap)
-	stateToNodeMap[StateInSpace].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
+	stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
+	stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
+	b.addValueConnections(stateToNodeMap[StateInSpaceToValue])
+	stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
 
 	// Values
 	// string node
@@ -85,149 +123,142 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
 	stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
 
 	// String end node
-	addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
+	b.addEnds(stateToNodeMap[StateInStringEnd])
+	stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
+	stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
 
 	// TODO: add counters for allowable number of decimals, e, E, etc
 	// number node
 	for _, r := range validNumberRunes {
 		stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
 	}
-	addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
-
-	// bool node
-	for _, r := range validBoolRunes {
-		stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
-	}
-	addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
-	stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpace]
+	b.addEnds(stateToNodeMap[StateInNumber])
+	stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
+	stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
 
 	// list node
 	stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
 	stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
 	stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
 	stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
+
+	// list end node
+	stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
+	stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
+	stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
+	stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
+
 	// empty list
 	stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
-	addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
+	b.addValueConnections(stateToNodeMap[StateInList])
 
 	// null node
 	for _, r := range validNullRunes {
 		stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
 	}
-	addEnds(stateToNodeMap[StateInNull], stateToNodeMap)
+	b.addEnds(stateToNodeMap[StateInNull])
+	stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
+	stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
 
 	// list comma
 	// should point to values
 	stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
 	stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
 	stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
-	addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
+	b.addValueConnections(stateToNodeMap[StateInListComma])
 
 	// list object end
 	stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
 	stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
+	// TODO: not sure if this is needed
+	stateToNodeMap[StateInListObjectEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
 
 	// bool node
 	for _, r := range validBoolRunes {
 		stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
 	}
 	stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
-	addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
-
-	stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
-	stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
+	b.addEnds(stateToNodeMap[StateInBool])
+	stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
+	stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
 
+	// comma node
 	stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
-	stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
-	stateToNodeMap[StateInComma].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
+	stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
 	stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
 	stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
 
-	stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
-	stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
+	// space end value
+	stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
+	stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
+	stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
+	stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
 
-	return stateToNodeMap[StateStart], stateToNodeMap, nil
+	b.stateToNodeMap = stateToNodeMap
+	if err := b.preComputeValidStates(); err != nil {
+		return err
+	}
+	return nil
 }
 
-func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
-	node.TransitionEdges[','] = stateToNodeMap[StateInComma]
-	node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
-	node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
+func (b *PDAGraphBuilder) addEnds(node *PDA) {
+	node.TransitionEdges[','] = b.stateToNodeMap[StateInComma]
+	node.TransitionEdges['}'] = b.stateToNodeMap[StateInObjectEnd]
+	node.TransitionEdges[']'] = b.stateToNodeMap[StateInListEnd]
 }
 
-func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
-	node.TransitionEdges['"'] = stateToNodeMap[StateInString]
+func (b *PDAGraphBuilder) addValueConnections(node *PDA) {
+	node.TransitionEdges['"'] = b.stateToNodeMap[StateInString]
 	for _, r := range validNumberRunes {
-		node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
+		node.TransitionEdges[r] = b.stateToNodeMap[StateInNumber]
 	}
-	node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
-	node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
-	node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
+	// TODO(parthsareen): force the output and shift similar to structured outputs
+	node.TransitionEdges['t'] = b.stateToNodeMap[StateInBool]
+	node.TransitionEdges['f'] = b.stateToNodeMap[StateInBool]
+	node.TransitionEdges['n'] = b.stateToNodeMap[StateInNull]
 }
 
-// TODO: tough life fr. plz fix.
-func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
-
-	// TODO; should come from top level
-	vocab := proc.GetVocabulary()
-
-	decodedToks := make([]string, len(vocab.Values))
-	for i := range vocab.Values {
-		token, err := proc.Decode([]int32{int32(i)})
-		if err != nil {
-			return err
-		}
-		decodedToks[i] = token
-	}
-
-	var err error
-	for _, node := range stateToNodeMap {
-		err = CreateMask(node, proc, decodedToks)
-		if err != nil {
+func (b *PDAGraphBuilder) preComputeValidStates() error {
+	for _, node := range b.stateToNodeMap {
+		if err := b.CreateMask(node); err != nil {
 			return err
 		}
 	}
 	return nil
 }
 
-func CreateMask(node *PDANode, proc model.TextProcessor, decodedToks []string) error {
-	for i := range decodedToks {
-		token := decodedToks[i]
+func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
+	for i := range b.decodedToks {
+		token := b.decodedToks[i]
 		// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
-		if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
+		if b.proc.Is(uint32(i), model.SpecialEOS) || b.proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
 			continue
 		}
-		valid := true
 		curNode := node
+		valid := true
 		consumedSpecialRunes := make(map[rune]bool)
-		var err error
 		for _, r := range token {
-			valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
-			if err != nil {
-				return err
-			}
-			if !valid {
+			curNode, valid = isRuneValid(r, curNode, consumedSpecialRunes)
+			if curNode == nil || !valid {
 				break
 			}
 		}
 		if valid {
-			// cur node allows skipping
 			node.MaskTokenIDToNode[int32(i)] = curNode
 		}
 	}
 	return nil
 }
 
-// TODO: garbage interface plz fix
-func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
+func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA, bool) {
 	if consumedSpecialRunes[r] {
-		return false, nil, nil
+		return nil, false
 	}
 
 	specialRune := slices.Contains(stringInvalidRunes, r)
 	if specialRune {
 		if curNode.State == StateInString || curNode.State == StateInObjectKey {
-			return false, nil, nil
+			return nil, false
 		}
 	}
 
@@ -235,17 +266,17 @@ func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (
 	if nextNode, ok := curNode.TransitionEdges[r]; ok {
 		if specialRune {
 			if curNode.State == nextNode.State {
-				return false, nil, nil
+				return nil, false
 			}
 			consumedSpecialRunes[r] = true
 		}
-		return true, nextNode, nil
+		return nextNode, true
 	}
 
 	// Check for sentinel value - if present, any rune is valid
 	if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
-		return true, nextNode, nil
+		return nextNode, true
 	}
 
-	return false, nil, nil
+	return nil, false
 }

+ 70 - 48
sample/pushdown_runner.go

@@ -11,17 +11,17 @@ import (
 
 // TODO: safety in case of invalid json
 // TODO: interfaces to cleanup with return values
+// TODO this interface shouldn't be the sampler - should just use Sampler
+// TODO: add penalties for string \n stuff
 type PushdownSampler struct {
-	// stateful
-	curNode        *PDANode
-	proc           model.TextProcessor
-	stateToNodeMap map[JSONState]*PDANode
-	braceStack     []rune
-	stateCounter   uint32
+	PDAGraphBuilder
+	curNode      *PDA
+	braceStack   []rune
+	stateCounter uint32
 }
 
 // graph should be built once and reused per tokenizer
-func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
+func NewPushdownSampler(proc model.TextProcessor) (*PushdownSampler, error) {
 	start := time.Now()
 
 	fmt.Println("--------------------------------")
@@ -32,27 +32,38 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
 	before := m.Alloc
 	fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
 
-	startNode, stateToNodeMap, err := BuildGraph(proc)
-	if err != nil {
-		panic(err)
+	vocab := proc.GetVocabulary()
+	decodedToks := make([]string, len(vocab.Values))
+	for i := range vocab.Values {
+		token, err := proc.Decode([]int32{int32(i)})
+		if err != nil {
+			return nil, err
+		}
+		decodedToks[i] = token
 	}
-	err = PreComputeValidStates(stateToNodeMap, proc)
-	if err != nil {
-		panic(err)
+
+	gb := &PDAGraphBuilder{
+		proc:        proc,
+		decodedToks: decodedToks,
+	}
+
+	if err := gb.BuildGraph(); err != nil {
+		return nil, err
 	}
+
 	runtime.ReadMemStats(&m)
 	after := m.Alloc
 	fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
 	fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
 	fmt.Printf("Graph build time = %v\n", time.Since(start))
 
+	// TODO: this can be simplified
 	return &PushdownSampler{
-		curNode:        startNode,
-		proc:           proc,
-		stateToNodeMap: stateToNodeMap,
-		braceStack:     []rune{},
-		stateCounter:   0,
-	}
+		curNode:         gb.stateToNodeMap[StateStart],
+		PDAGraphBuilder: *gb,
+		braceStack:      []rune{},
+		stateCounter:    0,
+	}, nil
 }
 
 // TODO: need to add resampling logic if the first sample was not good
@@ -66,14 +77,7 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
 		// force finish if no braces left
 		if len(s.braceStack) == 0 {
 			s.curNode = NewPDANode(StateTerminate)
-			for i := range logits {
-				if s.proc.Is(uint32(i), model.SpecialEOS) {
-					logits[i] = 1.0
-				} else {
-					logits[i] = math.Inf(-1)
-				}
-			}
-			return logits, nil
+			return forceFinish(s, logits)
 		}
 
 		logits, err := s.maskLogits(logits, s.curNode)
@@ -82,18 +86,14 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
 		}
 		return logits, nil
 
+	case StateTerminate:
+		return forceFinish(s, logits)
+
 	case StateInObjectEnd:
 		// force finish if no braces left
 		if len(s.braceStack) == 0 {
 			s.curNode = NewPDANode(StateTerminate)
-			for i := range logits {
-				if s.proc.Is(uint32(i), model.SpecialEOS) {
-					logits[i] = 1.0
-				} else {
-					logits[i] = math.Inf(-1)
-				}
-			}
-			return logits, nil
+			return forceFinish(s, logits)
 		}
 
 		peek := s.braceStack[len(s.braceStack)-1]
@@ -112,22 +112,13 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
 		if peek == rune('[') {
 			s.curNode = s.stateToNodeMap[StateInListComma]
 		}
+
 		logits, err := s.maskLogits(logits, s.curNode)
 		if err != nil {
 			return nil, err
 		}
 		return logits, nil
 
-	case StateTerminate:
-		for i := range logits {
-			if s.proc.Is(uint32(i), model.SpecialEOS) {
-				logits[i] = 1.0
-			} else {
-				logits[i] = math.Inf(-1)
-			}
-		}
-		return logits, nil
-
 	default:
 		fmt.Println("masking logits current state", s.curNode.State)
 		logits, err := s.maskLogits(logits, s.curNode)
@@ -138,13 +129,24 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
 	}
 }
 
+func forceFinish(s *PushdownSampler, logits []float64) ([]float64, error) {
+	for i := range logits {
+		if s.proc.Is(uint32(i), model.SpecialEOS) {
+			logits[i] = 1.0
+		} else {
+			logits[i] = math.Inf(-1)
+		}
+	}
+	return logits, nil
+}
+
 func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
 	fmt.Println("current state - updating", s.curNode.State)
 	mappedString, err := s.proc.Decode(tokenSlice)
 	if err != nil {
 		return err
 	}
-	fmt.Println(">>> mappedString", mappedString)
+	fmt.Printf(">>> mappedString: %q\n", mappedString)
 
 	// TODO: should force closing for all braces - not doing square yet
 	for _, r := range mappedString {
@@ -198,7 +200,8 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
 }
 
 // greedy sample + backtrack?
-func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
+func (s *PushdownSampler) maskLogits(logits []float64, node *PDA) ([]float64, error) {
+
 	// Create a new slice with same length as logits, initialized to -Inf
 	maskedLogits := make([]float64, len(logits))
 	for i := range maskedLogits {
@@ -215,4 +218,23 @@ func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64
 	return maskedLogits, nil
 }
 
-// TODO: add penalties for string \n stuff
+func (s *PushdownSampler) fastMaskLogits(logits []float64, node *PDA) ([]float64, error) {
+	maxLogit := math.Inf(-1)
+	maxIndex := -1
+
+	// Find the maximum logit value among valid tokens
+	for tokenID := range node.MaskTokenIDToNode {
+		if int(tokenID) < len(logits) && logits[tokenID] > maxLogit {
+			maxLogit = logits[tokenID]
+			maxIndex = int(tokenID)
+		}
+	}
+
+	if maxIndex == -1 {
+		return nil, fmt.Errorf("no valid tokens found in mask")
+	}
+
+	logits[0] = float64(maxIndex)
+	return logits, nil
+	// return maxIndex, nil
+}

+ 69 - 92
sample/sample.go

@@ -6,6 +6,8 @@ import (
 	"math"
 	"slices"
 
+	pq "github.com/emirpasic/gods/v2/queues/priorityqueue"
+	"golang.org/x/exp/rand"
 	"gonum.org/v1/gonum/floats"
 	"gonum.org/v1/gonum/stat/sampleuv"
 )
@@ -15,33 +17,34 @@ type Transform interface {
 }
 
 type Sampler interface {
-	Sample([]float64) (int, error)
+	Sample([]float32, ...Transform) (int, error)
 }
 
-type SamplerConfig struct {
-	transforms []Transform
-	sampler    Sampler
-}
-
-// NewSampler creates a sampler with the given transforms and sampling method
-func NewSampler(transforms []Transform, sampler Sampler) *SamplerConfig {
-	return &SamplerConfig{
-		transforms: transforms,
-		sampler:    sampler,
+// TODO(parthsareen): potentially cache softmax values
+func softmax(logits []float64) []float64 {
+	var sum float64
+	tt := make([]float64, len(logits))
+	for i, v := range logits {
+		tt[i] = math.Exp(v)
+		sum += tt[i]
 	}
+	floats.Scale(1/sum, tt)
+	return tt
 }
 
 type Temperature float64
 
 func (t Temperature) Apply(logits []float64) ([]float64, error) {
+	if t == 0 {
+		return nil, errors.New("use Greedy sampler instead of Temperature(0)")
+	}
 	if t < 0 || t > 2 {
 		return nil, errors.New("temperature must be between 0 and 2")
 	}
+	temp := math.Max(float64(t), 1e-7)
 
 	// subtracting max logit to avoid under/overflow
-	maxLogit := floats.Max(logits)
-
-	temp := math.Max(float64(t), 1e-7)
+	maxLogit := slices.Max(logits)
 	for i := range logits {
 		logits[i] = (logits[i] - maxLogit) / temp
 	}
@@ -49,52 +52,41 @@ func (t Temperature) Apply(logits []float64) ([]float64, error) {
 	return logits, nil
 }
 
-type softmax struct{}
-
-func Softmax() Transform {
-	return softmax{}
+type logitMap struct {
+	index int
+	logit float64
 }
 
-func (softmax) Apply(logits []float64) ([]float64, error) {
-	return computeSoftmax(logits), nil
-}
-
-// TODO: cache softmax values
-func computeSoftmax(logits []float64) []float64 {
-	copiedLogits := make([]float64, len(logits))
-	copy(copiedLogits, logits)
-	for i := range copiedLogits {
-		copiedLogits[i] = math.Exp(copiedLogits[i])
-	}
-
-	floatSum := floats.Sum(copiedLogits)
-	floats.Scale(1.0/floatSum, copiedLogits)
-
-	return copiedLogits
+func logitMapComparator(a, b logitMap) int {
+	return -cmp.Compare(a.logit, b.logit)
 }
 
 type TopK int
 
+// TODO(parthsareen): avoid having to check all logits after this transform
 func (k TopK) Apply(logits []float64) ([]float64, error) {
 	if k <= 0 {
-		return nil, errors.New("k must be positive")
+		return nil, errors.New("k must be greater than 0")
 	}
 	if int(k) >= len(logits) {
 		return logits, nil
 	}
 
-	indices := make([]int, len(logits))
-	for i := range indices {
-		indices[i] = i
+	q := pq.NewWith(logitMapComparator)
+	for i, logit := range logits {
+		q.Enqueue(logitMap{index: i, logit: logit})
 	}
 
-	// sort in descending order
-	slices.SortFunc(indices, func(i, j int) int {
-		return cmp.Compare(logits[j], logits[i])
-	})
+	validLogits := make(map[int]float64)
+	for range k {
+		logitMap, _ := q.Dequeue()
+		validLogits[logitMap.index] = logitMap.logit
+	}
 
-	for _, idx := range indices[k:] {
-		logits[idx] = math.Inf(-1)
+	for i := range logits {
+		if _, ok := validLogits[i]; !ok {
+			logits[i] = math.Inf(-1)
+		}
 	}
 
 	return logits, nil
@@ -107,8 +99,7 @@ func (p TopP) Apply(logits []float64) ([]float64, error) {
 		return nil, errors.New("p must be between 0 and 1")
 	}
 
-	probs := computeSoftmax(logits)
-
+	probs := softmax(logits)
 	indices := make([]int, len(probs))
 	for i := range indices {
 		indices[i] = i
@@ -139,17 +130,11 @@ func (p MinP) Apply(logits []float64) ([]float64, error) {
 		return nil, errors.New("p must be between 0 and 1")
 	}
 
-	probs := computeSoftmax(logits)
-	copiedProbs := make([]float64, len(probs))
-	copy(copiedProbs, probs)
-
-	slices.Sort(copiedProbs)
-
-	maxProb := copiedProbs[len(copiedProbs)-1]
-	probThreshold := float64(p) * maxProb
+	probs := softmax(logits)
+	threshold := slices.Max(probs) * float64(p)
 
-	for i := range probs {
-		if probs[i] < probThreshold {
+	for i, prob := range probs {
+		if prob < threshold {
 			logits[i] = math.Inf(-1)
 		}
 	}
@@ -157,18 +142,35 @@ func (p MinP) Apply(logits []float64) ([]float64, error) {
 	return logits, nil
 }
 
-type weighed struct{}
+type weighted struct {
+	src rand.Source
+}
 
-func Weighed() Sampler {
-	return weighed{}
+func Weighted(seed *int64) Sampler {
+	var src rand.Source
+	if seed != nil {
+		src = rand.NewSource(uint64(*seed))
+	}
+	return weighted{src: src}
 }
 
-// should return single value
-func (s weighed) Sample(logits []float64) (int, error) {
+func (s weighted) Sample(logits []float32, transforms ...Transform) (int, error) {
+	logits64 := make([]float64, len(logits))
+	for i, v := range logits {
+		logits64[i] = float64(v)
+	}
+
+	var err error
+	for _, t := range transforms {
+		logits64, err = t.Apply(logits64)
+		if err != nil {
+			return -1, err
+		}
+	}
+
 	logitsCopy := make([]float64, 0, len(logits))
 	indices := make([]int, 0, len(logits))
-	// the uv sampler does not support NaN values
-	for i, logit := range logits {
+	for i, logit := range logits64 {
 		if !math.IsInf(logit, -1) {
 			logitsCopy = append(logitsCopy, logit)
 			indices = append(indices, i)
@@ -176,38 +178,13 @@ func (s weighed) Sample(logits []float64) (int, error) {
 	}
 
 	if len(logitsCopy) == 0 {
-		return -1, errors.New("no valid tokens found")
+		return -1, errors.New("no valid logits found for weighed sampling")
 	}
 
-	softmax := computeSoftmax(logitsCopy)
-	w := sampleuv.NewWeighted(softmax, nil)
+	probs := softmax(logitsCopy)
+	w := sampleuv.NewWeighted(probs, s.src)
 	if idx, ok := w.Take(); ok {
-		// returns the token ID
 		return indices[idx], nil
 	}
-	return -1, errors.New("weighed sampler failed")
-}
-
-// Sample applies transforms and samples a token ID
-func (s *SamplerConfig) Sample(input []float32) (int, error) {
-	logits := make([]float64, len(input))
-	for i, v := range input {
-		logits[i] = float64(v)
-	}
-
-	var err error
-	for _, t := range s.transforms {
-		if t == Temperature(0) {
-			// early return with greedy if temperature is 0
-			s.sampler = Greedy()
-			break
-		}
-
-		logits, err = t.Apply(logits)
-		if err != nil {
-			return -1, err
-		}
-	}
-
-	return s.sampler.Sample(logits)
+	return -1, errors.New("weighed sampler failed, no valid token found")
 }

+ 120 - 65
sample/sample_test.go

@@ -3,116 +3,129 @@ package sample
 import (
 	"fmt"
 	"math"
-	"slices"
+	"math/rand/v2"
 	"testing"
 
-	"gonum.org/v1/gonum/floats"
+	"github.com/google/go-cmp/cmp"
 )
 
 func TestTemperature(t *testing.T) {
-	logits, err := Temperature(0.5).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
+	logits, err := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
 	if err != nil {
-		t.Fatal(err)
+		t.Error(err)
+		return
 	}
-	want := []float64{-14, -12, -10, -8, -6, -4, 0}
-	if !floats.Equal(logits, want) {
-		t.Fatalf("got: %v, want: %v", logits, want)
+	want := []float64{-4, -10, 0, -14, -6, -12, -8}
+	if diff := cmp.Diff(want, logits); diff != "" {
+		t.Errorf("logits mismatch (-want +got):\n%s", diff)
 	}
 
-	if _, err := Temperature(-1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
-		t.Fatalf("expected error for temperature=-1, got %v", logits)
+	logits, err = Temperature(-1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
+	if err == nil {
+		t.Errorf("expected error for temperature=-1, got %v", logits)
+	}
+	logits, err = Temperature(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
+	if err == nil {
+		t.Errorf("expected error for temperature=0, got %v", logits)
 	}
-	if _, err := Temperature(2.1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
-		t.Fatalf("expected error for temperature=2.1, got %v", logits)
+	logits, err = Temperature(2.1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
+	if err == nil {
+		t.Errorf("expected error for temperature=2.1, got %v", logits)
 	}
 }
 
 func TestSoftmax(t *testing.T) {
-	probs, err := Softmax().Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
-	if err != nil {
-		t.Fatal(err)
-	}
+	probs := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
 
 	expectedProbs := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
-	if !floats.Equal(probs, expectedProbs) {
-		t.Fatalf("logits: %v, expectedlogits: %v", probs, expectedProbs)
+	if diff := cmp.Diff(expectedProbs, probs); diff != "" {
+		t.Errorf("probs mismatch (-want +got):\n%s", diff)
 	}
 }
 
 func TestTopK(t *testing.T) {
 	logits, err := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
 	if err != nil {
-		t.Fatal(err)
+		t.Error(err)
+		return
 	}
 	expectedlogits := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
-	if !floats.Same(logits, expectedlogits) {
-		t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
+	if diff := cmp.Diff(expectedlogits, logits); diff != "" {
+		t.Errorf("logits mismatch (-want +got):\n%s", diff)
 	}
-	logits, err = TopK(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
+
+	_, err = TopK(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
 	if err == nil {
-		t.Fatalf("expected error for k=0, got %v", logits)
+		t.Errorf("expected error for k=0, got %v", err)
 	}
 
 	logits, err = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
 	if err != nil {
-		t.Fatal(err)
+		t.Error(err)
+		return
 	}
 	expectedlogits = []float64{-3, -2, -1, 0, 1, 2, 4}
-	if !floats.Same(logits, expectedlogits) {
-		t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
+	if diff := cmp.Diff(expectedlogits, logits); diff != "" {
+		t.Errorf("logits mismatch (-want +got):\n%s", diff)
 	}
 }
 
 func TestTopP(t *testing.T) {
 	logits, err := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
 	if err != nil {
-		t.Fatal(err)
+		t.Error(err)
+		return
 	}
 	want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
-	if !floats.Same(logits, want) {
-		t.Fatalf("got: %v, want: %v", logits, want)
+	if diff := cmp.Diff(want, logits); diff != "" {
+		t.Errorf("logits mismatch (-want +got):\n%s", diff)
 	}
-	logits, err = TopP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
+
+	_, err = TopP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
 	if err == nil {
-		t.Fatalf("expected error for p=1.0, got %v", logits)
+		t.Error("expected error for p=1.0")
 	}
-	logits, err = TopP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
+	_, err = TopP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
 	if err == nil {
-		t.Fatalf("expected error for p=0.0, got %v", logits)
+		t.Error("expected error for p=0.0")
 	}
 }
 
 func TestMinP(t *testing.T) {
-	logits, err := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
+	logits, err := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
 	if err != nil {
-		t.Fatal(err)
+		t.Error(err)
+		return
 	}
-	want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 3, 4}
-	if !floats.Same(logits, want) {
-		t.Fatalf("got: %v, want: %v", logits, want)
+	want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
+	if diff := cmp.Diff(want, logits); diff != "" {
+		t.Errorf("logits mismatch (-want +got):\n%s", diff)
 	}
-	logits, err = MinP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
+
+	_, err = MinP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
 	if err == nil {
-		t.Fatalf("expected error for p=1.0, got %v", logits)
+		t.Error("expected error for p=1.0")
 	}
-	logits, err = MinP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
+	_, err = MinP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
 	if err == nil {
-		t.Fatalf("expected error for p=0.0, got %v", logits)
+		t.Error("expected error for p=0.0")
 	}
 }
 
 func TestWeighed(t *testing.T) {
-	idx, err := Weighed().Sample([]float64{math.Inf(-1), 2, math.Inf(-1), math.Inf(-1)})
+	idx, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
 	if err != nil {
-		t.Fatal(err)
+		t.Error(err)
+		return
 	}
 	want := 1
-	if idx != want {
-		t.Fatalf("got: %v, want: %v", idx, want)
+	if diff := cmp.Diff(want, idx); diff != "" {
+		t.Errorf("index mismatch (-want +got):\n%s", diff)
 	}
-	idx, err = Weighed().Sample([]float64{math.Inf(-1), math.Inf(-1), math.Inf(-1)})
+
+	idx, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
 	if err == nil {
-		t.Fatalf("expected error for no valid tokens, got %v", idx)
+		t.Error("expected error for no valid tokens, got index", idx)
 	}
 }
 
@@ -132,27 +145,32 @@ func TestSample(t *testing.T) {
 		id:        3,
 		callOrder: &callOrder,
 	}
-	sampler := NewSampler([]Transform{mock1, mock2, mock3}, Greedy())
 
-	got, err := sampler.Sample(input)
+	got, err := Greedy().Sample(input, mock1, mock2, mock3)
 	if err != nil {
-		t.Fatal(err)
+		t.Error(err)
+		return
 	}
 
-	if !slices.Equal(callOrder, []int{1, 2, 3}) {
-		t.Errorf("got %v, want %v", callOrder, []int{1, 2, 3})
+	want := 3 // Greedy sampler should pick highest logit
+	if diff := cmp.Diff(want, got); diff != "" {
+		t.Errorf("sampled index mismatch (-want +got):\n%s", diff)
 	}
 
-	want := 3 // Greedy sampler should pick highest logit
-	if got != want {
-		t.Errorf("got %v, want %v", got, want)
+	_, err = Weighted(nil).Sample(input, mock1, mock2, mock3)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	wantOrder := []int{1, 2, 3}
+	if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
+		t.Errorf("call order mismatch (-want +got):\n%s", diff)
 	}
 
 	errMock := &testTransform{
 		returnErr: fmt.Errorf("mock error"),
 	}
-	sampler = NewSampler([]Transform{mock1, errMock, mock2}, Greedy())
-	_, err = sampler.Sample(input)
+	_, err = Weighted(nil).Sample(input, mock1, errMock, mock2)
 	if err == nil {
 		t.Error("Expected error from sampler")
 	}
@@ -174,14 +192,51 @@ func (ts *testTransform) Apply(logits []float64) ([]float64, error) {
 	return logits, nil
 }
 
-func TestSampleTemperatureZero(t *testing.T) {
-	sampler := NewSampler([]Transform{Temperature(0)}, Greedy())
-	got, err := sampler.Sample([]float32{1, 2, 3, 4})
-	if err != nil {
-		t.Fatal(err)
+func BenchmarkTransform(b *testing.B) {
+	transforms := map[string]Transform{
+		"Temperature": Temperature(0.5),
+		"TopK":        TopK(10),
+		"TopP":        TopP(0.9),
+		"MinP":        MinP(0.2),
+	}
+
+	logits := make([]float64, 1<<16)
+	for i := range logits {
+		logits[i] = rand.Float64()
+	}
+
+	for name, transform := range transforms {
+		b.Run(name, func(b *testing.B) {
+			b.ResetTimer()
+			for range b.N {
+				_, err := transform.Apply(logits)
+				if err != nil {
+					b.Error(err)
+				}
+			}
+		})
 	}
-	want := 3 // Greedy sampler should pick highest logit index
-	if got != want {
-		t.Fatalf("got: %v, want: %v", got, want)
+}
+
+func BenchmarkSample(b *testing.B) {
+	samplers := map[string]Sampler{
+		"Greedy":   Greedy(),
+		"Weighted": Weighted(nil),
+	}
+
+	logits := make([]float32, 1<<16)
+	for i := range logits {
+		logits[i] = rand.Float32()
+	}
+
+	for name, s := range samplers {
+		b.Run(name, func(b *testing.B) {
+			b.ResetTimer()
+			for range b.N {
+				if _, err := s.Sample(logits); err != nil {
+					b.Error(err)
+				}
+			}
+		})
 	}
 }

+ 59 - 26
sample/structured_outputs.go

@@ -8,27 +8,45 @@ import (
 	"github.com/ollama/ollama/model"
 )
 
-type SOSampler struct {
+type JSONSampler struct {
 	schema        *Schema
 	propIdx       int
-	propToNodeMap map[string]*PDANode
+	propToNodeMap map[string]*PDA
 	pdaSampler    *PushdownSampler
 	decodedToks   []string
 }
 
-func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) {
-	pdaSampler := NewPushdownSampler(proc)
+func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, error) {
+	pdaSampler, err := NewPushdownSampler(proc)
+	if err != nil {
+		return nil, err
+	}
+
+	if schema == nil {
+		return &JSONSampler{
+			schema:        nil,
+			propIdx:       -1,
+			propToNodeMap: nil,
+			pdaSampler:    pdaSampler,
+		}, nil
+	}
 
-	so := &SOSampler{
+	fmt.Println("schema not nil")
+	so := &JSONSampler{
 		schema:        schema,
 		propIdx:       -1,
-		propToNodeMap: make(map[string]*PDANode),
+		propToNodeMap: make(map[string]*PDA),
 		pdaSampler:    pdaSampler,
 	}
 
 	so.schemaToGraph()
 
-	// This is prob slow
+	// Benchmark token decoding
+	start := time.Now()
+	var m runtime.MemStats
+	runtime.ReadMemStats(&m)
+	before := m.Alloc
+
 	vocab := proc.GetVocabulary()
 	decodedToks := make([]string, len(vocab.Values))
 	for i := range vocab.Values {
@@ -40,14 +58,18 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
 	}
 	so.decodedToks = decodedToks
 
+	runtime.ReadMemStats(&m)
+	after := m.Alloc
+	fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
+	fmt.Printf("Token decode time = %v\n", time.Since(start))
+
 	fmt.Println("--------------------------------")
 	fmt.Println("SOSampler")
 	fmt.Println("--------------------------------")
 	// Benchmark this section
-	start := time.Now()
-	var m runtime.MemStats
+	start = time.Now()
 	runtime.ReadMemStats(&m)
-	before := m.Alloc
+	before = m.Alloc
 
 	// TODO: still messed up
 	// TODO: recursion use case
@@ -57,12 +79,12 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
 		// propName -> node
 		curState := node.State
 		fromNode := node
-		CreateMask(fromNode, proc, decodedToks)
+		so.pdaSampler.CreateMask(fromNode)
 		for curState == StateInStructuredKey {
 			// there is only one edge
 			for r, toNode := range fromNode.TransitionEdges {
 				// fmt.Println("rune", r, "edge", toNode.State)
-				CreateMask(toNode, proc, decodedToks)
+				so.pdaSampler.CreateMask(toNode)
 				fmt.Printf("created mask for %c\n", r)
 				curState = toNode.State
 				fmt.Println("next state", curState)
@@ -73,7 +95,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
 	}
 
 	runtime.ReadMemStats(&m)
-	after := m.Alloc
+	after = m.Alloc
 	fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
 	fmt.Printf("Mask creation time = %v\n", time.Since(start))
 	fmt.Println("--------------------------------")
@@ -81,7 +103,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
 	return so, nil
 }
 
-func (s *SOSampler) schemaToGraph() {
+func (s *JSONSampler) schemaToGraph() {
 	schemaType := s.schema.EffectiveType()
 	switch schemaType {
 	case "object":
@@ -91,18 +113,18 @@ func (s *SOSampler) schemaToGraph() {
 		for _, prop := range s.schema.Properties {
 			// name of key
 			name := prop.Name
-			keyNode := &PDANode{
+			keyNode := &PDA{
 				State:             StateInStructuredKey, // this is unchanging, will impact sampling
-				TransitionEdges:   make(map[rune]*PDANode),
-				MaskTokenIDToNode: make(map[int32]*PDANode),
+				TransitionEdges:   make(map[rune]*PDA),
+				MaskTokenIDToNode: make(map[int32]*PDA),
 			}
 
 			prevNode := keyNode
 			for _, r := range name {
-				runeNode := &PDANode{
+				runeNode := &PDA{
 					State:             StateInStructuredKey, // this is unchanging, will impact sampling
-					TransitionEdges:   make(map[rune]*PDANode),
-					MaskTokenIDToNode: make(map[int32]*PDANode),
+					TransitionEdges:   make(map[rune]*PDA),
+					MaskTokenIDToNode: make(map[int32]*PDA),
 				}
 				fmt.Println("runeNode created", runeNode.State)
 				fmt.Printf("runeNode created %c\n", r)
@@ -117,9 +139,14 @@ func (s *SOSampler) schemaToGraph() {
 			fmt.Println("name", name, "keyNode", keyNode.State)
 		}
 	}
+	// TODO: do values + recursion
 }
 
-func (s *SOSampler) Apply(logits []float64) ([]float64, error) {
+func (s *JSONSampler) Apply(logits []float64) ([]float64, error) {
+	if s.schema == nil {
+		return s.pdaSampler.Apply(logits)
+	}
+
 	switch s.pdaSampler.curNode.State {
 	// doesnt account for multi rune case
 	case StateInObjectKey:
@@ -148,17 +175,18 @@ func (s *SOSampler) Apply(logits []float64) ([]float64, error) {
 			// todo: if i incremenet propidx then i know im in last value as well
 			switch s.pdaSampler.curNode.State {
 			case StateInObjectEnd:
-				fmt.Println("<<<<< in obj end- generating mask for", s.pdaSampler.curNode.State)
-				s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDANode)
+				fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State)
+				s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA)
 				s.pdaSampler.curNode = NewPDANode(StateTerminate)
 				s.propIdx++
 
+			// TODO: this needs to be optimized in some way, computing mask on the fly is expensive
 			case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
 				fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
 				delete(s.pdaSampler.curNode.TransitionEdges, ',')
-				s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDANode)
+				s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA)
 
-				CreateMask(s.pdaSampler.curNode, s.pdaSampler.proc, s.decodedToks)
+				s.pdaSampler.CreateMask(s.pdaSampler.curNode)
 				s.propIdx++
 			}
 		}
@@ -167,12 +195,17 @@ func (s *SOSampler) Apply(logits []float64) ([]float64, error) {
 
 }
 
-func (s *SOSampler) UpdateState(tokenSlice []int32) error {
+func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
 	err := s.pdaSampler.UpdateState(tokenSlice)
 	if err != nil {
 		return err
 	}
 
+	if s.schema == nil {
+		// Don't need to update state for unconstrained JSON sampling
+		return nil
+	}
+
 	switch s.pdaSampler.curNode.State {
 	case StateInObjectKey:
 		s.propIdx++