ParthSareen 3 miesięcy temu
rodzic
commit
c56a8b7749

+ 10 - 3
model/cmd/main.go

@@ -104,6 +104,8 @@ func temp() error {
 		}
 	}
 
+	pdaSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
+	var stringBuffer string
 	var offset int
 	for range args.n {
 		logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
@@ -118,7 +120,10 @@ func temp() error {
 		}
 
 		// do sampling
-		f64s, err = sample.Sample(f64s, sample.Greedy())
+		// []ints back
+		// ints map to sampled logits
+		f64s, err = sample.Sample(f64s, pdaSampler, sample.Greedy())
+
 		if err != nil {
 			return err
 		}
@@ -129,6 +134,7 @@ func temp() error {
 				outputIDs = append(outputIDs, int32(f64))
 			}
 		}
+		pdaSampler.UpdateState(outputIDs)
 
 		if len(outputIDs) == 0 {
 			break
@@ -141,8 +147,9 @@ func temp() error {
 			return err
 		}
 
-		fmt.Print(s)
-
+		// fmt.Print(s)
+		stringBuffer += s
+		fmt.Println("--- stringBuffer", stringBuffer)
 		inputIDs = append(inputIDs, outputIDs...)
 		if args.cache {
 			offset = len(inputIDs) - 1

+ 1 - 0
model/cmd/test.go

@@ -0,0 +1 @@
+package main

+ 5 - 0
model/process_text.go

@@ -21,6 +21,7 @@ type TextProcessor interface {
 	Encode(string) ([]int32, error)
 	Decode([]int32) (string, error)
 	Is(uint32, Special) bool
+	GetVocabulary() *Vocabulary
 }
 
 type Vocabulary struct {
@@ -104,6 +105,10 @@ type BytePairEncoding struct {
 	*Vocabulary
 }
 
+func (bpe BytePairEncoding) GetVocabulary() *Vocabulary {
+	return bpe.Vocabulary
+}
+
 func (bpe BytePairEncoding) split(s string) ([]string, error) {
 	re, err := regexp2.Compile(bpe.Pretokenizer, regexp2.Unicode|regexp2.RE2)
 	if err != nil {

+ 0 - 220
sample/fast_json.go

@@ -1,11 +1,7 @@
 package sample
 
 import (
-	"errors"
 	"fmt"
-	"math"
-
-	"github.com/ollama/ollama/model"
 )
 
 type JSONState int
@@ -136,219 +132,3 @@ func (s JSONState) String() string {
 		return fmt.Sprintf("Unknown state: %d", s)
 	}
 }
-
-type JSONSampler struct {
-	curNode        *Node
-	proc           model.TextProcessor
-	stack          []*Node
-	bracketCounter int
-}
-
-func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
-	// fmt.Println("Creating new JSON sampler")
-	startNode, err := buildStateMachine(proc)
-	if err != nil {
-		return nil, err
-	}
-	js := &JSONSampler{
-		curNode:        startNode,
-		proc:           proc,
-		stack:          []*Node{},
-		bracketCounter: 0,
-	}
-
-	return js, nil
-}
-
-func isTokenSubset(subset, superset []int32) bool {
-	freq1 := make(map[int32]int)
-	freq2 := make(map[int32]int)
-
-	for _, v := range subset {
-		freq1[v]++
-	}
-	for _, v := range superset {
-		freq2[v]++
-	}
-	isSubset := true
-	for k, count1 := range freq1 {
-		count2 := freq2[k]
-		if count1 > count2 {
-			isSubset = false
-			break
-		}
-	}
-	return isSubset
-}
-
-func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
-	// fmt.Printf("Updating state with token: %v\n", tokenSlice)
-	// fmt.Printf("Current state: %s\n", s.curNode.State)
-
-	// fmt.Println("tokenSlice", tokenSlice)
-	// todo: account for strings here
-
-	objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
-	if err != nil {
-		return err
-	}
-
-	// only move to terminate state if stack is empty
-	if s.curNode.State == StateInObjectEnd {
-		fmt.Println("debug: node.State", s.curNode.State)
-		if len(s.stack) > 0 {
-			s.stack = s.stack[:len(s.stack)-1]
-			fmt.Println("popped and cur state", s.curNode.State)
-			return nil
-		}
-		return nil
-	}
-
-	for node, edge := range s.curNode.TransitionEdges {
-		for _, validToken := range edge {
-			if isTokenSubset(tokenSlice, validToken) {
-				s.curNode = node
-				for _, token := range objectTokens {
-					if isTokenSubset(tokenSlice, token) {
-						fmt.Println("Appending to stack", s.curNode.State)
-						s.stack = append(s.stack, s.curNode)
-					}
-				}
-				// fmt.Printf("Transitioned to state: %s\n", node.State)
-				return nil
-			}
-		}
-	}
-	for node, edge := range s.curNode.TransitionEdges {
-		for _, validToken := range edge {
-			if len(validToken) == 1 && validToken[0] == -1 || validToken[0] == -2 {
-				s.curNode = node
-				// fmt.Printf("Accepting any token, staying in state: %s\n", node.State)
-				return nil
-			}
-		}
-	}
-	fmt.Println("invalid token ", tokenSlice)
-	dec, err := s.proc.Decode(tokenSlice)
-	if err != nil {
-		return err
-	}
-	fmt.Println("decoded token ", dec)
-	return errors.New("invalid token")
-}
-
-func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
-	fmt.Printf("Sampling in state: %s\n", s.curNode.State)
-	var err error
-
-	switch s.curNode.State {
-	case StateTerminate:
-		for i := range logits {
-			if s.proc.Is(uint32(i), model.SpecialEOS) {
-				logits[i] = 1.0
-			} else {
-				logits[i] = math.NaN()
-			}
-		}
-		return logits, nil
-
-	case StateInInt:
-		validStates := []int32{}
-		minus, err := s.proc.Encode("-")
-		if err != nil {
-			return nil, err
-		}
-		digits := make([][]int32, 10)
-		for i := 0; i < 10; i++ {
-			digits[i], err = s.proc.Encode(fmt.Sprintf("%d", i))
-			if err != nil {
-				return nil, err
-			}
-		}
-		// Allow "-" and digits 0-9 at start
-		for i := range logits {
-			for _, d := range digits {
-				if len(d) == 1 && int32(i) == d[0] {
-					validStates = append(validStates, int32(i))
-				}
-			}
-			if len(minus) == 1 && int32(i) == minus[0] {
-				validStates = append(validStates, int32(i))
-			}
-		}
-		return logits, nil
-
-	case StateInString:
-		penalizeNewlineVariants := []string{"\n", " \"\n"}
-		penalizeNewlineToks, err := ComputeTokenVariants(penalizeNewlineVariants, s.proc)
-		if err != nil {
-			return nil, err
-		}
-		penalizeNewlineToks = append(penalizeNewlineToks, []int32{702})
-		logits, err = s.maskSpecificLogits(logits, penalizeNewlineToks)
-		if err != nil {
-			return nil, err
-		}
-		validStates := getValidStates(s.curNode)
-		logits, err = s.maskLogits(logits, validStates)
-		if err != nil {
-			return nil, err
-		}
-		return logits, nil
-
-	default:
-		validStates := getValidStates(s.curNode)
-		logits, err = s.maskLogits(logits, validStates)
-		if err != nil {
-			return nil, err
-		}
-		return logits, nil
-	}
-}
-
-func getValidStates(node *Node) []int32 {
-	validStates := []int32{}
-	for _, edge := range node.TransitionEdges {
-		for _, token := range edge {
-			validStates = append(validStates, token...)
-		}
-	}
-	return validStates
-}
-
-func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) {
-	// fmt.Printf("Masking logits with valid states: %v\n", validStates)
-	// todo: this can prob be more efficient
-	for i := range logits {
-		isValid := false
-		for _, token := range validStates {
-			if token == -1 {
-				// fmt.Println("Found sentinel token, returning unmasked logits")
-				return logits, nil
-			}
-			if i == int(token) {
-				// fmt.Printf("Found valid token: %d\n", token)
-				isValid = true
-				break
-			}
-		}
-		if !isValid {
-			logits[i] = math.NaN()
-		}
-	}
-	return logits, nil
-}
-
-func (s *JSONSampler) maskSpecificLogits(logits []float64, tokensToMask []token) ([]float64, error) {
-	// fmt.Printf("Masking specific logits: %v\n", tokensToMask)
-	for i := range logits {
-		for _, token := range tokensToMask {
-			for _, chunked := range token {
-				if int(chunked) == i {
-					logits[i] = math.NaN()
-				}
-			}
-		}
-	}
-	return logits, nil
-}

