ParthSareen 1 månad sedan
förälder
incheckning
5ec6bb52a0

+ 1 - 0
model/process_text.go

@@ -32,6 +32,7 @@ type TextProcessor interface {
 	Encode(s string, addSpecial bool) ([]int32, error)
 	Encode(s string, addSpecial bool) ([]int32, error)
 	Decode([]int32) (string, error)
 	Decode([]int32) (string, error)
 	Is(int32, Special) bool
 	Is(int32, Special) bool
+	Vocab() *Vocabulary
 }
 }
 
 
 type Vocabulary struct {
 type Vocabulary struct {

+ 4 - 0
model/process_text_spm.go

@@ -53,6 +53,10 @@ func (spm SentencePieceModel) Is(id int32, special Special) bool {
 	return spm.vocab.Is(id, special)
 	return spm.vocab.Is(id, special)
 }
 }
 
 
+func (spm SentencePieceModel) Vocab() *Vocabulary {
+	return spm.vocab
+}
+
 func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
 func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
 	return func(yield func(string) bool) {
 	return func(yield func(string) bool) {
 		for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
 		for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {

+ 32 - 0
runner/ollamarunner/runner.go

@@ -468,6 +468,20 @@ func (s *Server) processBatch() error {
 			return fmt.Errorf("failed to sample token: %w", err)
 			return fmt.Errorf("failed to sample token: %w", err)
 		}
 		}
 
 
+		if seq.sampler.JSONSampler != nil {
+			_, err = seq.sampler.JSONSampler.UpdateState([]int32{token})
+			if err != nil {
+				return fmt.Errorf("failed to update state: %w", err)
+			}
+		}
+
+		if seq.sampler.PythonSampler != nil {
+			err = seq.sampler.PythonSampler.UpdateState(token)
+			if err != nil {
+				return fmt.Errorf("failed to update state: %w", err)
+			}
+		}
+
 		// if it's an end of sequence token, break
 		// if it's an end of sequence token, break
 		if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
 		if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
 			// TODO (jmorganca): we should send this back
 			// TODO (jmorganca): we should send this back
@@ -562,6 +576,22 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 		}
 		}
 	}
 	}
 
 
