|
@@ -0,0 +1,104 @@
|
|
|
+package sample
|
|
|
+
|
|
|
+import (
|
|
|
+ "fmt"
|
|
|
+ "math"
|
|
|
+
|
|
|
+ "github.com/ollama/ollama/model"
|
|
|
+)
|
|
|
+
|
|
|
+type JSONState int
|
|
|
+
|
|
|
+const (
|
|
|
+ StateStart JSONState = iota // Initial state
|
|
|
+ StateInObject // Inside an object {}
|
|
|
+ StateInArray // Inside an array []
|
|
|
+ StateInString // Inside a string ""
|
|
|
+ StateAfterKey // After object key, expecting :
|
|
|
+ StateAfterColon // After :, expecting value
|
|
|
+ StateAfterValue // After value, expecting , or closing bracket
|
|
|
+ StateDone // JSON parsing complete
|
|
|
+)
|
|
|
+
|
|
|
+type JSONSampler struct {
|
|
|
+ state JSONState
|
|
|
+ stack []string
|
|
|
+ proc model.TextProcessor
|
|
|
+}
|
|
|
+
|
|
|
+func NewJSONSampler(proc model.TextProcessor) *JSONSampler {
|
|
|
+ return &JSONSampler{
|
|
|
+ state: StateStart,
|
|
|
+ proc: proc,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
|
|
|
+ // Pre-decode valid tokens for current state
|
|
|
+ validTokens := make(map[uint32]bool)
|
|
|
+
|
|
|
+ // Always allow EOS token in any state
|
|
|
+ // TODO: Check for other special tokens if needed
|
|
|
+ for i := range logits {
|
|
|
+ if s.proc.Is(uint32(i), model.SpecialEOS) {
|
|
|
+ validTokens[uint32(i)] = true
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Build set of valid tokens based on current state
|
|
|
+ switch s.state {
|
|
|
+ case StateStart:
|
|
|
+ // Only allow opening brace
|
|
|
+ for i := range logits {
|
|
|
+ text, err := s.proc.Decode([]int32{int32(i)})
|
|
|
+ if err == nil && text == "{" {
|
|
|
+ validTokens[uint32(i)] = true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ case StateInObject, StateInArray:
|
|
|
+ // Allow any token
|
|
|
+ for i := range logits {
|
|
|
+ validTokens[uint32(i)] = true
|
|
|
+ }
|
|
|
+ case StateInString:
|
|
|
+ // Allow any token except closing brace
|
|
|
+ for i := range logits {
|
|
|
+ text, err := s.proc.Decode([]int32{int32(i)})
|
|
|
+ if err == nil && text != "}" {
|
|
|
+ validTokens[uint32(i)] = true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ case StateDone:
|
|
|
+ // No tokens allowed
|
|
|
+ }
|
|
|
+
|
|
|
+ // Mark invalid tokens as NaN in one pass
|
|
|
+ for i := range logits {
|
|
|
+ if !validTokens[uint32(i)] {
|
|
|
+ logits[i] = math.NaN()
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return logits, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (s *JSONSampler) UpdateState(tokenID int) error {
|
|
|
+ text, err := s.proc.Decode([]int32{int32(tokenID)})
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("failed to decode token: %w", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ switch s.state {
|
|
|
+ case StateStart:
|
|
|
+ if text != "{" {
|
|
|
+ return fmt.Errorf("expected {, got %s", text)
|
|
|
+ }
|
|
|
+ s.state = StateInObject
|
|
|
+ case StateInObject:
|
|
|
+ if text == "}" {
|
|
|
+ s.state = StateDone
|
|
|
+ }
|
|
|
+ case StateDone:
|
|
|
+ return fmt.Errorf("unexpected token after closing bracket: %s", text)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|