+ 296 - 0
sample/hid.txt

@@ -0,0 +1,296 @@
+package sample
+
+import (
+	"slices"
+
+	"github.com/ollama/ollama/model"
+)
+
+var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
+
+var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
+var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
+
+var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
+
+var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
+
+var validNullRunes = []rune{'n', 'u', 'l', 'l'}
+
+type PDANode struct {
+	State             JSONState
+	TransitionEdges   map[rune]*PDANode
+	MaskTokenIDToNode map[int32]JSONState
+}
+
+func NewPDANode(state JSONState) *PDANode {
+	return &PDANode{
+		State:             state,
+		TransitionEdges:   make(map[rune]*PDANode),
+		MaskTokenIDToNode: make(map[int32]JSONState),
+	}
+}
+
+func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
+	stateToNodeMap := make(map[JSONState]*PDANode)
+
+	startNode := NewPDANode(StateStart)
+	stateToNodeMap[StateStart] = startNode
+
+	objNode := NewPDANode(StateInObject)
+	stateToNodeMap[StateInObject] = objNode
+
+	objEndNode := NewPDANode(StateInObjectEnd)
+	stateToNodeMap[StateInObjectEnd] = objEndNode
+
+	objKeyNode := NewPDANode(StateInObjectKey)
+	stateToNodeMap[StateInObjectKey] = objKeyNode
+
+	objKeyEndNode := NewPDANode(StateInObjectKeyEnd)
+	stateToNodeMap[StateInObjectKeyEnd] = objKeyEndNode
+
+	colonNode := NewPDANode(StateInColon)
+	stateToNodeMap[StateInColon] = colonNode
+
+	commaNode := NewPDANode(StateInComma)
+	stateToNodeMap[StateInComma] = commaNode
+
+	newlineNode := NewPDANode(StateInNewline)
+	stateToNodeMap[StateInNewline] = newlineNode
+
+	spaceNode := NewPDANode(StateInSpace)
+	stateToNodeMap[StateInSpace] = spaceNode
+
+	spaceObjNode := NewPDANode(StateInObjSpace)
+	stateToNodeMap[StateInObjSpace] = spaceObjNode
+
+	tabNode := NewPDANode(StateInTab)
+	stateToNodeMap[StateInTab] = tabNode
+
+	stringNode := NewPDANode(StateInString)
+	stateToNodeMap[StateInString] = stringNode
+
+	stringEndNode := NewPDANode(StateInStringEnd)
+	stateToNodeMap[StateInStringEnd] = stringEndNode
+
+	listNode := NewPDANode(StateInList)
+	stateToNodeMap[StateInList] = listNode
+
+	listCommaNode := NewPDANode(StateInListComma)
+	stateToNodeMap[StateInListComma] = listCommaNode
+
+	listEndNode := NewPDANode(StateListEnd)
+	stateToNodeMap[StateListEnd] = listEndNode
+
+	numberNode := NewPDANode(StateInNumber)
+	stateToNodeMap[StateInNumber] = numberNode
+
+	boolNode := NewPDANode(StateInBool)
+	stateToNodeMap[StateInBool] = boolNode
+
+	nullNode := NewPDANode(StateInNull)
+	stateToNodeMap[StateInNull] = nullNode
+
+	// Defined with structured outputs only
+	intNode := NewPDANode(StateInInt)
+	stateToNodeMap[StateInInt] = intNode
+
+	// TODO:
+	// consider adding a node to just point to values, could be good to compute that
+	// mask rather than many different nodes
+
+	// Connect nodes
+	// TODO: if all are single tokens then this can just be connected instead of defining the token
+	startNode.TransitionEdges['{'] = objNode
+
+	objNode.TransitionEdges['"'] = objKeyNode
+	objNode.TransitionEdges['\n'] = newlineNode
+	// objNode.TransitionEdges['\t'] = tabNode
+
+	newlineNode.TransitionEdges['"'] = objKeyNode
+	newlineNode.TransitionEdges['\t'] = tabNode
+
+	tabNode.TransitionEdges['"'] = objKeyNode
+	// tabNode.TransitionEdges['\t'] = tabNode
+
+	objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
+	objKeyNode.TransitionEdges['"'] = objKeyEndNode
+
+	objKeyEndNode.TransitionEdges[':'] = colonNode
+	objEndNode.TransitionEdges[' '] = spaceNode
+
+	// where values should be
+	// this could be combined but the probs might change, we're alr doing a skip ahead
+	colonNode.TransitionEdges[' '] = spaceNode
+
+	// Leads to a value
+	spaceNode.TransitionEdges['"'] = stringNode
+	spaceNode.TransitionEdges['['] = listNode
+	spaceNode.TransitionEdges['{'] = objNode
+
+	for _, r := range validNumberRunes {
+		spaceNode.TransitionEdges[r] = numberNode
+	}
+	for _, r := range validBoolRunes {
+		spaceNode.TransitionEdges[r] = boolNode
+	}
+
+	for _, r := range validNullRunes {
+		spaceNode.TransitionEdges[r] = nullNode
+	}
+
+	// Values
+	// string node
+	stringNode.TransitionEdges[rune(-1)] = stringNode
+	stringNode.TransitionEdges['"'] = stringEndNode
+
+	stringEndNode.TransitionEdges[','] = commaNode
+	stringEndNode.TransitionEdges['}'] = objEndNode
+	stringEndNode.TransitionEdges[']'] = listEndNode
+
+	// TODO: add counters for allowable number of decimals, e, E, etc
+	// number node
+	for _, r := range validNumberRunes {
+		numberNode.TransitionEdges[r] = numberNode
+	}
+	numberNode.TransitionEdges[','] = commaNode
+	numberNode.TransitionEdges['}'] = objEndNode
+	numberNode.TransitionEdges[']'] = listEndNode
+
+	for _, r := range validBoolRunes {
+		boolNode.TransitionEdges[r] = boolNode
+	}
+
+	// list node
+	listNode.TransitionEdges[','] = commaNode
+	listNode.TransitionEdges['"'] = stringNode
+	// squash states to a value
+	for _, r := range validNumberRunes {
+		listNode.TransitionEdges[r] = numberNode
+	}
+	for _, r := range validBoolRunes {
+		listNode.TransitionEdges[r] = boolNode
+	}
+	for _, r := range validNullRunes {
+		listNode.TransitionEdges[r] = nullNode
+	}
+
+	// null node
+	for _, r := range validNullRunes {
+		nullNode.TransitionEdges[r] = nullNode
+	}
+	nullNode.TransitionEdges[','] = commaNode
+	nullNode.TransitionEdges['}'] = objEndNode
+	nullNode.TransitionEdges[']'] = listEndNode
+
+	// list comma
+	// should point to values
+	listCommaNode.TransitionEdges['"'] = stringNode
+	listCommaNode.TransitionEdges[' '] = listCommaNode
+	listCommaNode.TransitionEdges['{'] = objNode
+	listCommaNode.TransitionEdges['\n'] = newlineNode
+
+	for _, r := range validNumberRunes {
+		listCommaNode.TransitionEdges[r] = numberNode
+	}
+	for _, r := range validBoolRunes {
+		listCommaNode.TransitionEdges[r] = boolNode
+	}
+	for _, r := range validNullRunes {
+		listCommaNode.TransitionEdges[r] = nullNode
+	}
+
+	// bool node
+	for _, r := range validBoolRunes {
+		boolNode.TransitionEdges[r] = boolNode
+	}
+	boolNode.TransitionEdges['}'] = objEndNode
+	boolNode.TransitionEdges[']'] = listEndNode
+	boolNode.TransitionEdges[','] = commaNode
+
+	listEndNode.TransitionEdges['}'] = objEndNode
+	listEndNode.TransitionEdges[','] = commaNode
+
+	commaNode.TransitionEdges['{'] = objNode
+	commaNode.TransitionEdges['\n'] = newlineNode
+	commaNode.TransitionEdges['\t'] = tabNode
+	commaNode.TransitionEdges['"'] = objKeyNode
+	commaNode.TransitionEdges[' '] = spaceObjNode
+
+	spaceObjNode.TransitionEdges['"'] = objKeyNode
+
+	return startNode, stateToNodeMap, nil
+}
+
+func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
+
+	vocab := proc.GetVocabulary()
+
+	decodedToks := make([]string, len(vocab.Values))
+	for i := range vocab.Values {
+		token, err := proc.Decode([]int32{int32(i)})
+		if err != nil {
+			return err
+		}
+		decodedToks[i] = token
+	}
+
+	var err error
+	for _, node := range stateToNodeMap {
+		for i := range vocab.Values {
+			token := decodedToks[i]
+			// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
+			if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" {
+				continue
+			}
+			valid := true
+			curNode := node
+			consumedSpecialRunes := make(map[rune]bool)
+			for _, r := range token {
+				valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
+				if err != nil {
+					return err
+				}
+				if !valid {
+					break
+				}
+			}
+			if valid {
+				node.MaskTokenIDToNode[int32(i)] = curNode.State
+			}
+		}
+	}
+	return nil
+}
+
+func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
+	if consumedSpecialRunes[r] {
+		return false, nil, nil
+	}
+
+	specialRune := slices.Contains(stringInvalidRunes, r)
+	if specialRune {
+		if curNode.State == StateInString || curNode.State == StateInObjectKey {
+			return false, nil, nil
+		}
+	}
+
+	// Check for specific rune transition
+	if nextNode, ok := curNode.TransitionEdges[r]; ok {
+		if specialRune {
+			if curNode.State == nextNode.State {
+				return false, nil, nil
+			}
+			// fmt.Println("special rune", r, "consumed")
+			consumedSpecialRunes[r] = true
+		}
+		return true, nextNode, nil
+	}
+
+	// Check for sentinel value - if present, any rune is valid
+	if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
+		return true, nextNode, nil
+	}
+
+	return false, nil, nil
+}