+	// jsonSampler, err := sample.NewJSONSampler(s.model.(model.TextProcessor), nil)
+	// if err != nil {
+	// 	http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
+	// 	return
+	// }
+	// jsonSampler = nil
+	// pythonSampler := sample.NewPythonSampler(s.model.(model.TextProcessor), nil)
+	// pythonSampler := &sample.PythonSampler{}
+	// functions := []sample.PythonFunction{
+	// 	{
+	// 		Name:  "add_two_strings",
+	// 		Args:  []string{"s1", "s2"},
+	// 		Types: []string{"string", "string"},
+	// 	},
+	// }
+	// pythonSampler.Init(functions, s.model.(model.TextProcessor))
 	sampler := sample.NewSampler(
 	sampler := sample.NewSampler(
 		req.Options.Temperature,
 		req.Options.Temperature,
 		req.Options.TopK,
 		req.Options.TopK,
@@ -569,6 +599,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 		req.Options.MinP,
 		req.Options.MinP,
 		req.Options.Seed,
 		req.Options.Seed,
 		grammar,
 		grammar,
+		nil,
+		nil,
 	)
 	)
 
 
 	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
 	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{

+ 53 - 0
sample/gtf.go

@@ -0,0 +1,53 @@
+package sample
+
+var DefaultGrammar = map[string]string{
+	"unicode": `\x{hex}{2} | \u{hex}{4} | \U{hex}{8}`,
+	"null":    `"null"`,
+	"object":  `"{" (kv ("," kv)*)? "}"`,
+	"array":   `"[" (value ("," value)*)? "]"`,
+	"kv":      `string ":" value`,
+	"integer": `"0" | [1-9] [0-9]*`,
+	"number":  `"-"? integer frac? exp?`,
+	"frac":    `"." [0-9]+`,
+	"exp":     `("e" | "E") ("+" | "-") [0-9]+`,
+	"string":  `"\"" char* "\""`,
+	"escape":  `["/" | "b" | "f" | "n" | "r" | "t" | unicode]`,
+	"char":    `[^"\\] | escape`,
+	"space":   `(" " | "\t" | "\n" | "\r")*`,
+	"hex":     `[0-9] | [a-f] | [A-F]`,
+	"boolean": `"true" | "false"`,
+	"value":   `object | array | string | number | boolean | "null"`,
+}
+
+const jsonString = `object | array`
+
+type StateMachine struct {
+	states map[rune]State
+}
+
+type State struct {
+	NextStates []string
+	// bitmask?
+	Mask       []bool
+	IsTerminal bool
+}
+
+func NewStateMachine(grammar map[string]string, startRule string) *StateMachine {
+	states := make(map[rune]State)
+
+	var cumu string
+	flag := false
+	for _, r := range startRule {
+		if r == '"' {
+			flag = !flag
+		}
+		if flag {
+			cumu += string(r)
+		}
+	}
+
+	sm := &StateMachine{
+		states: states,
+	}
+	return sm
+}

+ 138 - 0
sample/gtf_test.go

@@ -0,0 +1,138 @@
+package sample
+
+import (
+	"testing"
+)
+
+func TestGrammarParsing(t *testing.T) {
+	tests := []struct {
+		name      string
+		grammar   map[string]string
+		startRule string
+		input     string
+		want      bool
+	}{
+		{
+			name: "simple object",
+			grammar: map[string]string{
+				"object": `"{" "}"`,
+			},
+			startRule: "object",
+			input:     "{}",
+			want:      true,
+		},
+		{
+			name: "simple array",
+			grammar: map[string]string{
+				"array": `"[" "]"`,
+			},
+			startRule: "array",
+			input:     "[]",
+			want:      true,
+		},
+		{
+			name: "character class",
+			grammar: map[string]string{
+				"digit": `[0-9]`,
+			},
+			startRule: "digit",
+			input:     "5",
+			want:      true,
+		},
+		{
+			name: "alternation",
+			grammar: map[string]string{
+				"bool": `"true" | "false"`,
+			},
+			startRule: "bool",
+			input:     "true",
+			want:      true,
+		},
+		{
+			name: "repetition",
+			grammar: map[string]string{
+				"digits": `[0-9]+`,
+			},
+			startRule: "digits",
+			input:     "123",
+			want:      true,
+		},
+		{
+			name: "nested rules",
+			grammar: map[string]string{
+				"value":  `object | array`,
+				"object": `"{" "}"`,
+				"array":  `"[" "]"`,
+			},
+			startRule: "value",
+			input:     "{}",
+			want:      true,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			parser := NewParser(tt.grammar)
+			machine, err := parser.Parse(tt.startRule)
+			if err != nil {
+				t.Fatalf("Parse() error = %v", err)
+			}
+
+			matcher := NewMatcher(machine)
+			got, err := matcher.Match(tt.input)
+			if err != nil {
+				t.Fatalf("Match() error = %v", err)
+			}
+			if got != tt.want {
+				t.Errorf("Match() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func TestJSONGrammar(t *testing.T) {
+	tests := []struct {
+		name  string
+		input string
+		want  bool
+	}{
+		{"empty object", "{}", true},
+		{"empty array", "[]", true},
+		{"simple string", `"hello"`, true},
+		{"simple number", "123", true},
+		{"simple boolean", "true", true},
+		{"simple null", "null", true},
+		{"object with string", `{"key": "value"}`, true},
+		{"array with numbers", "[1, 2, 3]", true},
+		{"nested object", `{"obj": {"key": "value"}}`, true},
+		{"nested array", `[1, [2, 3], 4]`, true},
+		{"invalid object", "{", false},
+		{"invalid array", "[1, 2", false},
+		{"invalid string", `"hello`, false},
+	}
+
+	parser := NewParser(DefaultGrammar)
+	machine, err := parser.Parse("value")
+	if err != nil {
+		t.Fatalf("Parse() error = %v", err)
+	}
+
+	matcher := NewMatcher(machine)
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got, err := matcher.Match(tt.input)
+			if tt.want {
+				if err != nil {
+					t.Errorf("Match() error = %v", err)
+				}
+				if !got {
+					t.Errorf("Match() = false, want true")
+				}
+			} else {
+				if err == nil && got {
+					t.Errorf("Match() = true, want false")
+				}
+			}
+		})
+	}
+}

+ 160 - 0
sample/json_types.go

@@ -0,0 +1,160 @@
+package sample
+
+import (
+	"fmt"
+)
+
+type JSONState int
+
+const (
+	StateStart JSONState = iota
+	StateInObject
+	StateInObjectKey
+	StateInStructuredKey
+	StateInStructuredValue
+	StateNewline
+	StateTab
+	StateSpace
+	StateInString
+	StateInInt
+	StateInFloat
+	StateInBool
+	StateInNull
+	StateInColon
+	StateInComma
+	StateInTab
+	StateInSpaceToValue
+	StateInSpaceEndValue
+	StateInNewlineEndValue
+	StateInObjSpace
+	StateInList
+	StateInListComma
+	StateInValue
+	StateInValueEnd
+	StateInListEnd
+	StateInListObjectEnd
+	StateInNewline
+	StateInNumber
+	StateInNumberEnd
+	StateInStringEnd
+	StateInObjectKeyEnd
+	StateTerminate
+	StateInObjectEnd
+	StateTransitioningToTerminate
+	StateInListStartJSON
+)
+
+var JSONStates = []JSONState{
+	StateStart,
+	StateInObject,
+	StateInObjectKey,
+	StateInStructuredKey,
+	StateInStructuredValue,
+	StateNewline,
+	StateTab,
+	StateSpace,
+	StateInString,
+	StateInInt,
+	StateInFloat,
+	StateInBool,
+	StateInNull,
+	StateInColon,
+	StateInComma,
+	StateInTab,
+	StateInSpaceToValue,
+	StateInSpaceEndValue,
+	StateInNewlineEndValue,
+	StateInObjSpace,
+	StateInListStartJSON,
+	StateInList,
+	StateInListComma,
+	StateInValue,
+	StateInValueEnd,
+	StateInListEnd,
+	StateInListObjectEnd,
+	StateInNewline,
+	StateInNumber,
+	StateInNumberEnd,
+	StateInStringEnd,
+	StateInObjectKeyEnd,
+	StateTerminate,
+	StateInObjectEnd,
+	StateTransitioningToTerminate,
+}
+
+func (s JSONState) String() string {
+	switch s {
+	case StateStart:
+		return "StateStart"
+	case StateInObject:
+		return "StateInObject"
+	case StateInObjectKey:
+		return "StateInObjectKey"
+	case StateInStructuredKey:
+		return "StateInStructuredKey"
+	case StateInStructuredValue:
+		return "StateInStructuredValue"
+	case StateNewline:
+		return "StateNewline"
+	case StateTab:
+		return "StateTab"
+	case StateSpace:
+		return "StateSpace"
+	case StateInString:
+		return "StateInString"
+	case StateInInt:
+		return "StateInInt"
+	case StateInFloat:
+		return "StateInFloat"
+	case StateInBool:
+		return "StateInBool"
+	case StateInNull:
+		return "StateInNull"
+	case StateInColon:
+		return "StateInColon"
+	case StateInComma:
+		return "StateInComma"
+	case StateInTab:
+		return "StateInTab"
+	case StateInSpaceToValue:
+		return "StateInSpaceToValue"
+	case StateInSpaceEndValue:
+		return "StateInSpaceEndValue"
+	case StateInNewlineEndValue:
+		return "StateInNewlineEndValue"
+	case StateInObjSpace:
+		return "StateInObjSpace"
+	case StateInList:
+		return "StateInList"
+	case StateInListComma:
+		return "StateInListComma"
+	case StateInValue:
+		return "StateInValue"
+	case StateInValueEnd:
+		return "StateInValueEnd"
+	case StateInListEnd:
+		return "StateInListEnd"
+	case StateInListObjectEnd:
+		return "StateInListObjectEnd"
+	case StateInNewline:
+		return "StateInNewline"
+	case StateInNumber:
+		return "StateInNumber"
+	case StateInNumberEnd:
+		return "StateInNumberEnd"
+	case StateInStringEnd:
+		return "StateInStringEnd"
+	case StateInObjectKeyEnd:
+		return "StateInObjectKeyEnd"
+	case StateTerminate:
+		return "StateTerminate"
+	case StateInObjectEnd:
+		return "StateInObjectEnd"
+	case StateTransitioningToTerminate:
+		return "StateTransitioningToTerminate"
+	case StateInListStartJSON:
+		return "StateInListStartJSON"
+	default:
+		return fmt.Sprintf("Unknown state: %d", s)
+	}
+}

+ 327 - 0
sample/pushdown_automata.go

@@ -0,0 +1,327 @@
+package sample
+
+import (
+	"fmt"
+	"slices"
+
+	"github.com/ollama/ollama/model"
+)
+
+/*
+Key JSON rules to consider:
+
+1. Whitespace handling:
+   - Need to handle all valid JSON whitespace characters (\r, spaces between tokens)
+   - Current code only handles some whitespace cases
+
+2. Number validation:
+   - Need proper validation for special number cases like -0
+   - Should handle .5 style decimals
+   - Need limits on scientific notation (e, E)
+
+3. String escaping:
+   - Currently marks \ as invalid but should allow escaped sequences:
+     - \"
+     - \n
+     - \u1234 unicode escapes
+
+4. Empty object/array transitions:
+   - Direct {} and [] cases could be more explicit
+   - Need clear transitions for these edge cases
+
+5. Nested depth limits:
+   - No protection against excessive nesting
+   - Could cause stack overflow with deeply nested structures
+*/
+
+// TODO: / should be valid but an escape character
+var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
+
+var (
+	intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
+	validIntRunes   = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
+)
+
+var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
+
+var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
+
+var validNullRunes = []rune{'n', 'u', 'l', 'l'}
+
+type PDA struct {
+	State             JSONState
+	TransitionEdges   map[rune]*PDA
+	MaskTokenIDToNode map[int32]*PDA
+}
+
+func NewPDANode(state JSONState) *PDA {
+	return &PDA{
+		State:             state,
+		TransitionEdges:   make(map[rune]*PDA),
+		MaskTokenIDToNode: make(map[int32]*PDA),
+	}
+}
+
+type PDAGraphBuilder struct {
+	proc             model.TextProcessor
+	decodedToks      []string
+	stateToNodeMap   map[JSONState]*PDA
+	tokenToStatesMap map[int32][]JSONState
+}
+
+func (b *PDAGraphBuilder) BuildGraph() error {
+	stateToNodeMap := make(map[JSONState]*PDA)
+	for _, state := range JSONStates {
+		stateToNodeMap[state] = NewPDANode(state)
+	}
+
+	stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
+	stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
+
+	// TODO: update naming here - and revisit values
+	stateToNodeMap[StateInListStartJSON].TransitionEdges['{'] = stateToNodeMap[StateInObject]
+	stateToNodeMap[StateInListStartJSON].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
+
+	stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
+	stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
+	stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
+	stateToNodeMap[StateInObject].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
+
+	// new line
+	stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
+	stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
+	stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
+	stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
+	// stateToNodeMap[StateInNewline].TransitionEdges['{'] = stateToNodeMap[StateInObject]
+
+	// new line end value
+	// stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
+	stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
+	stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
+
+	stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
+	stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
+	// TODO: see if this is needed for formatting
+	stateToNodeMap[StateInObjSpace].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
+
+	stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
+	stateToNodeMap[StateInTab].TransitionEdges['\t'] = stateToNodeMap[StateInNewline]
+
+	stateToNodeMap[StateInObjectKey].TransitionEdges[rune(-1)] = stateToNodeMap[StateInObjectKey]
+	stateToNodeMap[StateInObjectKey].TransitionEdges['"'] = stateToNodeMap[StateInObjectKeyEnd]
+
+	stateToNodeMap[StateInObjectKeyEnd].TransitionEdges[':'] = stateToNodeMap[StateInColon]
+
+	stateToNodeMap[StateInObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
+	stateToNodeMap[StateInObjectEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
+
+	// where values should be
+	// this could be combined but the probl might change, we're alr doing a skip ahead
+	stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
+	stateToNodeMap[StateInColon].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
+
+	stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
+	stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
+	addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
+
+	// Leads to a value
+	stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
+	stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
+	addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap)
+	stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
+	stateToNodeMap[StateInSpaceToValue].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
+
+	// Values
+	// string node
+	stateToNodeMap[StateInString].TransitionEdges[rune(-1)] = stateToNodeMap[StateInString]
+	stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
+
+	// String end node
+	addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
+	// stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
+	stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
+
+	// TODO: add counters for allowable number of decimals, e, E, etc
+	// number node
+	for _, r := range validNumberRunes {
+		stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
+	}
+	addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
+	// stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
+	stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
+
+	// list node
+	stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
+	stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
+	stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
+	stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
+	// early end
+	stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
+
+	// list end node
+	stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
+	// stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
+	stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
+	stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
+
+	// empty list
+	stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
+	addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
+
+	// null node
+	for _, r := range validNullRunes {
+		stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
+	}
+	addEnds(stateToNodeMap[StateInNull], stateToNodeMap)
+	stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
+	stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
+
+	// list comma
+	// should point to values
+	stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
+	stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
+	stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
+	stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInList]
+	stateToNodeMap[StateInListComma].TransitionEdges['\t'] = stateToNodeMap[StateInList]
+
+	addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
+
+	// list object end
+	stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
+	stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
+	// TODO: not sure if this is needed
+	stateToNodeMap[StateInListObjectEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
+
+	// bool node
+	for _, r := range validBoolRunes {
+		stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
+	}
+	stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
+	addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
+	// stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
+	stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
+
+	// comma node
+	stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
+	stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
+	stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
+	stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
+	// todo: review this space transition
+	// stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
+
+	// space end value
+	// stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
+	stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
+	stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
+	stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
+
+	b.stateToNodeMap = stateToNodeMap
+	if err := b.preComputeValidStates(); err != nil {
+		return err
+	}
+	return nil
+}
+
+func addEnds(node *PDA, stateToNodeMap map[JSONState]*PDA) {
+	node.TransitionEdges[','] = stateToNodeMap[StateInComma]
+	node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
+	node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
+}
+
+func addValueConnections(node *PDA, stateToNodeMap map[JSONState]*PDA) {
+	node.TransitionEdges['"'] = stateToNodeMap[StateInString]
+	for _, r := range validNumberRunes {
+		node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
+	}
+	// TODO(parthsareen): force the output and shift similar to structured outputs
+	node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
+	node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
+	node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
+}
+
+func (b *PDAGraphBuilder) preComputeValidStates() error {
+	for _, node := range b.stateToNodeMap {
+		// if node.State == StateInObjectKey {
+		// 	if len(b.stateToNodeMap[StateInString].MaskTokenIDToNode) > 0 {
+		// 		b.stateToNodeMap[StateInObjectKey].MaskTokenIDToNode = b.stateToNodeMap[StateInString].MaskTokenIDToNode
+		// 		fmt.Println("copying string mask to object key mask")
+		// 	}
+		// }
+		if err := b.CreateMask(node); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (b *PDAGraphBuilder) preComputeTokenToStatesMap() error {
+	// TODO: make can be somewhere else too
+	b.tokenToStatesMap = make(map[int32][]JSONState)
+	for i, t := range b.decodedToks {
+		for _, r := range t {
+			if r == '"' {
+				b.tokenToStatesMap[int32(i)] = append(b.tokenToStatesMap[int32(i)], StateInString)
+			}
+		}
+	}
+	return nil
+}
+
+// TODO: the mask for obj key and string should be the same?
+func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
+	if node == nil {
+		return fmt.Errorf("node cannot be nil")
+	}
+	for i := range b.decodedToks {
+		token := b.decodedToks[i]
+		// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
+		if b.proc.Is(int32(i), model.SpecialEOS) || b.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
+			continue
+		}
+		curNode := node
+		valid := true
+		consumedSpecialRunes := make(map[rune]bool)
+		for _, r := range token {
+			curNode, valid = isRuneValid(r, curNode, consumedSpecialRunes)
+			if curNode == nil || !valid {
+				break
+			}
+		}
+		if valid {
+			node.MaskTokenIDToNode[int32(i)] = curNode
+		}
+	}
+	return nil
+}
+
+func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA, bool) {
+	if consumedSpecialRunes[r] {
+		return nil, false
+	}
+
+	specialRune := slices.Contains(stringInvalidRunes, r)
+	if specialRune {
+		if curNode.State == StateInString || curNode.State == StateInObjectKey {
+			return nil, false
+		}
+	}
+
+	// Check for specific rune transition
+	if nextNode, ok := curNode.TransitionEdges[r]; ok {
+		// fmt.Println("next node", nextNode)
+		if specialRune {
+			if curNode.State == nextNode.State {
+				return nil, false
+			}
+			consumedSpecialRunes[r] = true
+		}
+		return nextNode, true
+	}
+
+	// Check for sentinel value - if present, any rune is valid
+	if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
+		return nextNode, true
+	}
+
+	return nil, false
+}

+ 264 - 0
sample/pushdown_runner.go

@@ -0,0 +1,264 @@
+package sample
+
+import (
+	"fmt"
+	"math"
+	"runtime"
+	"time"
+
+	"github.com/ollama/ollama/model"
+)
+
+// TODO: safety in case of invalid json
+// TODO: partial JSON matching?
+// TODO: interfaces to cleanup with return values
+// TODO this interface shouldn't be the sampler - should just use Sampler
+// TODO: add penalties for string \n stuff
+// TODO: minimize number of fwd passes if there is only one match
+// TODO: greedy sample initially and then backtrack if no match
+
+type PushdownSampler struct {
+	PDAGraphBuilder
+	curNode      *PDA
+	braceStack   []rune
+	stateCounter uint32
+}
+
+// graph should be built once and reused per tokenizer
+func NewPushdownSampler(proc model.TextProcessor) (*PushdownSampler, error) {
+	start := time.Now()
+
+	fmt.Println("--------------------------------")
+	fmt.Println("PDA sampler")
+	fmt.Println("--------------------------------")
+	var m runtime.MemStats
+	runtime.ReadMemStats(&m)
+	before := m.Alloc
+	fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
+
+	vocab := proc.Vocab()
+	decodedToks := make([]string, len(vocab.Values))
+	for i := range vocab.Values {
+		token, err := proc.Decode([]int32{int32(i)})
+		if err != nil {
+			return nil, err
+		}
+		decodedToks[i] = token
+	}
+
+	gb := &PDAGraphBuilder{
+		proc:        proc,
+		decodedToks: decodedToks,
+	}
+
+	if err := gb.BuildGraph(); err != nil {
+		return nil, err
+	}
+
+	runtime.ReadMemStats(&m)
+	after := m.Alloc
+	fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
+	fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
+	fmt.Printf("Graph build time = %v\n", time.Since(start))
+
+	// TODO: this can be simplified
+	return &PushdownSampler{
+		curNode:         gb.stateToNodeMap[StateStart],
+		PDAGraphBuilder: *gb,
+		braceStack:      []rune{},
+		stateCounter:    0,
+	}, nil
+}
+
+// TODO: need to add resampling logic if the first sample was not good
+// greedy sample + backtrack?
+func (s *PushdownSampler) Apply(logits []float32) ([]float32, error) {
+	switch s.curNode.State {
+	case StateInString:
+		return s.maskLogits(logits, s.curNode)
+
+	case StateInListEnd:
+		// force finish if no braces left
+		if len(s.braceStack) == 0 {
+			s.curNode = NewPDANode(StateTerminate)
+			return forceFinish(s, logits)
+		}
+
+		logits, err := s.maskLogits(logits, s.curNode)
+		if err != nil {
+			return nil, err
+		}
+		return logits, nil
+
+	case StateTerminate:
+		return forceFinish(s, logits)
+
+	case StateInObjectEnd:
+		// force finish if no braces left
+		if len(s.braceStack) == 0 {
+			s.curNode = NewPDANode(StateTerminate)
+			return forceFinish(s, logits)
+		}
+
+		peek := s.braceStack[len(s.braceStack)-1]
+		if peek == rune('[') {
+			s.curNode = s.stateToNodeMap[StateInListObjectEnd]
+		}
+
+		logits, err := s.maskLogits(logits, s.curNode)
+		if err != nil {
+			return nil, err
+		}
+		return logits, nil
+
+	case StateInComma:
+		peek := s.braceStack[len(s.braceStack)-1]
+		if peek == rune('[') {
+			s.curNode = s.stateToNodeMap[StateInListComma]
+		}
+
+		logits, err := s.maskLogits(logits, s.curNode)
+		if err != nil {
+			return nil, err
+		}
+		return logits, nil
+
+	default:
+		fmt.Println("masking logits current state", s.curNode.State)
+		logits, err := s.maskLogits(logits, s.curNode)
+		if err != nil {
+			return nil, err
+		}
+		return logits, nil
+	}
+}
+
+func forceFinish(s *PushdownSampler, logits []float32) ([]float32, error) {
+	for i := range logits {
+		if s.proc.Is(int32(i), model.SpecialEOS) {
+			logits[i] = 1.0
+		} else {
+			logits[i] = float32(math.Inf(-1))
+		}
+	}
+	return logits, nil
+}
+
+func (s *PushdownSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
+	fmt.Println("current state - updating", s.curNode.State)
+	mappedString, err := s.proc.Decode(tokenSlice)
+	if err != nil {
+		return nil, err
+	}
+	fmt.Printf(">>> mappedString: %q\n", mappedString)
+
+	// Special handling for EOS token in terminate state
+	if s.curNode.State == StateTerminate {
+		for _, tokenID := range tokenSlice {
+			if s.proc.Is(tokenID, model.SpecialEOS) {
+				return tokenSlice, nil
+			}
+		}
+	}
+
+	// flag := -1
+	// endBraceRunes := []rune{'}', ']'}
+	for _, r := range mappedString {
+		// TODO: if this is enabled again, make sure to appropriately handle the state transitions
+		// if slices.Contains(endBraceRunes, r) && len(s.braceStack) == 0 {
+		// 	fmt.Printf("stack is empty, extra closing brace %c\n", r)
+		// 	// flag = i
+		// 	break
+
+		// }
+		if r == rune('{') {
+			s.braceStack = append(s.braceStack, r)
+		}
+		if r == rune('[') {
+			s.braceStack = append(s.braceStack, r)
+		}
+		if r == rune('}') {
+			if len(s.braceStack) == 0 {
+				return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
+			}
+			top := s.braceStack[len(s.braceStack)-1]
+			if top != rune('{') {
+				return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
+			}
+			s.braceStack = s.braceStack[:len(s.braceStack)-1]
+		}
+
+		if r == rune(']') {
+			if len(s.braceStack) == 0 {
+				return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
+			}
+			top := s.braceStack[len(s.braceStack)-1]
+			if top != rune('[') {
+				return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
+			}
+			s.braceStack = s.braceStack[:len(s.braceStack)-1]
+		}
+	}
+
+	// if flag != -1 {
+	// 	tokenSlice = tokenSlice[:flag]
+	// }
+	// fmt.Println("flag!", flag)
+	for _, tokenID := range tokenSlice {
+		// transition to the next node
+		nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
+		if !ok {
+			return nil, fmt.Errorf("invalid token: %q", mappedString)
+		}
+		fmt.Println("transitioning to", nextNode.State)
+
+		// TODO: add a penalty for staying in the same state too long
+		if nextNode.State == s.curNode.State {
+			s.stateCounter++
+		} else {
+			s.stateCounter = 0
+		}
+		s.curNode = nextNode
+		fmt.Println("updated curNode state", s.curNode.State)
+	}
+	return tokenSlice, nil
+}
+
+// greedy sample + backtrack?
+func (s *PushdownSampler) maskLogits(logits []float32, node *PDA) ([]float32, error) {
+	// Create a new slice with same length as logits, initialized to -Inf
+	maskedLogits := make([]float32, len(logits))
+	for i := range maskedLogits {
+		maskedLogits[i] = float32(math.Inf(-1))
+	}
+
+	// Only update values for valid token IDs from the mask map
+	for tokenID := range node.MaskTokenIDToNode {
+		if int(tokenID) < len(logits) {
+			maskedLogits[tokenID] = logits[tokenID]
+		}
+	}
+
+	return maskedLogits, nil
+}
+
+func (s *PushdownSampler) fastMaskLogits(logits []float32, node *PDA) ([]float32, error) {
+	maxLogit := float32(math.Inf(-1))
+	maxIndex := -1
+
+	// Find the maximum logit value among valid tokens
+	for tokenID := range node.MaskTokenIDToNode {
+		if int(tokenID) < len(logits) && logits[tokenID] > maxLogit {
+			maxLogit = logits[tokenID]
+			maxIndex = int(tokenID)
+		}
+	}
+
+	if maxIndex == -1 {
+		return nil, fmt.Errorf("no valid tokens found in mask")
+	}
+
+	logits[0] = float32(maxIndex)
+	return logits, nil
+	// return maxIndex, nil
+}

+ 30 - 13
sample/samplers.go

@@ -17,12 +17,14 @@ type token struct {
 }
 }
 
 
 type Sampler struct {
 type Sampler struct {
-	rng         *rand.Rand
-	topK        int
-	topP        float32
-	minP        float32
-	temperature float32
-	grammar     *Grammar
+	rng           *rand.Rand
+	topK          int
+	topP          float32
+	minP          float32
+	temperature   float32
+	grammar       *Grammar
+	JSONSampler   *JSONSampler
+	PythonSampler *PythonSampler
 }
 }
 
 
 func (s *Sampler) Sample(logits []float32) (int32, error) {
 func (s *Sampler) Sample(logits []float32) (int32, error) {
@@ -30,6 +32,19 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
 		return -1, errors.New("sample: no logits provided to sample")
 		return -1, errors.New("sample: no logits provided to sample")
 	}
 	}
 
 
