structured_python.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. package sample
  2. import (
  3. "fmt"
  4. "math"
  5. "slices"
  6. "github.com/ollama/ollama/model"
  7. )
  8. type PythonState int
  9. const (
  10. PythonStateStart PythonState = iota
  11. StateInFunction
  12. StateInFunctionArgs
  13. StateInFunctionArgsType
  14. StateInFunctionEnd
  15. PStateInString
  16. PStateInStringEnd
  17. PStateInNumber
  18. PStateInList
  19. PStateInListEnd
  20. PStateInDict
  21. PStateInDictEnd
  22. PStateInTuple
  23. PStateInTupleEnd
  24. PStateTerminate
  25. )
  26. func (s PythonState) String() string {
  27. switch s {
  28. case PythonStateStart:
  29. return "PythonStateStart"
  30. case StateInFunction:
  31. return "StateInFunction"
  32. case StateInFunctionArgs:
  33. return "StateInFunctionArgs"
  34. case StateInFunctionArgsType:
  35. return "StateInFunctionArgsType"
  36. case StateInFunctionEnd:
  37. return "StateInFunctionEnd"
  38. case PStateInString:
  39. return "PStateInString"
  40. case PStateInStringEnd:
  41. return "PStateInStringEnd"
  42. case PStateInNumber:
  43. return "PStateInNumber"
  44. case PStateInList:
  45. return "PStateInList"
  46. case PStateInListEnd:
  47. return "PStateInListEnd"
  48. case PStateInDict:
  49. return "PStateInDict"
  50. case PStateInDictEnd:
  51. return "PStateInDictEnd"
  52. case PStateInTuple:
  53. return "PStateInTuple"
  54. case PStateInTupleEnd:
  55. return "PStateInTupleEnd"
  56. case PStateTerminate:
  57. return "PStateTerminate"
  58. default:
  59. return fmt.Sprintf("PythonState(%d)", s)
  60. }
  61. }
  62. var PythonStates = []PythonState{
  63. PythonStateStart,
  64. StateInFunction,
  65. StateInFunctionArgs,
  66. StateInFunctionArgsType,
  67. StateInFunctionEnd,
  68. PStateInString,
  69. PStateInStringEnd,
  70. PStateInNumber,
  71. PStateInList,
  72. PStateInListEnd,
  73. PStateInDict,
  74. PStateInDictEnd,
  75. PStateInTuple,
  76. PStateInTupleEnd,
  77. PStateTerminate,
  78. }
  79. type Node struct {
  80. State PythonState
  81. TransitionEdges map[rune]*Node
  82. MaskTokenIDToNode map[int32]*Node
  83. }
  84. func NewNode(state PythonState) *Node {
  85. return &Node{
  86. State: state,
  87. TransitionEdges: make(map[rune]*Node),
  88. MaskTokenIDToNode: make(map[int32]*Node),
  89. }
  90. }
  91. type PythonFunction struct {
  92. Name string
  93. Args []string
  94. Types []string
  95. }
  96. type PythonSampler struct {
  97. stateToNodes map[PythonState]*Node
  98. proc model.TextProcessor
  99. decodedToks []string
  100. curNode *Node
  101. completed int
  102. functions []PythonFunction
  103. }
  104. func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
  105. s.proc = proc
  106. s.functions = functions
  107. decodedToks := make([]string, len(proc.Vocab().Values))
  108. for i := range proc.Vocab().Values {
  109. token, err := proc.Decode([]int32{int32(i)})
  110. if err != nil {
  111. return err
  112. }
  113. decodedToks[i] = token
  114. }
  115. s.decodedToks = decodedToks
  116. s.BuildGraph()
  117. for _, function := range functions {
  118. prevNode := s.stateToNodes[PythonStateStart]
  119. for _, r := range function.Name {
  120. nextNode := NewNode(StateInFunction)
  121. prevNode.TransitionEdges[r] = nextNode
  122. if err := s.CreateMask(nextNode); err != nil {
  123. return err
  124. }
  125. fmt.Println("prevNode", prevNode.State)
  126. fmt.Printf("transition edge: %q\n", r)
  127. fmt.Println("nextNode", nextNode.State)
  128. prevNode = nextNode
  129. }
  130. prevNode.TransitionEdges['('] = s.stateToNodes[StateInFunctionArgs]
  131. s.CreateMask(prevNode)
  132. prevNode = s.stateToNodes[StateInFunctionArgs]
  133. for i, arg := range function.Args {
  134. for _, r := range arg {
  135. nextNode := NewNode(StateInFunctionArgs)
  136. prevNode.TransitionEdges[r] = nextNode
  137. s.CreateMask(prevNode)
  138. prevNode = nextNode
  139. }
  140. prevNode.TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
  141. // prevNode = s.stateToNodes[StateInFunctionArgs]
  142. prevNode.TransitionEdges['='] = NewNode(StateInFunctionArgsType)
  143. s.CreateMask(prevNode)
  144. prevNode = prevNode.TransitionEdges['=']
  145. switch function.Types[i] {
  146. case "string":
  147. prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInString]
  148. s.CreateMask(prevNode.TransitionEdges['"'])
  149. case "number":
  150. prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInNumber]
  151. s.CreateMask(prevNode.TransitionEdges['"'])
  152. }
  153. }
  154. }
  155. s.curNode = s.stateToNodes[PythonStateStart]
  156. fmt.Println("curNode", s.curNode.State)
  157. fmt.Println("transition edges", s.curNode.TransitionEdges)
  158. if err := s.CreateMask(s.curNode); err != nil {
  159. return err
  160. }
  161. fmt.Println("maskTokenIDToNode", s.curNode.MaskTokenIDToNode)
  162. for tokenID, node := range s.curNode.MaskTokenIDToNode {
  163. fmt.Printf("tokenID: %d, node: %v\n", s.decodedToks[tokenID], node.State)
  164. }
  165. return nil
  166. }
  167. func (s *PythonSampler) BuildGraph() error {
  168. s.stateToNodes = make(map[PythonState]*Node)
  169. for _, state := range PythonStates {
  170. s.stateToNodes[state] = NewNode(state)
  171. }
  172. for _, state := range s.stateToNodes {
  173. if err := s.CreateMask(state); err != nil {
  174. return err
  175. }
  176. }
  177. // String
  178. s.stateToNodes[PStateInString].TransitionEdges[rune(-1)] = s.stateToNodes[PStateInString]
  179. s.stateToNodes[PStateInString].TransitionEdges['"'] = s.stateToNodes[PStateInStringEnd]
  180. // String end
  181. s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
  182. // s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
  183. // Number
  184. for _, r := range validNumberRunes {
  185. s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
  186. }
  187. s.stateToNodes[PStateInNumber].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
  188. s.stateToNodes[PStateInNumber].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
  189. s.stateToNodes[PStateInNumber].TransitionEdges[' '] = s.stateToNodes[StateInFunctionArgs]
  190. return nil
  191. }
  192. func (s *PythonSampler) ApplyMask(logits []float32) ([]float32, error) {
  193. if s.curNode.State == PStateTerminate {
  194. logits, err := finish(s, logits)
  195. if err != nil {
  196. return nil, err
  197. }
  198. return logits, nil
  199. }
  200. logits, err := s.maskLogits(logits, s.curNode)
  201. if err != nil {
  202. return nil, err
  203. }
  204. return logits, nil
  205. }
  206. func (s *PythonSampler) UpdateState(token int32) error {
  207. mappedString, err := s.proc.Decode([]int32{token})
  208. if err != nil {
  209. return err
  210. }
  211. fmt.Printf(">>> mappedString: %q\n", mappedString)
  212. if s.curNode.State == PStateTerminate {
  213. if s.proc.Is(token, model.SpecialEOS) {
  214. return nil
  215. }
  216. }
  217. nextNode, ok := s.curNode.MaskTokenIDToNode[token]
  218. if !ok {
  219. return fmt.Errorf("invalid token: %q", mappedString)
  220. }
  221. if mappedString == "\"" {
  222. if s.curNode.State == PStateInStringEnd {
  223. s.completed++
  224. }
  225. if s.completed == len(s.functions) {
  226. s.curNode.TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
  227. s.CreateMask(s.curNode)
  228. }
  229. }
  230. s.curNode = nextNode
  231. fmt.Println("curNode", s.curNode.State)
  232. for r, node := range s.curNode.TransitionEdges {
  233. fmt.Printf("transition edge: %q -> %v\n", r, node.State)
  234. }
  235. if err := s.CreateMask(s.curNode); err != nil {
  236. return err
  237. }
  238. return nil
  239. }
  240. func (s *PythonSampler) CreateMask(node *Node) error {
  241. if node == nil {
  242. return fmt.Errorf("node cannot be nil")
  243. }
  244. for i := range s.decodedToks {
  245. token := s.decodedToks[i]
  246. // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
  247. if s.proc.Is(int32(i), model.SpecialEOS) || s.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
  248. continue
  249. }
  250. curNode := node
  251. valid := true
  252. consumedSpecialRunes := make(map[rune]bool)
  253. for _, r := range token {
  254. curNode, valid = isRValid(r, curNode, consumedSpecialRunes)
  255. if curNode == nil || !valid {
  256. break
  257. }
  258. }
  259. if valid {
  260. if curNode.State == StateInFunction {
  261. // fmt.Println("cm curNode", curNode.State)
  262. // fmt.Println("cm token", s.decodedToks[i])
  263. }
  264. node.MaskTokenIDToNode[int32(i)] = curNode
  265. }
  266. }
  267. return nil
  268. }
  269. func isRValid(r rune, curNode *Node, consumedSpecialRunes map[rune]bool) (*Node, bool) {
  270. if consumedSpecialRunes[r] {
  271. return nil, false
  272. }
  273. specialRune := slices.Contains(stringInvalidRunes, r)
  274. if specialRune {
  275. if curNode.State == PStateInString || curNode.State == PStateInStringEnd {
  276. return nil, false
  277. }
  278. }
  279. // Check for specific rune transition
  280. if nextNode, ok := curNode.TransitionEdges[r]; ok {
  281. // fmt.Println("next node", nextNode)
  282. if specialRune {
  283. if curNode.State == nextNode.State {
  284. return nil, false
  285. }
  286. consumedSpecialRunes[r] = true
  287. }
  288. return nextNode, true
  289. }
  290. // Check for sentinel value - if present, any rune is valid
  291. if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
  292. return nextNode, true
  293. }
  294. return nil, false
  295. }
  296. func (s *PythonSampler) maskLogits(logits []float32, node *Node) ([]float32, error) {
  297. // Create a new slice with same length as logits, initialized to -Inf
  298. maskedLogits := make([]float32, len(logits))
  299. for i := range maskedLogits {
  300. maskedLogits[i] = float32(math.Inf(-1))
  301. }
  302. // Only update values for valid token IDs from the mask map
  303. for tokenID := range node.MaskTokenIDToNode {
  304. if int(tokenID) < len(logits) {
  305. maskedLogits[tokenID] = logits[tokenID]
  306. }
  307. }
  308. return maskedLogits, nil
  309. }
  310. func finish(s *PythonSampler, logits []float32) ([]float32, error) {
  311. for i := range logits {
  312. if s.proc.Is(int32(i), model.SpecialEOS) {
  313. logits[i] = 1.0
  314. } else {
  315. logits[i] = float32(math.Inf(-1))
  316. }
  317. }
  318. return logits, nil
  319. }