+ 0 - 104
sample/json_sampler.go

@@ -1,104 +0,0 @@
-package sample
-
-import (
-	"fmt"
-	"math"
-
-	"github.com/ollama/ollama/model"
-)
-
-type JSONState int
-
-const (
-	StateStart      JSONState = iota // Initial state
-	StateInObject                    // Inside an object {}
-	StateInArray                     // Inside an array []
-	StateInString                    // Inside a string ""
-	StateAfterKey                    // After object key, expecting :
-	StateAfterColon                  // After :, expecting value
-	StateAfterValue                  // After value, expecting , or closing bracket
-	StateDone                        // JSON parsing complete
-)
-
-type JSONSampler struct {
-	state JSONState
-	stack []string
-	proc  model.TextProcessor
-}
-
-func NewJSONSampler(proc model.TextProcessor) *JSONSampler {
-	return &JSONSampler{
-		state: StateStart,
-		proc:  proc,
-	}
-}
-
-func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
-	// Pre-decode valid tokens for current state
-	validTokens := make(map[uint32]bool)
-
-	// Always allow EOS token in any state
-	// TODO: Check for other special tokens if needed
-	for i := range logits {
-		if s.proc.Is(uint32(i), model.SpecialEOS) {
-			validTokens[uint32(i)] = true
-		}
-	}
-
-	// Build set of valid tokens based on current state
-	switch s.state {
-	case StateStart:
-		// Only allow opening brace
-		for i := range logits {
-			text, err := s.proc.Decode([]int32{int32(i)})
-			if err == nil && text == "{" {
-				validTokens[uint32(i)] = true
-			}
-		}
-	case StateInObject, StateInArray:
-		// Allow any token
-		for i := range logits {
-			validTokens[uint32(i)] = true
-		}
-	case StateInString:
-		// Allow any token except closing brace
-		for i := range logits {
-			text, err := s.proc.Decode([]int32{int32(i)})
-			if err == nil && text != "}" {
-				validTokens[uint32(i)] = true
-			}
-		}
-	case StateDone:
-		// No tokens allowed
-	}
-
-	// Mark invalid tokens as NaN in one pass
-	for i := range logits {
-		if !validTokens[uint32(i)] {
-			logits[i] = math.NaN()
-		}
-	}
-	return logits, nil
-}
-
-func (s *JSONSampler) UpdateState(tokenID int) error {
-	text, err := s.proc.Decode([]int32{int32(tokenID)})
-	if err != nil {
-		return fmt.Errorf("failed to decode token: %w", err)
-	}
-
-	switch s.state {
-	case StateStart:
-		if text != "{" {
-			return fmt.Errorf("expected {, got %s", text)
-		}
-		s.state = StateInObject
-	case StateInObject:
-		if text == "}" {
-			s.state = StateDone
-		}
-	case StateDone:
-		return fmt.Errorf("unexpected token after closing bracket: %s", text)
-	}
-	return nil
-}

