structured_outputs.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. package sample
  2. import (
  3. "fmt"
  4. "log/slog"
  5. "runtime"
  6. "time"
  7. "github.com/ollama/ollama/grammar/jsonschema"
  8. "github.com/ollama/ollama/model"
  9. )
  10. type JSONSampler struct {
  11. schema *jsonschema.Schema
  12. propIdx int
  13. propToNodeMap map[string]*PDA
  14. pdaSampler *PushdownSampler
  15. decodedToks []string
  16. }
  17. func NewJSONSampler(proc model.TextProcessor, schema *jsonschema.Schema) (*JSONSampler, error) {
  18. slog.Info("NewJSONSampler", "schema", schema)
  19. if proc == nil {
  20. return nil, fmt.Errorf("TextProcessor cannot be nil")
  21. }
  22. pdaSampler, err := NewPushdownSampler(proc)
  23. if err != nil {
  24. return nil, fmt.Errorf("failed to create PushdownSampler: %w", err)
  25. }
  26. if schema == nil {
  27. return &JSONSampler{
  28. schema: nil,
  29. propIdx: -1,
  30. propToNodeMap: nil,
  31. pdaSampler: pdaSampler,
  32. }, nil
  33. }
  34. // fmt.Println("schema not nil")
  35. so := &JSONSampler{
  36. schema: schema,
  37. propIdx: -1,
  38. propToNodeMap: make(map[string]*PDA),
  39. pdaSampler: pdaSampler,
  40. }
  41. so.schemaToGraph()
  42. // Benchmark token decoding
  43. start := time.Now()
  44. var m runtime.MemStats
  45. runtime.ReadMemStats(&m)
  46. before := m.Alloc
  47. vocab := proc.Vocab()
  48. decodedToks := make([]string, len(vocab.Values))
  49. for i := range vocab.Values {
  50. token, err := proc.Decode([]int32{int32(i)})
  51. if err != nil {
  52. return nil, err
  53. }
  54. decodedToks[i] = token
  55. }
  56. so.decodedToks = decodedToks
  57. runtime.ReadMemStats(&m)
  58. after := m.Alloc
  59. fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
  60. fmt.Printf("Token decode time = %v\n", time.Since(start))
  61. fmt.Println("--------------------------------")
  62. fmt.Println("SOSampler")
  63. fmt.Println("--------------------------------")
  64. // Benchmark this section
  65. start = time.Now()
  66. runtime.ReadMemStats(&m)
  67. before = m.Alloc
  68. // TODO: still messed up
  69. // TODO: recursion use case
  70. // key masks
  71. for _, prop := range so.schema.Properties {
  72. node := so.propToNodeMap[prop.Name]
  73. // propName -> node
  74. curState := node.State
  75. fromNode := node
  76. so.pdaSampler.CreateMask(fromNode)
  77. for curState == StateInStructuredKey {
  78. // there is only one edge
  79. for r, toNode := range fromNode.TransitionEdges {
  80. fmt.Println("rune", r, "edge", toNode.State)
  81. so.pdaSampler.CreateMask(toNode)
  82. fmt.Printf("created mask for %c\n", r)
  83. curState = toNode.State
  84. fmt.Println("next state", curState)
  85. // TODO: theres an extra gen for " right now
  86. fromNode = toNode
  87. }
  88. }
  89. if curState != StateInColon {
  90. return nil, fmt.Errorf("expected state to be StateInColon, got %v", curState)
  91. }
  92. // so.pdaSampler.CreateMask(fromNode)
  93. fromNode = fromNode.TransitionEdges[' ']
  94. so.pdaSampler.CreateMask(fromNode)
  95. curState = fromNode.State
  96. for _, toNode := range fromNode.TransitionEdges {
  97. fmt.Println("toNode", toNode.State)
  98. }
  99. }
  100. // runtime.ReadMemStats(&m)
  101. // after = m.Alloc
  102. // fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
  103. // fmt.Printf("Mask creation time = %v\n", time.Since(start))
  104. // fmt.Println("--------------------------------")
  105. return so, nil
  106. }
  107. func (s *JSONSampler) schemaToGraph() {
  108. schemaType := s.schema.EffectiveType()
  109. switch schemaType {
  110. case "object":
  111. // TODO: see if we need to connect these to the JSON graph
  112. // each prop is a key
  113. for _, prop := range s.schema.Properties {
  114. // name of key
  115. name := prop.Name
  116. keyNode := &PDA{
  117. State: StateInStructuredKey, // this is unchanging, will impact sampling
  118. TransitionEdges: make(map[rune]*PDA),
  119. MaskTokenIDToNode: make(map[int32]*PDA),
  120. }
  121. prevNode := keyNode
  122. for _, r := range name {
  123. runeNode := &PDA{
  124. State: StateInStructuredKey, // this is unchanging, will impact sampling
  125. TransitionEdges: make(map[rune]*PDA),
  126. MaskTokenIDToNode: make(map[int32]*PDA),
  127. }
  128. // fmt.Println("runeNode created", runeNode.State)
  129. // fmt.Printf("runeNode created %c\n", r)
  130. // since alloc on heap connections wil still map
  131. prevNode.TransitionEdges[r] = runeNode
  132. prevNode = runeNode
  133. }
  134. // point to end of object key node after all chars are done
  135. // prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
  136. // link to value node
  137. // Create a node for the end of the key (after the closing quote)
  138. stringEndNode := &PDA{
  139. State: StateInStructuredKey,
  140. TransitionEdges: make(map[rune]*PDA),
  141. MaskTokenIDToNode: make(map[int32]*PDA),
  142. }
  143. prevNode.TransitionEdges['"'] = stringEndNode
  144. prevNode = stringEndNode
  145. // Add transition for colon after key
  146. colonNode := &PDA{
  147. State: StateInColon,
  148. TransitionEdges: make(map[rune]*PDA),
  149. MaskTokenIDToNode: make(map[int32]*PDA),
  150. }
  151. prevNode.TransitionEdges[':'] = colonNode
  152. prevNode = colonNode
  153. // Add transition for space after colon
  154. spaceNode := &PDA{
  155. State: StateInSpaceToValue,
  156. TransitionEdges: make(map[rune]*PDA),
  157. MaskTokenIDToNode: make(map[int32]*PDA),
  158. }
  159. prevNode.TransitionEdges[' '] = spaceNode
  160. prevNode = spaceNode
  161. value := prop.Type
  162. switch value {
  163. case "object":
  164. fmt.Println("object under key: ", name)
  165. case "array":
  166. fmt.Println("array under key: ", name)
  167. case "string":
  168. fmt.Println("string under key: ", name)
  169. prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInString]
  170. case "number":
  171. fmt.Println("number under key: ", name)
  172. for _, r := range validNumberRunes {
  173. prevNode.TransitionEdges[r] = s.pdaSampler.stateToNodeMap[StateInNumber]
  174. }
  175. case "boolean":
  176. fmt.Println("boolean under key: ", name)
  177. prevNode.TransitionEdges['t'] = s.pdaSampler.stateToNodeMap[StateInBool]
  178. prevNode.TransitionEdges['f'] = s.pdaSampler.stateToNodeMap[StateInBool]
  179. prevNode.TransitionEdges['n'] = s.pdaSampler.stateToNodeMap[StateInNull]
  180. }
  181. // points to start of the key
  182. s.propToNodeMap[name] = keyNode
  183. fmt.Println("name", name, "keyNode", keyNode.State)
  184. }
  185. }
  186. // TODO: do values + recursion
  187. }
  188. func (s *JSONSampler) Apply(logits []float32) ([]float32, error) {
  189. if s.schema == nil {
  190. return s.pdaSampler.Apply(logits)
  191. }
  192. switch s.pdaSampler.curNode.State {
  193. // TODO: doesnt account for multi rune case
  194. case StateInObjectKey:
  195. if s.propIdx > len(s.schema.Properties)-1 {
  196. return nil, fmt.Errorf("propIdx out of bounds")
  197. }
  198. // fmt.Println("in object key - structured outputs")
  199. // TODO: this tracking should probably be coming from a stack to track nested objects
  200. // simple case
  201. s.propIdx++
  202. fmt.Println("propIdx", s.propIdx)
  203. prop := s.schema.Properties[s.propIdx]
  204. fmt.Println("prop", prop.Name)
  205. s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
  206. fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
  207. logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
  208. if err != nil {
  209. return nil, err
  210. }
  211. return logits, nil
  212. default:
  213. // Will only happen for the last prop - can also be precomputed.
  214. if s.propIdx == len(s.schema.Properties)-1 {
  215. // todo: if i incremenet propidx then i know im in last value as well
  216. switch s.pdaSampler.curNode.State {
  217. case StateInObjectEnd:
  218. fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State)
  219. s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA)
  220. s.pdaSampler.curNode = NewPDANode(StateTerminate)
  221. s.propIdx++
  222. // TODO: this needs to be optimized in some way, computing mask on the fly is expensive
  223. case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
  224. fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
  225. delete(s.pdaSampler.curNode.TransitionEdges, ',')
  226. s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA)
  227. s.pdaSampler.CreateMask(s.pdaSampler.curNode)
  228. s.propIdx++
  229. }
  230. }
  231. return s.pdaSampler.Apply(logits)
  232. }
  233. }
  234. func (s *JSONSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
  235. tokenSlice, err := s.pdaSampler.UpdateState(tokenSlice)
  236. if err != nil {
  237. return nil, err
  238. }
  239. if s.schema == nil {
  240. // Don't need to update state for unconstrained JSON sampling
  241. return tokenSlice, nil
  242. }
  243. switch s.pdaSampler.curNode.State {
  244. case StateInObjectKey:
  245. s.propIdx++
  246. fmt.Println("propIdx", s.propIdx)
  247. prop := s.schema.Properties[s.propIdx]
  248. fmt.Println("prop", prop.Name)
  249. s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
  250. // TODO: this does not work - mike
  251. // str, err := s.pdaSampler.proc.Decode(tokenSlice)
  252. // if err != nil {
  253. // return nil, err
  254. // }
  255. // fmt.Println("str", str)
  256. return tokenSlice, nil
  257. default:
  258. return tokenSlice, nil
  259. }
  260. }