structured_outputs.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package sample
  2. import (
  3. "fmt"
  4. "runtime"
  5. "time"
  6. "github.com/ollama/ollama/model"
  7. )
  8. type SOSampler struct {
  9. schema *Schema
  10. propIdx int
  11. propStateMap map[string]*PDANode
  12. pdaSampler *PushdownSampler
  13. }
  14. func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) {
  15. pdaSampler := NewPushdownSampler(proc)
  16. so := &SOSampler{
  17. schema: schema,
  18. propIdx: -1,
  19. propStateMap: make(map[string]*PDANode),
  20. pdaSampler: pdaSampler,
  21. }
  22. so.schemaToGraph()
  23. vocab := proc.GetVocabulary()
  24. decodedToks := make([]string, len(vocab.Values))
  25. for i := range vocab.Values {
  26. token, err := proc.Decode([]int32{int32(i)})
  27. if err != nil {
  28. return nil, err
  29. }
  30. decodedToks[i] = token
  31. }
  32. fmt.Println("--------------------------------")
  33. fmt.Println("SOSampler")
  34. fmt.Println("--------------------------------")
  35. // Benchmark this section
  36. start := time.Now()
  37. var m runtime.MemStats
  38. runtime.ReadMemStats(&m)
  39. before := m.Alloc
  40. // TODO: still messed up
  41. for _, node := range so.propStateMap {
  42. // propName -> node
  43. curState := node.State
  44. fromNode := node
  45. CreateMask(fromNode, proc, decodedToks, vocab)
  46. for curState == StateInStructuredKey {
  47. // there is only one edge
  48. for r, toNode := range fromNode.TransitionEdges {
  49. // fmt.Println("rune", r, "edge", toNode.State)
  50. CreateMask(toNode, proc, decodedToks, vocab)
  51. fmt.Printf("created mask for %c\n", r)
  52. curState = toNode.State
  53. fmt.Println("next state", curState)
  54. // TODO: theres an extra gen for " right now
  55. fromNode = toNode
  56. }
  57. }
  58. }
  59. runtime.ReadMemStats(&m)
  60. after := m.Alloc
  61. fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
  62. fmt.Printf("Mask creation time = %v\n", time.Since(start))
  63. fmt.Println("--------------------------------")
  64. return so, nil
  65. }
  66. func (s *SOSampler) schemaToGraph() {
  67. schemaType := s.schema.EffectiveType()
  68. switch schemaType {
  69. case "object":
  70. // TODO: see if we need to connect these to the JSON graph
  71. // prevState := StateInObjectKey
  72. // prevNode := so.stateToNodeMap[prevState]
  73. // each prop is a key
  74. for _, prop := range s.schema.Properties {
  75. // name of key
  76. name := prop.Name
  77. // prevState := StateInObjectKey
  78. keyNode := &PDANode{
  79. State: StateInStructuredKey, // this is unchanging, will impact sampling
  80. TransitionEdges: make(map[rune]*PDANode),
  81. MaskTokenIDToNode: make(map[int32]*PDANode),
  82. }
  83. prevNode := keyNode
  84. for _, r := range name {
  85. runeNode := &PDANode{
  86. State: StateInStructuredKey, // this is unchanging, will impact sampling
  87. TransitionEdges: make(map[rune]*PDANode),
  88. MaskTokenIDToNode: make(map[int32]*PDANode),
  89. }
  90. fmt.Println("runeNode created", runeNode.State)
  91. fmt.Printf("runeNode created %c\n", r)
  92. // since alloc on heap connections wil still map
  93. prevNode.TransitionEdges[r] = runeNode
  94. prevNode = runeNode
  95. }
  96. // point to end of object key node after all chars are done
  97. prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
  98. // points to start of the key
  99. s.propStateMap[name] = keyNode
  100. fmt.Println("name", name, "keyNode", keyNode.State)
  101. }
  102. }
  103. }
  104. func (s *SOSampler) Sample(logits []float64) ([]float64, error) {
  105. switch s.pdaSampler.curNode.State {
  106. // doesnt account for multi rune case
  107. case StateInObjectKey:
  108. // fmt.Println("in object key - structured outputs")
  109. // TODO: this tracking should probably be coming from a stack to track nested objects
  110. // simple case
  111. s.propIdx++
  112. prop := s.schema.Properties[s.propIdx]
  113. // fmt.Println("prop", prop.Name)
  114. s.pdaSampler.curNode = s.propStateMap[prop.Name]
  115. // fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
  116. logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
  117. if err != nil {
  118. return nil, err
  119. }
  120. return logits, nil
  121. default:
  122. return s.pdaSampler.Sample(logits)
  123. }
  124. }
  125. func (s *SOSampler) UpdateState(tokenSlice []int32) error {
  126. return s.pdaSampler.UpdateState(tokenSlice)
  127. }