structured_python.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. package sample
  2. import (
  3. "fmt"
  4. "math"
  5. "slices"
  6. "github.com/ollama/ollama/model"
  7. )
  8. type PythonState int
  9. const (
  10. PythonStateStart PythonState = iota
  11. StateInFunction
  12. StateInFunctionArgs
  13. StateInFunctionArgsType
  14. StateInFunctionEnd
  15. PStateInString
  16. PStateInStringEnd
  17. PStateInNumber
  18. PStateInList
  19. PStateInListEnd
  20. PStateInDict
  21. PStateInDictEnd
  22. PStateInTuple
  23. PStateInTupleEnd
  24. PStateTerminate
  25. )
  26. func (s PythonState) String() string {
  27. switch s {
  28. case PythonStateStart:
  29. return "PythonStateStart"
  30. case StateInFunction:
  31. return "StateInFunction"
  32. case StateInFunctionArgs:
  33. return "StateInFunctionArgs"
  34. case StateInFunctionArgsType:
  35. return "StateInFunctionArgsType"
  36. case StateInFunctionEnd:
  37. return "StateInFunctionEnd"
  38. case PStateInString:
  39. return "PStateInString"
  40. case PStateInStringEnd:
  41. return "PStateInStringEnd"
  42. case PStateInNumber:
  43. return "PStateInNumber"
  44. case PStateInList:
  45. return "PStateInList"
  46. case PStateInListEnd:
  47. return "PStateInListEnd"
  48. case PStateInDict:
  49. return "PStateInDict"
  50. case PStateInDictEnd:
  51. return "PStateInDictEnd"
  52. case PStateInTuple:
  53. return "PStateInTuple"
  54. case PStateInTupleEnd:
  55. return "PStateInTupleEnd"
  56. case PStateTerminate:
  57. return "PStateTerminate"
  58. default:
  59. return fmt.Sprintf("PythonState(%d)", s)
  60. }
  61. }
  62. var PythonStates = []PythonState{
  63. PythonStateStart,
  64. StateInFunction,
  65. StateInFunctionArgs,
  66. StateInFunctionArgsType,
  67. StateInFunctionEnd,
  68. PStateInString,
  69. PStateInStringEnd,
  70. PStateInNumber,
  71. PStateInList,
  72. PStateInListEnd,
  73. PStateInDict,
  74. PStateInDictEnd,
  75. PStateInTuple,
  76. PStateInTupleEnd,
  77. PStateTerminate,
  78. }
  79. type Node struct {
  80. State PythonState
  81. TransitionEdges map[rune]*Node
  82. MaskTokenIDToNode map[int32]*Node
  83. }
  84. func NewNode(state PythonState) *Node {
  85. return &Node{
  86. State: state,
  87. TransitionEdges: make(map[rune]*Node),
  88. MaskTokenIDToNode: make(map[int32]*Node),
  89. }
  90. }
  91. type PythonFunction struct {
  92. Name string
  93. Args []string
  94. Types []string
  95. }
  96. type PythonSampler struct {
  97. stateToNodes map[PythonState]*Node
  98. proc model.TextProcessor
  99. decodedToks []string
  100. curNode *Node
  101. }
  102. func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
  103. s.proc = proc
  104. decodedToks := make([]string, len(proc.Vocab().Values))
  105. for i := range proc.Vocab().Values {
  106. token, err := proc.Decode([]int32{int32(i)})
  107. if err != nil {
  108. return err
  109. }
  110. decodedToks[i] = token
  111. }
  112. s.decodedToks = decodedToks
  113. s.BuildGraph()
  114. for _, function := range functions {
  115. prevNode := s.stateToNodes[PythonStateStart]
  116. for _, r := range function.Name {
  117. nextNode := NewNode(StateInFunction)
  118. prevNode.TransitionEdges[r] = nextNode
  119. if err := s.CreateMask(nextNode); err != nil {
  120. return err
  121. }
  122. fmt.Println("prevNode", prevNode.State)
  123. fmt.Printf("transition edge: %q\n", r)
  124. fmt.Println("nextNode", nextNode.State)
  125. prevNode = nextNode
  126. }
  127. prevNode.TransitionEdges['('] = s.stateToNodes[StateInFunctionArgs]
  128. s.CreateMask(prevNode)
  129. prevNode = s.stateToNodes[StateInFunctionArgs]
  130. for i, arg := range function.Args {
  131. for _, r := range arg {
  132. nextNode := NewNode(StateInFunctionArgs)
  133. prevNode.TransitionEdges[r] = nextNode
  134. s.CreateMask(prevNode)
  135. prevNode = nextNode
  136. }
  137. prevNode.TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
  138. // prevNode = s.stateToNodes[StateInFunctionArgs]
  139. prevNode.TransitionEdges['='] = NewNode(StateInFunctionArgsType)
  140. s.CreateMask(prevNode)
  141. prevNode = prevNode.TransitionEdges['=']
  142. switch function.Types[i] {
  143. case "string":
  144. prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInString]
  145. s.CreateMask(prevNode.TransitionEdges['"'])
  146. case "number":
  147. prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInNumber]
  148. s.CreateMask(prevNode.TransitionEdges['"'])
  149. }
  150. }
  151. }
  152. s.curNode = s.stateToNodes[PythonStateStart]
  153. fmt.Println("curNode", s.curNode.State)
  154. fmt.Println("transition edges", s.curNode.TransitionEdges)
  155. if err := s.CreateMask(s.curNode); err != nil {
  156. return err
  157. }
  158. fmt.Println("maskTokenIDToNode", s.curNode.MaskTokenIDToNode)
  159. for tokenID, node := range s.curNode.MaskTokenIDToNode {
  160. fmt.Printf("tokenID: %d, node: %v\n", s.decodedToks[tokenID], node.State)
  161. }
  162. return nil
  163. }
  164. func (s *PythonSampler) BuildGraph() error {
  165. s.stateToNodes = make(map[PythonState]*Node)
  166. for _, state := range PythonStates {
  167. s.stateToNodes[state] = NewNode(state)
  168. }
  169. for _, state := range s.stateToNodes {
  170. if err := s.CreateMask(state); err != nil {
  171. return err
  172. }
  173. }
  174. // String
  175. s.stateToNodes[PStateInString].TransitionEdges[rune(-1)] = s.stateToNodes[PStateInString]
  176. s.stateToNodes[PStateInString].TransitionEdges['"'] = s.stateToNodes[PStateInStringEnd]
  177. // String end
  178. s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
  179. s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
  180. // Number
  181. for _, r := range validNumberRunes {
  182. s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
  183. }
  184. s.stateToNodes[PStateInNumber].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
  185. s.stateToNodes[PStateInNumber].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
  186. s.stateToNodes[PStateInNumber].TransitionEdges[' '] = s.stateToNodes[StateInFunctionArgs]
  187. return nil
  188. }
  189. func (s *PythonSampler) ApplyMask(logits []float32) ([]float32, error) {
  190. if s.curNode.State == PStateTerminate {
  191. logits, err := finish(s, logits)
  192. if err != nil {
  193. return nil, err
  194. }
  195. return logits, nil
  196. }
  197. logits, err := s.maskLogits(logits, s.curNode)
  198. if err != nil {
  199. return nil, err
  200. }
  201. return logits, nil
  202. }
  203. func (s *PythonSampler) UpdateState(token int32) error {
  204. mappedString, err := s.proc.Decode([]int32{token})
  205. if err != nil {
  206. return err
  207. }
  208. fmt.Printf(">>> mappedString: %q\n", mappedString)
  209. if s.curNode.State == PStateTerminate {
  210. if s.proc.Is(token, model.SpecialEOS) {
  211. return nil
  212. }
  213. }
  214. nextNode, ok := s.curNode.MaskTokenIDToNode[token]
  215. if !ok {
  216. return fmt.Errorf("invalid token: %q", mappedString)
  217. }
  218. s.curNode = nextNode
  219. fmt.Println("curNode", s.curNode.State)
  220. for r, node := range s.curNode.TransitionEdges {
  221. fmt.Printf("transition edge: %q -> %v\n", r, node.State)
  222. }
  223. if err := s.CreateMask(s.curNode); err != nil {
  224. return err
  225. }
  226. return nil
  227. }
  228. func (s *PythonSampler) CreateMask(node *Node) error {
  229. if node == nil {
  230. return fmt.Errorf("node cannot be nil")
  231. }
  232. for i := range s.decodedToks {
  233. token := s.decodedToks[i]
  234. // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
  235. if s.proc.Is(int32(i), model.SpecialEOS) || s.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
  236. continue
  237. }
  238. curNode := node
  239. valid := true
  240. consumedSpecialRunes := make(map[rune]bool)
  241. for _, r := range token {
  242. curNode, valid = isRValid(r, curNode, consumedSpecialRunes)
  243. if curNode == nil || !valid {
  244. break
  245. }
  246. }
  247. if valid {
  248. if curNode.State == StateInFunction {
  249. // fmt.Println("cm curNode", curNode.State)
  250. // fmt.Println("cm token", s.decodedToks[i])
  251. }
  252. node.MaskTokenIDToNode[int32(i)] = curNode
  253. }
  254. }
  255. return nil
  256. }
  257. func isRValid(r rune, curNode *Node, consumedSpecialRunes map[rune]bool) (*Node, bool) {
  258. if consumedSpecialRunes[r] {
  259. return nil, false
  260. }
  261. specialRune := slices.Contains(stringInvalidRunes, r)
  262. if specialRune {
  263. if curNode.State == PStateInString || curNode.State == PStateInStringEnd {
  264. return nil, false
  265. }
  266. }
  267. // Check for specific rune transition
  268. if nextNode, ok := curNode.TransitionEdges[r]; ok {
  269. // fmt.Println("next node", nextNode)
  270. if specialRune {
  271. if curNode.State == nextNode.State {
  272. return nil, false
  273. }
  274. consumedSpecialRunes[r] = true
  275. }
  276. return nextNode, true
  277. }
  278. // Check for sentinel value - if present, any rune is valid
  279. if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
  280. return nextNode, true
  281. }
  282. return nil, false
  283. }
  284. func (s *PythonSampler) maskLogits(logits []float32, node *Node) ([]float32, error) {
  285. // Create a new slice with same length as logits, initialized to -Inf
  286. maskedLogits := make([]float32, len(logits))
  287. for i := range maskedLogits {
  288. maskedLogits[i] = float32(math.Inf(-1))
  289. }
  290. // Only update values for valid token IDs from the mask map
  291. for tokenID := range node.MaskTokenIDToNode {
  292. if int(tokenID) < len(logits) {
  293. maskedLogits[tokenID] = logits[tokenID]
  294. }
  295. }
  296. return maskedLogits, nil
  297. }
  298. func finish(s *PythonSampler, logits []float32) ([]float32, error) {
  299. for i := range logits {
  300. if s.proc.Is(int32(i), model.SpecialEOS) {
  301. logits[i] = 1.0
  302. } else {
  303. logits[i] = float32(math.Inf(-1))
  304. }
  305. }
  306. return logits, nil
  307. }