json_sampler.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. package sample
  2. import (
  3. "fmt"
  4. "math"
  5. "github.com/ollama/ollama/model"
  6. )
  7. type JSONState int
  8. const (
  9. StateStart JSONState = iota // Initial state
  10. StateInObject // Inside an object {}
  11. StateInArray // Inside an array []
  12. StateInString // Inside a string ""
  13. StateAfterKey // After object key, expecting :
  14. StateAfterColon // After :, expecting value
  15. StateAfterValue // After value, expecting , or closing bracket
  16. StateDone // JSON parsing complete
  17. )
  18. type JSONSampler struct {
  19. state JSONState
  20. stack []string
  21. proc model.TextProcessor
  22. }
  23. func NewJSONSampler(proc model.TextProcessor) *JSONSampler {
  24. return &JSONSampler{
  25. state: StateStart,
  26. proc: proc,
  27. }
  28. }
  29. func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
  30. // Pre-decode valid tokens for current state
  31. validTokens := make(map[uint32]bool)
  32. // Always allow EOS token in any state
  33. // TODO: Check for other special tokens if needed
  34. for i := range logits {
  35. if s.proc.Is(uint32(i), model.SpecialEOS) {
  36. validTokens[uint32(i)] = true
  37. }
  38. }
  39. // Build set of valid tokens based on current state
  40. switch s.state {
  41. case StateStart:
  42. // Only allow opening brace
  43. for i := range logits {
  44. text, err := s.proc.Decode([]int32{int32(i)})
  45. if err == nil && text == "{" {
  46. validTokens[uint32(i)] = true
  47. }
  48. }
  49. case StateInObject, StateInArray:
  50. // Allow any token
  51. for i := range logits {
  52. validTokens[uint32(i)] = true
  53. }
  54. case StateInString:
  55. // Allow any token except closing brace
  56. for i := range logits {
  57. text, err := s.proc.Decode([]int32{int32(i)})
  58. if err == nil && text != "}" {
  59. validTokens[uint32(i)] = true
  60. }
  61. }
  62. case StateDone:
  63. // No tokens allowed
  64. }
  65. // Mark invalid tokens as NaN in one pass
  66. for i := range logits {
  67. if !validTokens[uint32(i)] {
  68. logits[i] = math.NaN()
  69. }
  70. }
  71. return logits, nil
  72. }
  73. func (s *JSONSampler) UpdateState(tokenID int) error {
  74. text, err := s.proc.Decode([]int32{int32(tokenID)})
  75. if err != nil {
  76. return fmt.Errorf("failed to decode token: %w", err)
  77. }
  78. switch s.state {
  79. case StateStart:
  80. if text != "{" {
  81. return fmt.Errorf("expected {, got %s", text)
  82. }
  83. s.state = StateInObject
  84. case StateInObject:
  85. if text == "}" {
  86. s.state = StateDone
  87. }
  88. case StateDone:
  89. return fmt.Errorf("unexpected token after closing bracket: %s", text)
  90. }
  91. return nil
  92. }