+	var err error
+	if s.JSONSampler != nil {
+		logits, err = s.JSONSampler.Apply(logits)
+		if err != nil {
+			return -1, err
+		}
+	}
+	if s.PythonSampler != nil {
+		logits, err = s.PythonSampler.ApplyMask(logits)
+		if err != nil {
+			return -1, err
+		}
+	}
 	tokens := make([]token, len(logits))
 	tokens := make([]token, len(logits))
 	for i := range logits {
 	for i := range logits {
 		tokens[i].id = int32(i)
 		tokens[i].id = int32(i)
@@ -127,7 +142,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
 }
 }
 
 
 // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
 // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
-func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
+func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar, jsonSampler *JSONSampler, pythonSampler *PythonSampler) Sampler {
 	var rng *rand.Rand
 	var rng *rand.Rand
 	if seed != -1 {
 	if seed != -1 {
 		// PCG requires two parameters: sequence and stream
 		// PCG requires two parameters: sequence and stream
@@ -155,12 +170,14 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
 	}
 	}
 
 
 	return Sampler{
 	return Sampler{
-		rng:         rng,
-		topK:        topK,
-		topP:        topP,
-		minP:        minP,
-		temperature: temperature,
-		grammar:     grammar,
+		rng:           rng,
+		topK:          topK,
+		topP:          topP,
+		minP:          minP,
+		temperature:   temperature,
+		grammar:       grammar,
+		JSONSampler:   jsonSampler,
+		PythonSampler: pythonSampler,
 	}
 	}
 }
 }
 
 