+ 4 - 14
sample/sample.go

@@ -165,9 +165,10 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
 	if len(logitsCopy) == 0 {
 		return nil, errors.New("no valid tokens found")
 	}
-
-	// usually, a softmax is applied to sample from the logits
-	// in this case the uv sampler normalizes the logits so that the sum of the weights is 1
+	logitsCopy, err := computeSoftmax(logitsCopy)
+	if err != nil {
+		return nil, err
+	}
 	w := sampleuv.NewWeighted(logitsCopy, nil)
 	if v, ok := w.Take(); ok {
 		// returns the token ID
@@ -176,17 +177,6 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
 	return nil, errors.New("weighed sampler failed")
 }
 
-// TODO: remove after next PR merge
-type greedy struct{}
-
-func Greedy() Sampler {
-	return greedy{}
-}
-
-func (greedy) Sample(logits []float64) ([]float64, error) {
-	return []float64{float64(floats.MaxIdx(logits))}, nil
-}
-
 func Sample(logits []float64, samplers ...Sampler) ([]float64, error) {
 	var err error
 	for _, sampler := range samplers {

+ 0 - 5
sample/sample_test.go

@@ -3,14 +3,9 @@ package sample
 import (
 	"fmt"
 	"math"
-	"math/rand"
-	"os"
-	"runtime"
 	"slices"
 	"testing"
 
-	"runtime/trace"
-
 	"gonum.org/v1/gonum/floats"
 )
 

+ 0 - 218
sample/state_machine.go

@@ -1,218 +0,0 @@
-package sample
-
-import (
-	"fmt"
-
-	"github.com/ollama/ollama/model"
-)
-
-type token []int32
-
-type Node struct {
-	State           JSONState
-	TransitionEdges map[*Node][]token
-}
-
-func NewNode(state JSONState) *Node {
-	return &Node{
-		State:           state,
-		TransitionEdges: make(map[*Node][]token),
-	}
-}
-
-var (
-	// startToken             token
-	startTokenVariants []token
-	// endToken               token
-	// stringToken            token
-	// objectKeyToken         token
-	tabToken     token
-	spaceToken   token
-	newlineToken token
-	newlineSpace token
-	// commaToken             token
-	// commaToken2            token
-	// commaToken3            token
-	// colonToken             token
-	// colonToken2            token
-	colonTokenVariants           []token
-	commaTokenVariants           []token
-	stringTokenVariants          []token
-	endTokenVariants             []token
-	objectKeyTokenVariants       []token
-	objKeyToColonVariants        []token
-	stringToObjectKeyVariants    []token
-	stringToCommaVariants        []token
-	stringToObjectVariants       []token
-	stringEndToObjectEndVariants []token
-	stringEndToCommaVariants     []token
-)
-
-func ComputeTokenVariants(variants []string, proc model.TextProcessor) ([]token, error) {
-	var allTokens token
-	for _, variant := range variants {
-		if t, err := proc.Encode(variant); err == nil {
-			allTokens = append(allTokens, t...)
-		}
-	}
-	if len(allTokens) == 0 {
-		return nil, fmt.Errorf("no valid tokens found for variants")
-	}
-	return []token{allTokens}, nil
-}
-func initTokens(proc model.TextProcessor) error {
-	var err error
-
-	s, err := proc.Decode([]int32{761})
-	fmt.Printf("761 decoded %q\n", s)
-
-	// Compute start token variants
-	startVariants := []string{"{", " {", "{\n", " {\n"}
-	startTokenVariants, err = ComputeTokenVariants(startVariants, proc)
-	if err != nil {
-		return err
-	}
-	// Compute end token variants
-	endVariants := []string{"}", " }", "}\n", " }\n"}
-	endTokenVariants, err = ComputeTokenVariants(endVariants, proc)
-	if err != nil {
-		return err
-	}
-
-	// Compute string token variants
-	// TODO: removed \n
-	stringVariants := []string{"\"", " \""}
-	stringTokenVariants, err = ComputeTokenVariants(stringVariants, proc)
-	if err != nil {
-		return err
-	}
-	stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\",\n"}, proc)
-	if err != nil {
-		return err
-	}
-	// objectKeyTokenVariants = []token{stringTokenVariants[0], stringTokenVariants[1]}
-	objectKeyTokenVariants = stringTokenVariants
-	// Compute whitespace tokens
-	tabToken, err = proc.Encode("\t")
-	if err != nil {
-		return err
-	}
-	spaceToken, err = proc.Encode(" ")
-	if err != nil {
-		return err
-	}
-	newlineToken, err = proc.Encode("\n")
-	if err != nil {
-		return err
-	}
-	newlineSpace, err = proc.Encode(" \n")
-	if err != nil {
-		return err
-	}
-
-	// Compute colon variants
-	colonVariants := []string{":"}
-	colonTokenVariants, err = ComputeTokenVariants(colonVariants, proc)
-	if err != nil {
-		return err
-	}
-	objKeyToColonVariants, err = ComputeTokenVariants([]string{"\":"}, proc)
-	if err != nil {
-		return err
-	}
-
-	// Compute comma variants
-	commaVariants := []string{",", " ,", ",\n", "\",", "\", "}
-	commaTokenVariants, err = ComputeTokenVariants(commaVariants, proc)
-	if err != nil {
-		return err
-	}
-	fmt.Printf("commaTokenVariants: %v\n", commaTokenVariants)
-	stringToCommaVariants, err = ComputeTokenVariants([]string{"\",", "\","}, proc)
-	if err != nil {
-		return err
-	}
-
-	stringEndToCommaVariants, err = ComputeTokenVariants([]string{",", ",\n"}, proc)
-	stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\","}, proc)
-	stringToObjectVariants, err = ComputeTokenVariants([]string{"\",\n"}, proc)
-	stringEndToObjectEndVariants, err = ComputeTokenVariants([]string{"\n"}, proc)
-
-	return nil
-}
-
-func buildStateMachine(proc model.TextProcessor) (*Node, error) {
-	if err := initTokens(proc); err != nil {
-		return nil, err
-	}
-
-	startNode := NewNode(StateStart)
-	objectNode := NewNode(StateInObject)
-	objectKeyNode := NewNode(StateInObjectKey)
-	objectKeyEndNode := NewNode(StateInObjectKeyEnd)
-	stringNode := NewNode(StateInString)
-	// intNode := NewNode(StateInInt)
-	commaNode := NewNode(StateInComma)
-	colonNode := NewNode(StateInColon)
-	stringEndNode := NewNode(StateInStringEnd)
-	endNode := NewNode(StateEnd)
-	terminateNode := NewNode(StateTerminate)
-
-	sentinelToken := token([]int32{-1})
-	// intSentinelToken := token([]int32{-2})
-
-	// TODO: cleanup connections of rules
-	startNode.TransitionEdges[objectNode] = startTokenVariants
-
-	objectNode.TransitionEdges[objectKeyNode] = stringTokenVariants
-	objectNode.TransitionEdges[objectNode] = []token{newlineToken}
-	objectNode.TransitionEdges[objectNode] = []token{spaceToken}
-
-	// objectNode.TransitionEdges[objectNode] = []token{newlineToken}
-	// objectNode.TransitionEdges[objectNode] = []token{spaceToken}
-
-	objectKeyNode.TransitionEdges[objectKeyNode] = []token{sentinelToken}
-	// characterize end of object key
-	objectKeyNode.TransitionEdges[objectKeyEndNode] = stringTokenVariants
-	objectKeyNode.TransitionEdges[colonNode] = objKeyToColonVariants
-
-	// TODO: enable this - key -> object
-	// objectKeyNode.TransitionEdges[objectNode] = startTokenVariants
-
-	// objectKeyNode.TransitionEdges[intNode] = []token{sentinelToken}
-
-	// intNode.TransitionEdges[intNode] = []token{intSentinelToken}
-	// intNode.TransitionEdges[commaNode] = commaTokenVariants
-	// TODO: handle
-	// intNode.TransitionEdges[terminateNode] = endTokenVariants
-
-	commaNode.TransitionEdges[objectKeyNode] = stringTokenVariants
-	// commaNode.TransitionEdges[objectNode] = startTokenVariants
-
-	colonNode.TransitionEdges[stringNode] = stringTokenVariants
-	//TODO: enable
-	// colonNode.TransitionEdges[intNode] = []token{intSentinelToken}
-	colonNode.TransitionEdges[objectNode] = startTokenVariants
-
-	stringNode.TransitionEdges[stringNode] = []token{sentinelToken}
-	stringNode.TransitionEdges[stringEndNode] = stringTokenVariants
-	// TODO: "\""," Case not accounted for
-	stringNode.TransitionEdges[commaNode] = stringToCommaVariants
-
-	// TODO: "\"",\"" Case not accounted for
-	stringNode.TransitionEdges[objectNode] = stringToObjectVariants
-
-	stringEndNode.TransitionEdges[commaNode] = stringEndToCommaVariants
-	stringEndNode.TransitionEdges[objectNode] = stringToObjectKeyVariants
-	stringEndNode.TransitionEdges[endNode] = stringEndToObjectEndVariants
-	// stringEndNode.TransitionEdges[terminateNode] = endTokenVariants
-
-	// Should be obj end
-	// TODO: handle
-	endNode.TransitionEdges[terminateNode] = []token{}
-
-	endNode.TransitionEdges[commaNode] = commaTokenVariants
-
-	terminateNode.TransitionEdges[terminateNode] = []token{}
-	return startNode, nil
-}

