ParthSareen 2 months ago
parent
commit
a4265c278a

+ 7 - 0
llama/runner/runner.go

@@ -443,6 +443,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 		s.lc.Synchronize()
 	}
 
+	var totalSamplingTime time.Duration
 	for i, seq := range s.seqs {
 		if seq == nil {
 			continue
@@ -477,8 +478,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 		}
 
 		// sample a token
+		samplingStart := time.Now()
 		token := seq.samplingCtx.Sample(s.lc, seq.iBatch)
 		seq.samplingCtx.Accept(token, true)
+		samplingTime := time.Since(samplingStart)
+		totalSamplingTime += samplingTime
+		slog.Info("sampling time", "time", samplingTime)
 		piece := s.model.TokenToPiece(token)
 
 		seq.numPredicted++
@@ -635,6 +640,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	samplingParams.Seed = uint32(req.Seed)
 	samplingParams.Grammar = req.Grammar
 
+	start := time.Now()
 	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
 		numPredict:     req.NumPredict,
 		stop:           req.Stop,
@@ -642,6 +648,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 		samplingParams: &samplingParams,
 		embedding:      false,
 	})
+	slog.Info("new sequence created", "duration", time.Since(start))
 	if err != nil {
 		http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
 		return

+ 23 - 32
model/cmd/main.go

@@ -28,7 +28,7 @@ var args struct {
 }
 
 func temp() error {
-	start := time.Now()
+	// start := time.Now()
 	flag.IntVar(&args.n, "n", 10, "number of samples")
 	flag.BoolVar(&args.debug, "debug", false, "enable debug logging")
 	flag.StringVar(&args.image, "image", "", "path to image file")
@@ -106,10 +106,12 @@ func temp() error {
 		}
 	}
 
-	// simple schema
-	// This schema maps to JSON like:
+	// Schema for a list of friends with their info
+	// Maps to JSON like:
 	// {
-	//   "name": "some string value"
+	// 	"name": "string",
+	// 	"age": integer,
+	// 	"is_available": boolean
 	// }
 	schema := &sample.Schema{
 		Name: "root",
@@ -117,20 +119,24 @@ func temp() error {
 		Properties: []*sample.Schema{
 			{Name: "name", Type: "string"},
 			{Name: "age", Type: "integer"},
-			{Name: "is_student", Type: "boolean"},
-			// {Name: "is_student", Type: "boolean"},
+			{Name: "is_available", Type: "boolean"},
 		},
 	}
 
-	// pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
-	pushdownSampler, err := sample.NewSOSampler(schema, m.(model.TextProcessor))
+	// fmt.Println("schema", schema)
+	// schema = nil
+	jsonTransform, err := sample.NewJSONSampler(m.(model.TextProcessor), schema)
 	if err != nil {
 		return err
 	}
 
+	transforms := []sample.Transform{
+		jsonTransform,
+	}
+
 	var offset int
 	var stringBuffer string
-	var ttft time.Duration
+	// var ttft time.Duration
 	var totalSamplingTime time.Duration
 	count := 0
 	for range args.n {
@@ -139,24 +145,9 @@ func temp() error {
 			return err
 		}
 
-		// f64s := make([]float64, len(f32s))
-		// for i, f32 := range f32s {
-		// 	f64s[i] = float64(f32)
-		// }
-		// samplers := []sample.Transform{
-		// pushdownSampler,
-		// sample.Weighed(),
-		// sample.TopP(0.9),
-		// sample.Weighed(),
-		// sample.Greedy(),
-		// }
-		transforms := []sample.Transform{
-			pushdownSampler,
-		}
-
 		samplingStart := time.Now()
-		sampler := sample.NewSampler(transforms, sample.Greedy())
-		sampledIdx, err := sampler.Sample(logits.Floats())
+		sampler := sample.Greedy()
+		sampledIdx, err := sampler.Sample(logits.Floats(), transforms...)
 		if err != nil {
 			return err
 		}
@@ -164,7 +155,7 @@ func temp() error {
 		samplingTime := time.Since(samplingStart)
 		totalSamplingTime += samplingTime
 
-		fmt.Println("sampling time", samplingTime)
+		// fmt.Println("sampling time", samplingTime)
 		// fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
 
 		var outputIDs []int32
@@ -184,10 +175,10 @@ func temp() error {
 			return err
 		}
 
-		if ttft == 0 {
-			ttft = time.Since(start)
-			fmt.Printf("Time to first token: %vms\n", ttft.Milliseconds())
-		}
+		// if ttft == 0 {
+		// 	ttft = time.Since(start)
+		// fmt.Printf("Time to first token: %vms\n", ttft.Milliseconds())
+		// }
 
 		// fmt.Printf("--- token: %q\n", s)
 		// fmt.Printf("--- outputIDs: %v\n", outputIDs)
@@ -195,7 +186,7 @@ func temp() error {
 		count++
 		fmt.Println("--- stringBuffer", stringBuffer)
 
-		err = pushdownSampler.UpdateState(outputIDs)
+		outputIDs, err = jsonTransform.UpdateState(outputIDs)
 		if err != nil {
 			return err
 		}

+ 0 - 49
sample/constrained.go

@@ -1,49 +0,0 @@
-package sample
-
-import (
-	"github.com/ollama/ollama/model"
-)
-
-type ConstrainedSampler struct {
-	schema        *Schema
-	propIdx       int
-	propToNodeMap map[string]*PDA
-	pdaSampler    *PushdownSampler
-	decodedToks   []string
-}
-
-func NewConstrainedSampler(proc model.TextProcessor, schema *Schema) (*ConstrainedSampler, error) {
-	pdaSampler, err := NewPushdownSampler(proc)
-	if err != nil {
-		return nil, err
-	}
-
-	// if schema == nil {
-	return &ConstrainedSampler{
-		schema:        nil,
-		propIdx:       -1,
-		propToNodeMap: nil,
-		pdaSampler:    pdaSampler,
-	}, nil
-
-}
-
-func (s *ConstrainedSampler) Apply(logits []float64) ([]float64, error) {
-	if s.schema == nil {
-		return s.pdaSampler.Apply(logits)
-	}
-
-	return nil, nil
-}
-
-func (s *ConstrainedSampler) UpdateState(tokenSlice []int32) error {
-	if err := s.pdaSampler.UpdateState(tokenSlice); err != nil {
-		return err
-	}
-
-	if s.schema == nil {
-		return nil
-	}
-
-	return nil
-}

+ 20 - 7
sample/json_types.go

@@ -41,6 +41,7 @@ const (
 	StateTerminate
 	StateInObjectEnd
 	StateTransitioningToTerminate
+	StateInListStartJSON
 )
 
 var JSONStates = []JSONState{
@@ -48,6 +49,7 @@ var JSONStates = []JSONState{
 	StateInObject,
 	StateInObjectKey,
 	StateInStructuredKey,
+	StateInStructuredValue,
 	StateNewline,
 	StateTab,
 	StateSpace,
@@ -63,6 +65,7 @@ var JSONStates = []JSONState{
 	StateInSpaceEndValue,
 	StateInNewlineEndValue,
 	StateInObjSpace,
+	StateInListStartJSON,
 	StateInList,
 	StateInListComma,
 	StateInValue,
@@ -89,6 +92,8 @@ func (s JSONState) String() string {
 		return "StateInObjectKey"
 	case StateInStructuredKey:
 		return "StateInStructuredKey"
+	case StateInStructuredValue:
+		return "StateInStructuredValue"
 	case StateNewline:
 		return "StateNewline"
 	case StateTab:
@@ -112,21 +117,27 @@ func (s JSONState) String() string {
 	case StateInTab:
 		return "StateInTab"
 	case StateInSpaceToValue:
-		return "StateInSpace"
+		return "StateInSpaceToValue"
+	case StateInSpaceEndValue:
+		return "StateInSpaceEndValue"
+	case StateInNewlineEndValue:
+		return "StateInNewlineEndValue"
 	case StateInObjSpace:
 		return "StateInObjSpace"
 	case StateInList:
 		return "StateInList"
-	case StateInListObjectEnd:
-		return "StateInListObjectEnd"
 	case StateInListComma:
 		return "StateInListComma"
+	case StateInValue:
+		return "StateInValue"
+	case StateInValueEnd:
+		return "StateInValueEnd"
 	case StateInListEnd:
 		return "StateInListEnd"
+	case StateInListObjectEnd:
+		return "StateInListObjectEnd"
 	case StateInNewline:
 		return "StateInNewline"
-	case StateInNewlineEndValue:
-		return "StateInNewlineEndValue"
 	case StateInNumber:
 		return "StateInNumber"
 	case StateInNumberEnd:
@@ -135,12 +146,14 @@ func (s JSONState) String() string {
 		return "StateInStringEnd"
 	case StateInObjectKeyEnd:
 		return "StateInObjectKeyEnd"
-	case StateInSpaceEndValue:
-		return "StateInSpaceEndValue"
 	case StateTerminate:
 		return "StateTerminate"
 	case StateInObjectEnd:
 		return "StateInObjectEnd"
+	case StateTransitioningToTerminate:
+		return "StateTransitioningToTerminate"
+	case StateInListStartJSON:
+		return "StateInListStartJSON"
 	default:
 		return fmt.Sprintf("Unknown state: %d", s)
 	}

+ 51 - 12
sample/pushdown_automata.go

@@ -37,8 +37,10 @@ Key JSON rules to consider:
 // TODO: / should be valid but an escape character
 var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
 
-var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
-var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
+var (
+	intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
+	validIntRunes   = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
+)
 
 var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
 
@@ -61,9 +63,10 @@ func NewPDANode(state JSONState) *PDA {
 }
 
 type PDAGraphBuilder struct {
-	proc           model.TextProcessor
-	decodedToks    []string
-	stateToNodeMap map[JSONState]*PDA
+	proc             model.TextProcessor
+	decodedToks      []string
+	stateToNodeMap   map[JSONState]*PDA
+	tokenToStatesMap map[int32][]JSONState
 }
 
 func (b *PDAGraphBuilder) BuildGraph() error {
@@ -73,20 +76,26 @@ func (b *PDAGraphBuilder) BuildGraph() error {
 	}
 
 	stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
-	stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
+	stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
+
+	// TODO: update naming here - and revisit values
+	stateToNodeMap[StateInListStartJSON].TransitionEdges['{'] = stateToNodeMap[StateInObject]
+	stateToNodeMap[StateInListStartJSON].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
 
 	stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
 	stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
 	stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
+	stateToNodeMap[StateInObject].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
 
 	// new line
 	stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
 	stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
 	stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
 	stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
+	// stateToNodeMap[StateInNewline].TransitionEdges['{'] = stateToNodeMap[StateInObject]
 
 	// new line end value
-	stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
+	// stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
 	stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
 	stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
 
@@ -108,6 +117,8 @@ func (b *PDAGraphBuilder) BuildGraph() error {
 	// where values should be
 	// this could be combined but the probl might change, we're alr doing a skip ahead
 	stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
+	stateToNodeMap[StateInColon].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
+
 	stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
 	stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
 	addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
@@ -117,6 +128,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
 	stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
 	addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap)
 	stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
+	stateToNodeMap[StateInSpaceToValue].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
 
 	// Values
 	// string node
@@ -125,7 +137,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
 
 	// String end node
 	addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
-	stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
+	// stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
 	stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
 
 	// TODO: add counters for allowable number of decimals, e, E, etc
@@ -134,7 +146,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
 		stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
 	}
 	addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
-	stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
+	// stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
 	stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
 
 	// list node
@@ -142,10 +154,12 @@ func (b *PDAGraphBuilder) BuildGraph() error {
 	stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
 	stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
 	stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
+	// early end
+	stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
 
 	// list end node
 	stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
-	stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
+	// stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
 	stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
 	stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
 
@@ -166,6 +180,9 @@ func (b *PDAGraphBuilder) BuildGraph() error {
 	stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
 	stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
 	stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
+	stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInList]
+	stateToNodeMap[StateInListComma].TransitionEdges['\t'] = stateToNodeMap[StateInList]
+
 	addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
 
 	// list object end
@@ -180,7 +197,7 @@ func (b *PDAGraphBuilder) BuildGraph() error {
 	}
 	stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
 	addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
-	stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
+	// stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
 	stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
 
 	// comma node
@@ -188,9 +205,11 @@ func (b *PDAGraphBuilder) BuildGraph() error {
 	stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
 	stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
 	stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
+	// todo: review this space transition
+	// stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
 
 	// space end value
-	stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
+	// stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
 	stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
 	stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
 	stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
@@ -221,6 +240,12 @@ func addValueConnections(node *PDA, stateToNodeMap map[JSONState]*PDA) {
 
 func (b *PDAGraphBuilder) preComputeValidStates() error {
 	for _, node := range b.stateToNodeMap {
+		// if node.State == StateInObjectKey {
+		// 	if len(b.stateToNodeMap[StateInString].MaskTokenIDToNode) > 0 {
+		// 		b.stateToNodeMap[StateInObjectKey].MaskTokenIDToNode = b.stateToNodeMap[StateInString].MaskTokenIDToNode
+		// 		fmt.Println("copying string mask to object key mask")
+		// 	}
+		// }
 		if err := b.CreateMask(node); err != nil {
 			return err
 		}
@@ -228,6 +253,20 @@ func (b *PDAGraphBuilder) preComputeValidStates() error {
 	return nil
 }
 
+func (b *PDAGraphBuilder) preComputeTokenToStatesMap() error {
+	// TODO: make can be somewhere else too
+	b.tokenToStatesMap = make(map[int32][]JSONState)
+	for i, t := range b.decodedToks {
+		for _, r := range t {
+			if r == '"' {
+				b.tokenToStatesMap[int32(i)] = append(b.tokenToStatesMap[int32(i)], StateInString)
+			}
+		}
+	}
+	return nil
+}
+
+// TODO: the mask for obj key and string should be the same?
 func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
 	if node == nil {
 		return fmt.Errorf("node cannot be nil")

+ 25 - 10
sample/pushdown_runner.go

@@ -10,9 +10,13 @@ import (
 )
 
 // TODO: safety in case of invalid json
+// TODO: partial JSON matching?
 // TODO: interfaces to cleanup with return values
 // TODO this interface shouldn't be the sampler - should just use Sampler
 // TODO: add penalties for string \n stuff
+// TODO: minimize number of fwd passes if there is only one match
+// TODO: greedy sample initially and then backtrack if no match
+
 type PushdownSampler struct {
 	PDAGraphBuilder
 	curNode      *PDA
@@ -140,16 +144,24 @@ func forceFinish(s *PushdownSampler, logits []float64) ([]float64, error) {
 	return logits, nil
 }
 
-func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
+func (s *PushdownSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
 	fmt.Println("current state - updating", s.curNode.State)
 	mappedString, err := s.proc.Decode(tokenSlice)
 	if err != nil {
-		return err
+		return nil, err
 	}
 	fmt.Printf(">>> mappedString: %q\n", mappedString)
 
-	// TODO: should force closing for all braces - not doing square yet
+	// flag := -1
+	// endBraceRunes := []rune{'}', ']'}
 	for _, r := range mappedString {
+		// TODO: if this is enabled again, make sure to appropriately handle the state transitions
+		// if slices.Contains(endBraceRunes, r) && len(s.braceStack) == 0 {
+		// 	fmt.Printf("stack is empty, extra closing brace %c\n", r)
+		// 	// flag = i
+		// 	break
+
+		// }
 		if r == rune('{') {
 			s.braceStack = append(s.braceStack, r)
 		}
@@ -158,32 +170,36 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
 		}
 		if r == rune('}') {
 			if len(s.braceStack) == 0 {
-				return fmt.Errorf("stack is empty, extra closing brace %c", r)
+				return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
 			}
 			top := s.braceStack[len(s.braceStack)-1]
 			if top != rune('{') {
-				return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
+				return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
 			}
 			s.braceStack = s.braceStack[:len(s.braceStack)-1]
 		}
 
 		if r == rune(']') {
 			if len(s.braceStack) == 0 {
-				return fmt.Errorf("stack is empty, extra closing brace %c", r)
+				return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
 			}
 			top := s.braceStack[len(s.braceStack)-1]
 			if top != rune('[') {
-				return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
+				return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
 			}
 			s.braceStack = s.braceStack[:len(s.braceStack)-1]
 		}
 	}
 
+	// if flag != -1 {
+	// 	tokenSlice = tokenSlice[:flag]
+	// }
+	// fmt.Println("flag!", flag)
 	for _, tokenID := range tokenSlice {
 		// transition to the next node
 		nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
 		if !ok {
-			return fmt.Errorf("invalid token: %q", mappedString)
+			return nil, fmt.Errorf("invalid token: %q", mappedString)
 		}
 		fmt.Println("transitioning to", nextNode.State)
 
@@ -196,12 +212,11 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
 		s.curNode = nextNode
 		fmt.Println("updated curNode state", s.curNode.State)
 	}
-	return nil
+	return tokenSlice, nil
 }
 
 // greedy sample + backtrack?
 func (s *PushdownSampler) maskLogits(logits []float64, node *PDA) ([]float64, error) {
-
 	// Create a new slice with same length as logits, initialized to -Inf
 	maskedLogits := make([]float64, len(logits))
 	for i := range maskedLogits {

+ 89 - 23
sample/structured_outputs.go

@@ -35,7 +35,7 @@ func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, err
 		}, nil
 	}
 
-	fmt.Println("schema not nil")
+	// fmt.Println("schema not nil")
 	so := &JSONSampler{
 		schema:        schema,
 		propIdx:       -1,
@@ -87,7 +87,7 @@ func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, err
 		for curState == StateInStructuredKey {
 			// there is only one edge
 			for r, toNode := range fromNode.TransitionEdges {
-				// fmt.Println("rune", r, "edge", toNode.State)
+				fmt.Println("rune", r, "edge", toNode.State)
 				so.pdaSampler.CreateMask(toNode)
 				fmt.Printf("created mask for %c\n", r)
 				curState = toNode.State
@@ -96,13 +96,27 @@ func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, err
 				fromNode = toNode
 			}
 		}
+
+		if curState != StateInColon {
+			return nil, fmt.Errorf("expected state to be StateInColon, got %v", curState)
+		}
+
+		// so.pdaSampler.CreateMask(fromNode)
+
+		fromNode = fromNode.TransitionEdges[' ']
+
+		so.pdaSampler.CreateMask(fromNode)
+		curState = fromNode.State
+		for _, toNode := range fromNode.TransitionEdges {
+			fmt.Println("toNode", toNode.State)
+		}
 	}
 
-	runtime.ReadMemStats(&m)
-	after = m.Alloc
-	fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
-	fmt.Printf("Mask creation time = %v\n", time.Since(start))
-	fmt.Println("--------------------------------")
+	// runtime.ReadMemStats(&m)
+	// after = m.Alloc
+	// fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
+	// fmt.Printf("Mask creation time = %v\n", time.Since(start))
+	// fmt.Println("--------------------------------")
 
 	return so, nil
 }
@@ -130,14 +144,66 @@ func (s *JSONSampler) schemaToGraph() {
 					TransitionEdges:   make(map[rune]*PDA),
 					MaskTokenIDToNode: make(map[int32]*PDA),
 				}
-				fmt.Println("runeNode created", runeNode.State)
-				fmt.Printf("runeNode created %c\n", r)
+				// fmt.Println("runeNode created", runeNode.State)
+				// fmt.Printf("runeNode created %c\n", r)
+
 				// since alloc on heap connections wil still map
 				prevNode.TransitionEdges[r] = runeNode
 				prevNode = runeNode
 			}
+
 			// point to end of object key node after all chars are done
-			prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
+			// prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
+
+			// link to value node
+			// Create a node for the end of the key (after the closing quote)
+			stringEndNode := &PDA{
+				State:             StateInStructuredKey,
+				TransitionEdges:   make(map[rune]*PDA),
+				MaskTokenIDToNode: make(map[int32]*PDA),
+			}
+			prevNode.TransitionEdges['"'] = stringEndNode
+			prevNode = stringEndNode
+
+			// Add transition for colon after key
+			colonNode := &PDA{
+				State:             StateInColon,
+				TransitionEdges:   make(map[rune]*PDA),
+				MaskTokenIDToNode: make(map[int32]*PDA),
+			}
+			prevNode.TransitionEdges[':'] = colonNode
+			prevNode = colonNode
+
+			// Add transition for space after colon
+			spaceNode := &PDA{
+				State:             StateInSpaceToValue,
+				TransitionEdges:   make(map[rune]*PDA),
+				MaskTokenIDToNode: make(map[int32]*PDA),
+			}
+			prevNode.TransitionEdges[' '] = spaceNode
+			prevNode = spaceNode
+
+			value := prop.Type
+			switch value {
+			case "object":
+				fmt.Println("object under key: ", name)
+			case "array":
+				fmt.Println("array under key: ", name)
+			case "string":
+				fmt.Println("string under key: ", name)
+				prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInString]
+			case "number":
+				fmt.Println("number under key: ", name)
+				for _, r := range validNumberRunes {
+					prevNode.TransitionEdges[r] = s.pdaSampler.stateToNodeMap[StateInNumber]
+				}
+			case "boolean":
+				fmt.Println("boolean under key: ", name)
+				prevNode.TransitionEdges['t'] = s.pdaSampler.stateToNodeMap[StateInBool]
+				prevNode.TransitionEdges['f'] = s.pdaSampler.stateToNodeMap[StateInBool]
+				prevNode.TransitionEdges['n'] = s.pdaSampler.stateToNodeMap[StateInNull]
+			}
+
 			// points to start of the key
 			s.propToNodeMap[name] = keyNode
 			fmt.Println("name", name, "keyNode", keyNode.State)
@@ -152,7 +218,7 @@ func (s *JSONSampler) Apply(logits []float64) ([]float64, error) {
 	}
 
 	switch s.pdaSampler.curNode.State {
-	// doesnt account for multi rune case
+	// TODO: doesnt account for multi rune case
 	case StateInObjectKey:
 		if s.propIdx > len(s.schema.Properties)-1 {
 			return nil, fmt.Errorf("propIdx out of bounds")
@@ -196,18 +262,17 @@ func (s *JSONSampler) Apply(logits []float64) ([]float64, error) {
 		}
 		return s.pdaSampler.Apply(logits)
 	}
-
 }
 
-func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
-	err := s.pdaSampler.UpdateState(tokenSlice)
+func (s *JSONSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
+	tokenSlice, err := s.pdaSampler.UpdateState(tokenSlice)
 	if err != nil {
-		return err
+		return nil, err
 	}
 
 	if s.schema == nil {
 		// Don't need to update state for unconstrained JSON sampling
-		return nil
+		return tokenSlice, nil
 	}
 
 	switch s.pdaSampler.curNode.State {
@@ -217,14 +282,15 @@ func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
 		prop := s.schema.Properties[s.propIdx]
 		fmt.Println("prop", prop.Name)
 		s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
-		str, err := s.pdaSampler.proc.Decode(tokenSlice)
-		if err != nil {
-			return err
-		}
-		fmt.Println("str", str)
+		// TODO: this does not work - mike
+		// str, err := s.pdaSampler.proc.Decode(tokenSlice)
+		// if err != nil {
+		// 	return nil, err
+		// }
+		// fmt.Println("str", str)
 
-		return nil
+		return tokenSlice, nil
 	default:
-		return nil
+		return tokenSlice, nil
 	}
 }