+ 299 - 0
sample/structured_outputs.go

@@ -0,0 +1,299 @@
+package sample
+
+import (
+	"fmt"
+	"log/slog"
+	"runtime"
+	"time"
+
+	"github.com/ollama/ollama/grammar/jsonschema"
+	"github.com/ollama/ollama/model"
+)
+
+type JSONSampler struct {
+	schema        *jsonschema.Schema
+	propIdx       int
+	propToNodeMap map[string]*PDA
+	pdaSampler    *PushdownSampler
+	decodedToks   []string
+}
+
+func NewJSONSampler(proc model.TextProcessor, schema *jsonschema.Schema) (*JSONSampler, error) {
+	slog.Info("NewJSONSampler", "schema", schema)
+	if proc == nil {
+		return nil, fmt.Errorf("TextProcessor cannot be nil")
+	}
+
+	pdaSampler, err := NewPushdownSampler(proc)
+	if err != nil {
+		return nil, fmt.Errorf("failed to create PushdownSampler: %w", err)
+	}
+
+	if schema == nil {
+		return &JSONSampler{
+			schema:        nil,
+			propIdx:       -1,
+			propToNodeMap: nil,
+			pdaSampler:    pdaSampler,
+		}, nil
+	}
+
+	// fmt.Println("schema not nil")
+	so := &JSONSampler{
+		schema:        schema,
+		propIdx:       -1,
+		propToNodeMap: make(map[string]*PDA),
+		pdaSampler:    pdaSampler,
+	}
+
+	so.schemaToGraph()
+
+	// Benchmark token decoding
+	start := time.Now()
+	var m runtime.MemStats
+	runtime.ReadMemStats(&m)
+	before := m.Alloc
+
+	vocab := proc.Vocab()
+	decodedToks := make([]string, len(vocab.Values))
+	for i := range vocab.Values {
+		token, err := proc.Decode([]int32{int32(i)})
+		if err != nil {
+			return nil, err
+		}
+		decodedToks[i] = token
+	}
+	so.decodedToks = decodedToks
+
+	runtime.ReadMemStats(&m)
+	after := m.Alloc
+	fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
+	fmt.Printf("Token decode time = %v\n", time.Since(start))
+
+	fmt.Println("--------------------------------")
+	fmt.Println("SOSampler")
+	fmt.Println("--------------------------------")
+	// Benchmark this section
+	start = time.Now()
+	runtime.ReadMemStats(&m)
+	before = m.Alloc
+
+	// TODO: still messed up
+	// TODO: recursion use case
+	// key masks
+	for _, prop := range so.schema.Properties {
+		node := so.propToNodeMap[prop.Name]
+		// propName -> node
+		curState := node.State
+		fromNode := node
+		so.pdaSampler.CreateMask(fromNode)
+		for curState == StateInStructuredKey {
+			// there is only one edge
+			for r, toNode := range fromNode.TransitionEdges {
+				fmt.Println("rune", r, "edge", toNode.State)
+				so.pdaSampler.CreateMask(toNode)
+				fmt.Printf("created mask for %c\n", r)
+				curState = toNode.State
+				fmt.Println("next state", curState)
+				// TODO: theres an extra gen for " right now
+				fromNode = toNode
+			}
+		}
+
+		if curState != StateInColon {
+			return nil, fmt.Errorf("expected state to be StateInColon, got %v", curState)
+		}
+
+		// so.pdaSampler.CreateMask(fromNode)
+
+		fromNode = fromNode.TransitionEdges[' ']
+
+		so.pdaSampler.CreateMask(fromNode)
+		curState = fromNode.State
+		for _, toNode := range fromNode.TransitionEdges {
+			fmt.Println("toNode", toNode.State)
+		}
+	}
+
+	// runtime.ReadMemStats(&m)
+	// after = m.Alloc
+	// fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
+	// fmt.Printf("Mask creation time = %v\n", time.Since(start))
+	// fmt.Println("--------------------------------")
+
+	return so, nil
+}
+
+func (s *JSONSampler) schemaToGraph() {
+	schemaType := s.schema.EffectiveType()
+	switch schemaType {
+	case "object":
+		// TODO: see if we need to connect these to the JSON graph
+
+		// each prop is a key
+		for _, prop := range s.schema.Properties {
+			// name of key
+			name := prop.Name
+			keyNode := &PDA{
+				State:             StateInStructuredKey, // this is unchanging, will impact sampling
+				TransitionEdges:   make(map[rune]*PDA),
+				MaskTokenIDToNode: make(map[int32]*PDA),
+			}
+
+			prevNode := keyNode
+			for _, r := range name {
+				runeNode := &PDA{
+					State:             StateInStructuredKey, // this is unchanging, will impact sampling
+					TransitionEdges:   make(map[rune]*PDA),
+					MaskTokenIDToNode: make(map[int32]*PDA),
+				}
+				// fmt.Println("runeNode created", runeNode.State)
+				// fmt.Printf("runeNode created %c\n", r)
+
+				// since alloc on heap connections wil still map
+				prevNode.TransitionEdges[r] = runeNode
+				prevNode = runeNode
+			}
+
+			// point to end of object key node after all chars are done
+			// prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
+
+			// link to value node
+			// Create a node for the end of the key (after the closing quote)
+			stringEndNode := &PDA{
+				State:             StateInStructuredKey,
+				TransitionEdges:   make(map[rune]*PDA),
+				MaskTokenIDToNode: make(map[int32]*PDA),
+			}
+			prevNode.TransitionEdges['"'] = stringEndNode
+			prevNode = stringEndNode
+
+			// Add transition for colon after key
+			colonNode := &PDA{
+				State:             StateInColon,
+				TransitionEdges:   make(map[rune]*PDA),
+				MaskTokenIDToNode: make(map[int32]*PDA),
+			}
+			prevNode.TransitionEdges[':'] = colonNode
+			prevNode = colonNode
+
+			// Add transition for space after colon
+			spaceNode := &PDA{
+				State:             StateInSpaceToValue,
+				TransitionEdges:   make(map[rune]*PDA),
+				MaskTokenIDToNode: make(map[int32]*PDA),
+			}
+			prevNode.TransitionEdges[' '] = spaceNode
+			prevNode = spaceNode
+
+			value := prop.Type
+			switch value {
+			case "object":
+				fmt.Println("object under key: ", name)
+			case "array":
+				fmt.Println("array under key: ", name)
+			case "string":
+				fmt.Println("string under key: ", name)
+				prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInString]
+			case "number":
+				fmt.Println("number under key: ", name)
+				for _, r := range validNumberRunes {
+					prevNode.TransitionEdges[r] = s.pdaSampler.stateToNodeMap[StateInNumber]
+				}
+			case "boolean":
+				fmt.Println("boolean under key: ", name)
+				prevNode.TransitionEdges['t'] = s.pdaSampler.stateToNodeMap[StateInBool]
+				prevNode.TransitionEdges['f'] = s.pdaSampler.stateToNodeMap[StateInBool]
+				prevNode.TransitionEdges['n'] = s.pdaSampler.stateToNodeMap[StateInNull]
+			}
+
+			// points to start of the key
+			s.propToNodeMap[name] = keyNode
+			fmt.Println("name", name, "keyNode", keyNode.State)
+		}
+	}
+	// TODO: do values + recursion
+}
+
+func (s *JSONSampler) Apply(logits []float32) ([]float32, error) {
+	if s.schema == nil {
+		return s.pdaSampler.Apply(logits)
+	}
+
+	switch s.pdaSampler.curNode.State {
+	// TODO: doesnt account for multi rune case
+	case StateInObjectKey:
+		if s.propIdx > len(s.schema.Properties)-1 {
+			return nil, fmt.Errorf("propIdx out of bounds")
+		}
+		// fmt.Println("in object key - structured outputs")
+		// TODO: this tracking should probably be coming from a stack to track nested objects
+		// simple case
+		s.propIdx++
+		fmt.Println("propIdx", s.propIdx)
+		prop := s.schema.Properties[s.propIdx]
+		fmt.Println("prop", prop.Name)
+		s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
+		fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
+		logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
+		if err != nil {
+			return nil, err
+		}
+		return logits, nil
+
+	default:
+
+		// Will only happen for the last prop - can also be precomputed.
+		if s.propIdx == len(s.schema.Properties)-1 {
+			// todo: if i incremenet propidx then i know im in last value as well
+			switch s.pdaSampler.curNode.State {
+			case StateInObjectEnd:
+				fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State)
+				s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA)
+				s.pdaSampler.curNode = NewPDANode(StateTerminate)
+				s.propIdx++
+
+			// TODO: this needs to be optimized in some way, computing mask on the fly is expensive
+			case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
+				fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
+				delete(s.pdaSampler.curNode.TransitionEdges, ',')
+				s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA)
+
+				s.pdaSampler.CreateMask(s.pdaSampler.curNode)
+				s.propIdx++
+			}
+		}
+		return s.pdaSampler.Apply(logits)
+	}
+}
+
+func (s *JSONSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
+	tokenSlice, err := s.pdaSampler.UpdateState(tokenSlice)
+	if err != nil {
+		return nil, err
+	}
+
+	if s.schema == nil {
+		// Don't need to update state for unconstrained JSON sampling
+		return tokenSlice, nil
+	}
+
+	switch s.pdaSampler.curNode.State {
+	case StateInObjectKey:
+		s.propIdx++
+		fmt.Println("propIdx", s.propIdx)
+		prop := s.schema.Properties[s.propIdx]
+		fmt.Println("prop", prop.Name)
+		s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
+		// TODO: this does not work - mike
+		// str, err := s.pdaSampler.proc.Decode(tokenSlice)
+		// if err != nil {
+		// 	return nil, err
+		// }
+		// fmt.Println("str", str)
+
+		return tokenSlice, nil
+	default:
+		return tokenSlice, nil
+	}
+}

