123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- package sample
- import (
- "fmt"
- "math"
- "github.com/ollama/ollama/model"
- )
- type JSONState int
- const (
- StateStart JSONState = iota // Initial state
- StateInObject // Inside an object {}
- StateInArray // Inside an array []
- StateInString // Inside a string ""
- StateAfterKey // After object key, expecting :
- StateAfterColon // After :, expecting value
- StateAfterValue // After value, expecting , or closing bracket
- StateDone // JSON parsing complete
- )
- type JSONSampler struct {
- state JSONState
- stack []string
- proc model.TextProcessor
- }
- func NewJSONSampler(proc model.TextProcessor) *JSONSampler {
- return &JSONSampler{
- state: StateStart,
- proc: proc,
- }
- }
- func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
- // Pre-decode valid tokens for current state
- validTokens := make(map[uint32]bool)
- // Always allow EOS token in any state
- // TODO: Check for other special tokens if needed
- for i := range logits {
- if s.proc.Is(uint32(i), model.SpecialEOS) {
- validTokens[uint32(i)] = true
- }
- }
- // Build set of valid tokens based on current state
- switch s.state {
- case StateStart:
- // Only allow opening brace
- for i := range logits {
- text, err := s.proc.Decode([]int32{int32(i)})
- if err == nil && text == "{" {
- validTokens[uint32(i)] = true
- }
- }
- case StateInObject, StateInArray:
- // Allow any token
- for i := range logits {
- validTokens[uint32(i)] = true
- }
- case StateInString:
- // Allow any token except closing brace
- for i := range logits {
- text, err := s.proc.Decode([]int32{int32(i)})
- if err == nil && text != "}" {
- validTokens[uint32(i)] = true
- }
- }
- case StateDone:
- // No tokens allowed
- }
- // Mark invalid tokens as NaN in one pass
- for i := range logits {
- if !validTokens[uint32(i)] {
- logits[i] = math.NaN()
- }
- }
- return logits, nil
- }
- func (s *JSONSampler) UpdateState(tokenID int) error {
- text, err := s.proc.Decode([]int32{int32(tokenID)})
- if err != nil {
- return fmt.Errorf("failed to decode token: %w", err)
- }
- switch s.state {
- case StateStart:
- if text != "{" {
- return fmt.Errorf("expected {, got %s", text)
- }
- s.state = StateInObject
- case StateInObject:
- if text == "}" {
- s.state = StateDone
- }
- case StateDone:
- return fmt.Errorf("unexpected token after closing bracket: %s", text)
- }
- return nil
- }
|