structured_outputs.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. package sample
  2. import (
  3. "fmt"
  4. "runtime"
  5. "time"
  6. "github.com/ollama/ollama/model"
  7. )
  8. type SOSampler struct {
  9. schema *Schema
  10. propIdx int
  11. propToNodeMap map[string]*PDANode
  12. pdaSampler *PushdownSampler
  13. decodedToks []string
  14. }
  15. func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) {
  16. pdaSampler := NewPushdownSampler(proc)
  17. so := &SOSampler{
  18. schema: schema,
  19. propIdx: -1,
  20. propToNodeMap: make(map[string]*PDANode),
  21. pdaSampler: pdaSampler,
  22. }
  23. so.schemaToGraph()
  24. // This is prob slow
  25. vocab := proc.GetVocabulary()
  26. decodedToks := make([]string, len(vocab.Values))
  27. for i := range vocab.Values {
  28. token, err := proc.Decode([]int32{int32(i)})
  29. if err != nil {
  30. return nil, err
  31. }
  32. decodedToks[i] = token
  33. }
  34. so.decodedToks = decodedToks
  35. fmt.Println("--------------------------------")
  36. fmt.Println("SOSampler")
  37. fmt.Println("--------------------------------")
  38. // Benchmark this section
  39. start := time.Now()
  40. var m runtime.MemStats
  41. runtime.ReadMemStats(&m)
  42. before := m.Alloc
  43. // TODO: still messed up
  44. // TODO: recursion use case
  45. // key masks
  46. for _, prop := range so.schema.Properties {
  47. node := so.propToNodeMap[prop.Name]
  48. // propName -> node
  49. curState := node.State
  50. fromNode := node
  51. CreateMask(fromNode, proc, decodedToks)
  52. for curState == StateInStructuredKey {
  53. // there is only one edge
  54. for r, toNode := range fromNode.TransitionEdges {
  55. // fmt.Println("rune", r, "edge", toNode.State)
  56. CreateMask(toNode, proc, decodedToks)
  57. fmt.Printf("created mask for %c\n", r)
  58. curState = toNode.State
  59. fmt.Println("next state", curState)
  60. // TODO: theres an extra gen for " right now
  61. fromNode = toNode
  62. }
  63. }
  64. }
  65. runtime.ReadMemStats(&m)
  66. after := m.Alloc
  67. fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
  68. fmt.Printf("Mask creation time = %v\n", time.Since(start))
  69. fmt.Println("--------------------------------")
  70. return so, nil
  71. }
  72. func (s *SOSampler) schemaToGraph() {
  73. schemaType := s.schema.EffectiveType()
  74. switch schemaType {
  75. case "object":
  76. // TODO: see if we need to connect these to the JSON graph
  77. // each prop is a key
  78. for _, prop := range s.schema.Properties {
  79. // name of key
  80. name := prop.Name
  81. keyNode := &PDANode{
  82. State: StateInStructuredKey, // this is unchanging, will impact sampling
  83. TransitionEdges: make(map[rune]*PDANode),
  84. MaskTokenIDToNode: make(map[int32]*PDANode),
  85. }
  86. prevNode := keyNode
  87. for _, r := range name {
  88. runeNode := &PDANode{
  89. State: StateInStructuredKey, // this is unchanging, will impact sampling
  90. TransitionEdges: make(map[rune]*PDANode),
  91. MaskTokenIDToNode: make(map[int32]*PDANode),
  92. }
  93. fmt.Println("runeNode created", runeNode.State)
  94. fmt.Printf("runeNode created %c\n", r)
  95. // since alloc on heap connections wil still map
  96. prevNode.TransitionEdges[r] = runeNode
  97. prevNode = runeNode
  98. }
  99. // point to end of object key node after all chars are done
  100. prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
  101. // points to start of the key
  102. s.propToNodeMap[name] = keyNode
  103. fmt.Println("name", name, "keyNode", keyNode.State)
  104. }
  105. }
  106. }
  107. func (s *SOSampler) Apply(logits []float64) ([]float64, error) {
  108. switch s.pdaSampler.curNode.State {
  109. // doesnt account for multi rune case
  110. case StateInObjectKey:
  111. if s.propIdx > len(s.schema.Properties)-1 {
  112. return nil, fmt.Errorf("propIdx out of bounds")
  113. }
  114. // fmt.Println("in object key - structured outputs")
  115. // TODO: this tracking should probably be coming from a stack to track nested objects
  116. // simple case
  117. s.propIdx++
  118. fmt.Println("propIdx", s.propIdx)
  119. prop := s.schema.Properties[s.propIdx]
  120. fmt.Println("prop", prop.Name)
  121. s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
  122. fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
  123. logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
  124. if err != nil {
  125. return nil, err
  126. }
  127. return logits, nil
  128. default:
  129. // Will only happen for the last prop - can also be precomputed.
  130. if s.propIdx == len(s.schema.Properties)-1 {
  131. // todo: if i incremenet propidx then i know im in last value as well
  132. switch s.pdaSampler.curNode.State {
  133. case StateInObjectEnd:
  134. fmt.Println("<<<<< in obj end- generating mask for", s.pdaSampler.curNode.State)
  135. s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDANode)
  136. s.pdaSampler.curNode = NewPDANode(StateTerminate)
  137. s.propIdx++
  138. case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
  139. fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
  140. delete(s.pdaSampler.curNode.TransitionEdges, ',')
  141. s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDANode)
  142. CreateMask(s.pdaSampler.curNode, s.pdaSampler.proc, s.decodedToks)
  143. s.propIdx++
  144. }
  145. }
  146. return s.pdaSampler.Apply(logits)
  147. }
  148. }
  149. func (s *SOSampler) UpdateState(tokenSlice []int32) error {
  150. err := s.pdaSampler.UpdateState(tokenSlice)
  151. if err != nil {
  152. return err
  153. }
  154. switch s.pdaSampler.curNode.State {
  155. case StateInObjectKey:
  156. s.propIdx++
  157. fmt.Println("propIdx", s.propIdx)
  158. prop := s.schema.Properties[s.propIdx]
  159. fmt.Println("prop", prop.Name)
  160. s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
  161. str, err := s.pdaSampler.proc.Decode(tokenSlice)
  162. if err != nil {
  163. return err
  164. }
  165. fmt.Println("str", str)
  166. return nil
  167. default:
  168. return nil
  169. }
  170. }