+ 339 - 0
sample/structured_python.go

@@ -0,0 +1,339 @@
+package sample
+
+import (
+	"fmt"
+	"math"
+	"slices"
+
+	"github.com/ollama/ollama/model"
+)
+
+type PythonState int
+
+const (
+	PythonStateStart PythonState = iota
+	StateInFunction
+	StateInFunctionArgs
+	StateInFunctionArgsType
+	StateInFunctionEnd
+	PStateInString
+	PStateInStringEnd
+	PStateInNumber
+	PStateInList
+	PStateInListEnd
+	PStateInDict
+	PStateInDictEnd
+	PStateInTuple
+	PStateInTupleEnd
+	PStateTerminate
+)
+
+func (s PythonState) String() string {
+	switch s {
+	case PythonStateStart:
+		return "PythonStateStart"
+	case StateInFunction:
+		return "StateInFunction"
+	case StateInFunctionArgs:
+		return "StateInFunctionArgs"
+	case StateInFunctionArgsType:
+		return "StateInFunctionArgsType"
+	case StateInFunctionEnd:
+		return "StateInFunctionEnd"
+	case PStateInString:
+		return "PStateInString"
+	case PStateInStringEnd:
+		return "PStateInStringEnd"
+	case PStateInNumber:
+		return "PStateInNumber"
+	case PStateInList:
+		return "PStateInList"
+	case PStateInListEnd:
+		return "PStateInListEnd"
+	case PStateInDict:
+		return "PStateInDict"
+	case PStateInDictEnd:
+		return "PStateInDictEnd"
+	case PStateInTuple:
+		return "PStateInTuple"
+	case PStateInTupleEnd:
+		return "PStateInTupleEnd"
+	case PStateTerminate:
+		return "PStateTerminate"
+	default:
+		return fmt.Sprintf("PythonState(%d)", s)
+	}
+}
+
+var PythonStates = []PythonState{
+	PythonStateStart,
+	StateInFunction,
+	StateInFunctionArgs,
+	StateInFunctionArgsType,
+	StateInFunctionEnd,
+	PStateInString,
+	PStateInStringEnd,
+	PStateInNumber,
+	PStateInList,
+	PStateInListEnd,
+	PStateInDict,
+	PStateInDictEnd,
+	PStateInTuple,
+	PStateInTupleEnd,
+	PStateTerminate,
+}
+
+type Node struct {
+	State             PythonState
+	TransitionEdges   map[rune]*Node
+	MaskTokenIDToNode map[int32]*Node
+}
+
+func NewNode(state PythonState) *Node {
+	return &Node{
+		State:             state,
+		TransitionEdges:   make(map[rune]*Node),
+		MaskTokenIDToNode: make(map[int32]*Node),
+	}
+}
+
+type PythonFunction struct {
+	Name  string
+	Args  []string
+	Types []string
+}
+
+type PythonSampler struct {
+	stateToNodes map[PythonState]*Node
+	proc         model.TextProcessor
+	decodedToks  []string
+	curNode      *Node
+}
+
+func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
+	s.proc = proc
+	decodedToks := make([]string, len(proc.Vocab().Values))
+	for i := range proc.Vocab().Values {
+		token, err := proc.Decode([]int32{int32(i)})
+		if err != nil {
+			return err
+		}
+		decodedToks[i] = token
+	}
+	s.decodedToks = decodedToks
+	s.BuildGraph()
+	for _, function := range functions {
+		prevNode := s.stateToNodes[PythonStateStart]
+
+		for _, r := range function.Name {
+			nextNode := NewNode(StateInFunction)
+			prevNode.TransitionEdges[r] = nextNode
+			if err := s.CreateMask(nextNode); err != nil {
+				return err
+			}
+			fmt.Println("prevNode", prevNode.State)
+			fmt.Printf("transition edge: %q\n", r)
+			fmt.Println("nextNode", nextNode.State)
+			prevNode = nextNode
+		}
+		prevNode.TransitionEdges['('] = s.stateToNodes[StateInFunctionArgs]
+		s.CreateMask(prevNode)
+		prevNode = s.stateToNodes[StateInFunctionArgs]
+		for i, arg := range function.Args {
+			for _, r := range arg {
+				nextNode := NewNode(StateInFunctionArgs)
+				prevNode.TransitionEdges[r] = nextNode
+				s.CreateMask(prevNode)
+				prevNode = nextNode
+			}
+			prevNode.TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
+			// prevNode = s.stateToNodes[StateInFunctionArgs]
+			prevNode.TransitionEdges['='] = NewNode(StateInFunctionArgsType)
+			s.CreateMask(prevNode)
+			prevNode = prevNode.TransitionEdges['=']
+			switch function.Types[i] {
+			case "string":
+				prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInString]
+				s.CreateMask(prevNode.TransitionEdges['"'])
+			case "number":
+				prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInNumber]
+				s.CreateMask(prevNode.TransitionEdges['"'])
+			}
+		}
+
+	}
+	s.curNode = s.stateToNodes[PythonStateStart]
+	fmt.Println("curNode", s.curNode.State)
+	fmt.Println("transition edges", s.curNode.TransitionEdges)
+	if err := s.CreateMask(s.curNode); err != nil {
+		return err
+	}
+	fmt.Println("maskTokenIDToNode", s.curNode.MaskTokenIDToNode)
+	for tokenID, node := range s.curNode.MaskTokenIDToNode {
+		fmt.Printf("tokenID: %d, node: %v\n", s.decodedToks[tokenID], node.State)
+	}
+
+	return nil
+}
+
+func (s *PythonSampler) BuildGraph() error {
+	s.stateToNodes = make(map[PythonState]*Node)
+	for _, state := range PythonStates {
+		s.stateToNodes[state] = NewNode(state)
+	}
+
+	for _, state := range s.stateToNodes {
+		if err := s.CreateMask(state); err != nil {
+			return err
+		}
+	}
+
+	// String
+	s.stateToNodes[PStateInString].TransitionEdges[rune(-1)] = s.stateToNodes[PStateInString]
+	s.stateToNodes[PStateInString].TransitionEdges['"'] = s.stateToNodes[PStateInStringEnd]
+
+	// String end
+	s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
+	s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
+	// Number
+	for _, r := range validNumberRunes {
+		s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
+	}
+	s.stateToNodes[PStateInNumber].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
+	s.stateToNodes[PStateInNumber].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
+	s.stateToNodes[PStateInNumber].TransitionEdges[' '] = s.stateToNodes[StateInFunctionArgs]
+
+	return nil
+}
+
+func (s *PythonSampler) ApplyMask(logits []float32) ([]float32, error) {
+	if s.curNode.State == PStateTerminate {
+		logits, err := finish(s, logits)
+		if err != nil {
+			return nil, err
+		}
+		return logits, nil
+	}
+	logits, err := s.maskLogits(logits, s.curNode)
+	if err != nil {
+		return nil, err
+	}
+	return logits, nil
+}
+
+func (s *PythonSampler) UpdateState(token int32) error {
+	mappedString, err := s.proc.Decode([]int32{token})
+	if err != nil {
+		return err
+	}
+	fmt.Printf(">>> mappedString: %q\n", mappedString)
+
+	if s.curNode.State == PStateTerminate {
+		if s.proc.Is(token, model.SpecialEOS) {
+			return nil
+		}
+	}
+	nextNode, ok := s.curNode.MaskTokenIDToNode[token]
+	if !ok {
+		return fmt.Errorf("invalid token: %q", mappedString)
+	}
+	s.curNode = nextNode
+	fmt.Println("curNode", s.curNode.State)
+	for r, node := range s.curNode.TransitionEdges {
+		fmt.Printf("transition edge: %q -> %v\n", r, node.State)
+	}
+	if err := s.CreateMask(s.curNode); err != nil {
+		return err
+	}
+	return nil
+}
+
+func (s *PythonSampler) CreateMask(node *Node) error {
+	if node == nil {
+		return fmt.Errorf("node cannot be nil")
+	}
+	for i := range s.decodedToks {
+		token := s.decodedToks[i]
+		// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
+		if s.proc.Is(int32(i), model.SpecialEOS) || s.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
+			continue
+		}
+		curNode := node
+		valid := true
+		consumedSpecialRunes := make(map[rune]bool)
+		for _, r := range token {
+			curNode, valid = isRValid(r, curNode, consumedSpecialRunes)
+			if curNode == nil || !valid {
+				break
+			}
+		}
+		if valid {
+			if curNode.State == StateInFunction {
+				// fmt.Println("cm curNode", curNode.State)
+				// fmt.Println("cm token", s.decodedToks[i])
+			}
+			node.MaskTokenIDToNode[int32(i)] = curNode
+		}
+	}
+	return nil
+}
+
+func isRValid(r rune, curNode *Node, consumedSpecialRunes map[rune]bool) (*Node, bool) {
+	if consumedSpecialRunes[r] {
+		return nil, false
+	}
+
+	specialRune := slices.Contains(stringInvalidRunes, r)
+	if specialRune {
+		if curNode.State == PStateInString || curNode.State == PStateInStringEnd {
+			return nil, false
+		}
+	}
+
+	// Check for specific rune transition
+	if nextNode, ok := curNode.TransitionEdges[r]; ok {
+		// fmt.Println("next node", nextNode)
+		if specialRune {
+			if curNode.State == nextNode.State {
+				return nil, false
+			}
+			consumedSpecialRunes[r] = true
+		}
+		return nextNode, true
+	}
+
+	// Check for sentinel value - if present, any rune is valid
+	if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
+		return nextNode, true
+	}
+
+	return nil, false
+}
+
+func (s *PythonSampler) maskLogits(logits []float32, node *Node) ([]float32, error) {
+	// Create a new slice with same length as logits, initialized to -Inf
+	maskedLogits := make([]float32, len(logits))
+	for i := range maskedLogits {
+		maskedLogits[i] = float32(math.Inf(-1))
+	}
+
+	// Only update values for valid token IDs from the mask map
+	for tokenID := range node.MaskTokenIDToNode {
+		if int(tokenID) < len(logits) {
+			maskedLogits[tokenID] = logits[tokenID]
+		}
+	}
+
+	return maskedLogits, nil
+}
+
+func finish(s *PythonSampler, logits []float32) ([]float32, error) {
+	for i := range logits {
+		if s.proc.Is(int32(i), model.SpecialEOS) {
+			logits[i] = 1.0
+		} else {
+			logits[i] = float32(math.Inf(-1))
+		}
+	}
+	return logits, nil
+}