structured_outputs.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. package sample
  2. import (
  3. "fmt"
  4. "runtime"
  5. "time"
  6. "github.com/ollama/ollama/model"
  7. )
  8. type JSONSampler struct {
  9. schema *Schema
  10. propIdx int
  11. propToNodeMap map[string]*PDA
  12. pdaSampler *PushdownSampler
  13. decodedToks []string
  14. }
  15. func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, error) {
  16. pdaSampler, err := NewPushdownSampler(proc)
  17. if err != nil {
  18. return nil, err
  19. }
  20. if schema == nil {
  21. return &JSONSampler{
  22. schema: nil,
  23. propIdx: -1,
  24. propToNodeMap: nil,
  25. pdaSampler: pdaSampler,
  26. }, nil
  27. }
  28. fmt.Println("schema not nil")
  29. so := &JSONSampler{
  30. schema: schema,
  31. propIdx: -1,
  32. propToNodeMap: make(map[string]*PDA),
  33. pdaSampler: pdaSampler,
  34. }
  35. so.schemaToGraph()
  36. // Benchmark token decoding
  37. start := time.Now()
  38. var m runtime.MemStats
  39. runtime.ReadMemStats(&m)
  40. before := m.Alloc
  41. vocab := proc.GetVocabulary()
  42. decodedToks := make([]string, len(vocab.Values))
  43. for i := range vocab.Values {
  44. token, err := proc.Decode([]int32{int32(i)})
  45. if err != nil {
  46. return nil, err
  47. }
  48. decodedToks[i] = token
  49. }
  50. so.decodedToks = decodedToks
  51. runtime.ReadMemStats(&m)
  52. after := m.Alloc
  53. fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
  54. fmt.Printf("Token decode time = %v\n", time.Since(start))
  55. fmt.Println("--------------------------------")
  56. fmt.Println("SOSampler")
  57. fmt.Println("--------------------------------")
  58. // Benchmark this section
  59. start = time.Now()
  60. runtime.ReadMemStats(&m)
  61. before = m.Alloc
  62. // TODO: still messed up
  63. // TODO: recursion use case
  64. // key masks
  65. for _, prop := range so.schema.Properties {
  66. node := so.propToNodeMap[prop.Name]
  67. // propName -> node
  68. curState := node.State
  69. fromNode := node
  70. so.pdaSampler.CreateMask(fromNode)
  71. for curState == StateInStructuredKey {
  72. // there is only one edge
  73. for r, toNode := range fromNode.TransitionEdges {
  74. // fmt.Println("rune", r, "edge", toNode.State)
  75. so.pdaSampler.CreateMask(toNode)
  76. fmt.Printf("created mask for %c\n", r)
  77. curState = toNode.State
  78. fmt.Println("next state", curState)
  79. // TODO: theres an extra gen for " right now
  80. fromNode = toNode
  81. }
  82. }
  83. }
  84. runtime.ReadMemStats(&m)
  85. after = m.Alloc
  86. fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
  87. fmt.Printf("Mask creation time = %v\n", time.Since(start))
  88. fmt.Println("--------------------------------")
  89. return so, nil
  90. }
  91. func (s *JSONSampler) schemaToGraph() {
  92. schemaType := s.schema.EffectiveType()
  93. switch schemaType {
  94. case "object":
  95. // TODO: see if we need to connect these to the JSON graph
  96. // each prop is a key
  97. for _, prop := range s.schema.Properties {
  98. // name of key
  99. name := prop.Name
  100. keyNode := &PDA{
  101. State: StateInStructuredKey, // this is unchanging, will impact sampling
  102. TransitionEdges: make(map[rune]*PDA),
  103. MaskTokenIDToNode: make(map[int32]*PDA),
  104. }
  105. prevNode := keyNode
  106. for _, r := range name {
  107. runeNode := &PDA{
  108. State: StateInStructuredKey, // this is unchanging, will impact sampling
  109. TransitionEdges: make(map[rune]*PDA),
  110. MaskTokenIDToNode: make(map[int32]*PDA),
  111. }
  112. fmt.Println("runeNode created", runeNode.State)
  113. fmt.Printf("runeNode created %c\n", r)
  114. // since alloc on heap connections wil still map
  115. prevNode.TransitionEdges[r] = runeNode
  116. prevNode = runeNode
  117. }
  118. // point to end of object key node after all chars are done
  119. prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
  120. // points to start of the key
  121. s.propToNodeMap[name] = keyNode
  122. fmt.Println("name", name, "keyNode", keyNode.State)
  123. }
  124. }
  125. // TODO: do values + recursion
  126. }
  127. func (s *JSONSampler) Apply(logits []float64) ([]float64, error) {
  128. if s.schema == nil {
  129. return s.pdaSampler.Apply(logits)
  130. }
  131. switch s.pdaSampler.curNode.State {
  132. // doesnt account for multi rune case
  133. case StateInObjectKey:
  134. if s.propIdx > len(s.schema.Properties)-1 {
  135. return nil, fmt.Errorf("propIdx out of bounds")
  136. }
  137. // fmt.Println("in object key - structured outputs")
  138. // TODO: this tracking should probably be coming from a stack to track nested objects
  139. // simple case
  140. s.propIdx++
  141. fmt.Println("propIdx", s.propIdx)
  142. prop := s.schema.Properties[s.propIdx]
  143. fmt.Println("prop", prop.Name)
  144. s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
  145. fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
  146. logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
  147. if err != nil {
  148. return nil, err
  149. }
  150. return logits, nil
  151. default:
  152. // Will only happen for the last prop - can also be precomputed.
  153. if s.propIdx == len(s.schema.Properties)-1 {
  154. // todo: if i incremenet propidx then i know im in last value as well
  155. switch s.pdaSampler.curNode.State {
  156. case StateInObjectEnd:
  157. fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State)
  158. s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA)
  159. s.pdaSampler.curNode = NewPDANode(StateTerminate)
  160. s.propIdx++
  161. // TODO: this needs to be optimized in some way, computing mask on the fly is expensive
  162. case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
  163. fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
  164. delete(s.pdaSampler.curNode.TransitionEdges, ',')
  165. s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA)
  166. s.pdaSampler.CreateMask(s.pdaSampler.curNode)
  167. s.propIdx++
  168. }
  169. }
  170. return s.pdaSampler.Apply(logits)
  171. }
  172. }
  173. func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
  174. err := s.pdaSampler.UpdateState(tokenSlice)
  175. if err != nil {
  176. return err
  177. }
  178. if s.schema == nil {
  179. // Don't need to update state for unconstrained JSON sampling
  180. return nil
  181. }
  182. switch s.pdaSampler.curNode.State {
  183. case StateInObjectKey:
  184. s.propIdx++
  185. fmt.Println("propIdx", s.propIdx)
  186. prop := s.schema.Properties[s.propIdx]
  187. fmt.Println("prop", prop.Name)
  188. s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
  189. str, err := s.pdaSampler.proc.Decode(tokenSlice)
  190. if err != nil {
  191. return err
  192. }
  193. fmt.Println("str", str)
  194. return nil
  195. default:
  196. return nil
  197. }
  198. }