+ 0 - 86
sample/structured_outputs.go

@@ -7,92 +7,6 @@ type StructuredOutput struct {
 }
 
 func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *PDANode {
-	// _, stateToNodeMap, err := BuildGraph(proc)
-	// if err != nil {
-	// 	panic(err)
-	// }
 
 	return nil
 }
-
-// func constrainGraph(graph *PDANode, schema *Schema) *PDANode {
-// 	// If no schema constraints, return original graph node
-// 	if schema == nil {
-// 		return graph
-// 	}
-
-// 	// Create a new node with same state
-// 	constrainedNode := NewPDANode(graph.State)
-
-// 	// Copy over existing transitions and masks
-// 	constrainedNode.TransitionEdges = make(map[rune]*PDANode)
-// 	for r, node := range graph.TransitionEdges {
-// 		constrainedNode.TransitionEdges[r] = node
-// 	}
-// 	constrainedNode.MaskTokenIDToNode = graph.MaskTokenIDToNode
-
-// 	// Apply schema constraints based on type
-// 	switch schema.EffectiveType() {
-// 	case "object":
-// 		// Only allow defined property names in object keys
-// 		if graph.State == StateInObjectKey {
-// 			// TODO: Add property name validation
-// 		}
-
-// 		// Constrain property values based on schema
-// 		if graph.State == StateInColon || graph.State == StateInSpace {
-// 			// Clear transitions to only allow valid types
-// 			constrainedNode.TransitionEdges = make(map[rune]*PDANode)
-
-// 			// Add transitions based on property schemas
-// 			for _, prop := range schema.Properties {
-// 				switch prop.EffectiveType() {
-// 				case "object":
-// 					if objNode, ok := graph.TransitionEdges['{']; ok {
-// 						constrainedNode.TransitionEdges['{'] = constrainGraph(objNode, prop)
-// 					}
-// 				case "array":
-// 					if arrNode, ok := graph.TransitionEdges['[']; ok {
-// 						constrainedNode.TransitionEdges['['] = constrainGraph(arrNode, prop)
-// 					}
-// 				case "string":
-// 					if strNode, ok := graph.TransitionEdges['"']; ok {
-// 						constrainedNode.TransitionEdges['"'] = constrainGraph(strNode, prop)
-// 					}
-// 				case "number":
-// 					for _, r := range validNumberRunes {
-// 						if numNode, ok := graph.TransitionEdges[r]; ok {
-// 							constrainedNode.TransitionEdges[r] = constrainGraph(numNode, prop)
-// 						}
-// 					}
-// 				case "integer":
-// 					for _, r := range validIntRunes {
-// 						if intNode, ok := graph.TransitionEdges[r]; ok {
-// 							constrainedNode.TransitionEdges[r] = constrainGraph(intNode, prop)
-// 						}
-// 					}
-// 				case "boolean":
-// 					for _, r := range []rune{'t', 'f'} {
-// 						if boolNode, ok := graph.TransitionEdges[r]; ok {
-// 							constrainedNode.TransitionEdges[r] = constrainGraph(boolNode, prop)
-// 						}
-// 					}
-// 				case "null":
-// 					if nullNode, ok := graph.TransitionEdges['n']; ok {
-// 						constrainedNode.TransitionEdges['n'] = constrainGraph(nullNode, prop)
-// 					}
-// 				}
-// 			}
-// 		}
-
-// 	case "array":
-// 		// Constrain array items based on schema
-// 		if schema.Items != nil {
-// 			for r, node := range graph.TransitionEdges {
-// 				constrainedNode.TransitionEdges[r] = constrainGraph(node, schema.Items)
-// 			}
-// 		}
-// 	}
-
-// 	return constrainedNode
-// }

BIN
sample/trace.out