structured_outputs.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. package sample
  2. import (
  3. "fmt"
  4. "runtime"
  5. "time"
  6. "github.com/ollama/ollama/model"
  7. )
  8. type JSONSampler struct {
  9. schema *Schema
  10. propIdx int
  11. propToNodeMap map[string]*PDA
  12. pdaSampler *PushdownSampler
  13. decodedToks []string
  14. }
  15. func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, error) {
  16. if proc == nil {
  17. return nil, fmt.Errorf("TextProcessor cannot be nil")
  18. }
  19. pdaSampler, err := NewPushdownSampler(proc)
  20. if err != nil {
  21. return nil, fmt.Errorf("failed to create PushdownSampler: %w", err)
  22. }
  23. if schema == nil {
  24. return &JSONSampler{
  25. schema: nil,
  26. propIdx: -1,
  27. propToNodeMap: nil,
  28. pdaSampler: pdaSampler,
  29. }, nil
  30. }
  31. // fmt.Println("schema not nil")
  32. so := &JSONSampler{
  33. schema: schema,
  34. propIdx: -1,
  35. propToNodeMap: make(map[string]*PDA),
  36. pdaSampler: pdaSampler,
  37. }
  38. so.schemaToGraph()
  39. // Benchmark token decoding
  40. start := time.Now()
  41. var m runtime.MemStats
  42. runtime.ReadMemStats(&m)
  43. before := m.Alloc
  44. vocab := proc.GetVocabulary()
  45. decodedToks := make([]string, len(vocab.Values))
  46. for i := range vocab.Values {
  47. token, err := proc.Decode([]int32{int32(i)})
  48. if err != nil {
  49. return nil, err
  50. }
  51. decodedToks[i] = token
  52. }
  53. so.decodedToks = decodedToks
  54. runtime.ReadMemStats(&m)
  55. after := m.Alloc
  56. fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
  57. fmt.Printf("Token decode time = %v\n", time.Since(start))
  58. fmt.Println("--------------------------------")
  59. fmt.Println("SOSampler")
  60. fmt.Println("--------------------------------")
  61. // Benchmark this section
  62. start = time.Now()
  63. runtime.ReadMemStats(&m)
  64. before = m.Alloc
  65. // TODO: still messed up
  66. // TODO: recursion use case
  67. // key masks
  68. for _, prop := range so.schema.Properties {
  69. node := so.propToNodeMap[prop.Name]
  70. // propName -> node
  71. curState := node.State
  72. fromNode := node
  73. so.pdaSampler.CreateMask(fromNode)
  74. for curState == StateInStructuredKey {
  75. // there is only one edge
  76. for r, toNode := range fromNode.TransitionEdges {
  77. fmt.Println("rune", r, "edge", toNode.State)
  78. so.pdaSampler.CreateMask(toNode)
  79. fmt.Printf("created mask for %c\n", r)
  80. curState = toNode.State
  81. fmt.Println("next state", curState)
  82. // TODO: theres an extra gen for " right now
  83. fromNode = toNode
  84. }
  85. }
  86. if curState != StateInColon {
  87. return nil, fmt.Errorf("expected state to be StateInColon, got %v", curState)
  88. }
  89. // so.pdaSampler.CreateMask(fromNode)
  90. fromNode = fromNode.TransitionEdges[' ']
  91. so.pdaSampler.CreateMask(fromNode)
  92. curState = fromNode.State
  93. for _, toNode := range fromNode.TransitionEdges {
  94. fmt.Println("toNode", toNode.State)
  95. }
  96. }
  97. // runtime.ReadMemStats(&m)
  98. // after = m.Alloc
  99. // fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
  100. // fmt.Printf("Mask creation time = %v\n", time.Since(start))
  101. // fmt.Println("--------------------------------")
  102. return so, nil
  103. }
  104. func (s *JSONSampler) schemaToGraph() {
  105. schemaType := s.schema.EffectiveType()
  106. switch schemaType {
  107. case "object":
  108. // TODO: see if we need to connect these to the JSON graph
  109. // each prop is a key
  110. for _, prop := range s.schema.Properties {
  111. // name of key
  112. name := prop.Name
  113. keyNode := &PDA{
  114. State: StateInStructuredKey, // this is unchanging, will impact sampling
  115. TransitionEdges: make(map[rune]*PDA),
  116. MaskTokenIDToNode: make(map[int32]*PDA),
  117. }
  118. prevNode := keyNode
  119. for _, r := range name {
  120. runeNode := &PDA{
  121. State: StateInStructuredKey, // this is unchanging, will impact sampling
  122. TransitionEdges: make(map[rune]*PDA),
  123. MaskTokenIDToNode: make(map[int32]*PDA),
  124. }
  125. // fmt.Println("runeNode created", runeNode.State)
  126. // fmt.Printf("runeNode created %c\n", r)
  127. // since alloc on heap connections wil still map
  128. prevNode.TransitionEdges[r] = runeNode
  129. prevNode = runeNode
  130. }
  131. // point to end of object key node after all chars are done
  132. // prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
  133. // link to value node
  134. // Create a node for the end of the key (after the closing quote)
  135. stringEndNode := &PDA{
  136. State: StateInStructuredKey,
  137. TransitionEdges: make(map[rune]*PDA),
  138. MaskTokenIDToNode: make(map[int32]*PDA),
  139. }
  140. prevNode.TransitionEdges['"'] = stringEndNode
  141. prevNode = stringEndNode
  142. // Add transition for colon after key
  143. colonNode := &PDA{
  144. State: StateInColon,
  145. TransitionEdges: make(map[rune]*PDA),
  146. MaskTokenIDToNode: make(map[int32]*PDA),
  147. }
  148. prevNode.TransitionEdges[':'] = colonNode
  149. prevNode = colonNode
  150. // Add transition for space after colon
  151. spaceNode := &PDA{
  152. State: StateInSpaceToValue,
  153. TransitionEdges: make(map[rune]*PDA),
  154. MaskTokenIDToNode: make(map[int32]*PDA),
  155. }
  156. prevNode.TransitionEdges[' '] = spaceNode
  157. prevNode = spaceNode
  158. value := prop.Type
  159. switch value {
  160. case "object":
  161. fmt.Println("object under key: ", name)
  162. case "array":
  163. fmt.Println("array under key: ", name)
  164. case "string":
  165. fmt.Println("string under key: ", name)
  166. prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInString]
  167. case "number":
  168. fmt.Println("number under key: ", name)
  169. for _, r := range validNumberRunes {
  170. prevNode.TransitionEdges[r] = s.pdaSampler.stateToNodeMap[StateInNumber]
  171. }
  172. case "boolean":
  173. fmt.Println("boolean under key: ", name)
  174. prevNode.TransitionEdges['t'] = s.pdaSampler.stateToNodeMap[StateInBool]
  175. prevNode.TransitionEdges['f'] = s.pdaSampler.stateToNodeMap[StateInBool]
  176. prevNode.TransitionEdges['n'] = s.pdaSampler.stateToNodeMap[StateInNull]
  177. }
  178. // points to start of the key
  179. s.propToNodeMap[name] = keyNode
  180. fmt.Println("name", name, "keyNode", keyNode.State)
  181. }
  182. }
  183. // TODO: do values + recursion
  184. }
  185. func (s *JSONSampler) Apply(logits []float64) ([]float64, error) {
  186. if s.schema == nil {
  187. return s.pdaSampler.Apply(logits)
  188. }
  189. switch s.pdaSampler.curNode.State {
  190. // TODO: doesnt account for multi rune case
  191. case StateInObjectKey:
  192. if s.propIdx > len(s.schema.Properties)-1 {
  193. return nil, fmt.Errorf("propIdx out of bounds")
  194. }
  195. // fmt.Println("in object key - structured outputs")
  196. // TODO: this tracking should probably be coming from a stack to track nested objects
  197. // simple case
  198. s.propIdx++
  199. fmt.Println("propIdx", s.propIdx)
  200. prop := s.schema.Properties[s.propIdx]
  201. fmt.Println("prop", prop.Name)
  202. s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
  203. fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
  204. logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
  205. if err != nil {
  206. return nil, err
  207. }
  208. return logits, nil
  209. default:
  210. // Will only happen for the last prop - can also be precomputed.
  211. if s.propIdx == len(s.schema.Properties)-1 {
  212. // todo: if i incremenet propidx then i know im in last value as well
  213. switch s.pdaSampler.curNode.State {
  214. case StateInObjectEnd:
  215. fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State)
  216. s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA)
  217. s.pdaSampler.curNode = NewPDANode(StateTerminate)
  218. s.propIdx++
  219. // TODO: this needs to be optimized in some way, computing mask on the fly is expensive
  220. case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
  221. fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
  222. delete(s.pdaSampler.curNode.TransitionEdges, ',')
  223. s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA)
  224. s.pdaSampler.CreateMask(s.pdaSampler.curNode)
  225. s.propIdx++
  226. }
  227. }
  228. return s.pdaSampler.Apply(logits)
  229. }
  230. }
  231. func (s *JSONSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
  232. tokenSlice, err := s.pdaSampler.UpdateState(tokenSlice)
  233. if err != nil {
  234. return nil, err
  235. }
  236. if s.schema == nil {
  237. // Don't need to update state for unconstrained JSON sampling
  238. return tokenSlice, nil
  239. }
  240. switch s.pdaSampler.curNode.State {
  241. case StateInObjectKey:
  242. s.propIdx++
  243. fmt.Println("propIdx", s.propIdx)
  244. prop := s.schema.Properties[s.propIdx]
  245. fmt.Println("prop", prop.Name)
  246. s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
  247. // TODO: this does not work - mike
  248. // str, err := s.pdaSampler.proc.Decode(tokenSlice)
  249. // if err != nil {
  250. // return nil, err
  251. // }
  252. // fmt.Println("str", str)
  253. return tokenSlice, nil
  254. default:
  255. return tokenSlice, nil
  256. }
  257. }