constrained.go 926 B

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. package sample
  2. import (
  3. "github.com/ollama/ollama/model"
  4. )
  5. type ConstrainedSampler struct {
  6. schema *Schema
  7. propIdx int
  8. propToNodeMap map[string]*PDA
  9. pdaSampler *PushdownSampler
  10. decodedToks []string
  11. }
  12. func NewConstrainedSampler(proc model.TextProcessor, schema *Schema) (*ConstrainedSampler, error) {
  13. pdaSampler, err := NewPushdownSampler(proc)
  14. if err != nil {
  15. return nil, err
  16. }
  17. // if schema == nil {
  18. return &ConstrainedSampler{
  19. schema: nil,
  20. propIdx: -1,
  21. propToNodeMap: nil,
  22. pdaSampler: pdaSampler,
  23. }, nil
  24. }
  25. func (s *ConstrainedSampler) Apply(logits []float64) ([]float64, error) {
  26. if s.schema == nil {
  27. return s.pdaSampler.Apply(logits)
  28. }
  29. return nil, nil
  30. }
  31. func (s *ConstrainedSampler) UpdateState(tokenSlice []int32) error {
  32. if err := s.pdaSampler.UpdateState(tokenSlice); err != nil {
  33. return err
  34. }
  35. if s.schema == nil {
  36. return nil
  37. }
  38. return nil
  39. }