state_machine.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. package sample
  2. import (
  3. "fmt"
  4. "github.com/ollama/ollama/model"
  5. )
  6. type token []int32
  7. type Node struct {
  8. State JSONState
  9. TransitionEdges map[*Node][]token
  10. }
  11. func NewNode(state JSONState) *Node {
  12. return &Node{
  13. State: state,
  14. TransitionEdges: make(map[*Node][]token),
  15. }
  16. }
  17. var (
  18. // startToken token
  19. startTokenVariants []token
  20. // endToken token
  21. // stringToken token
  22. // objectKeyToken token
  23. tabToken token
  24. spaceToken token
  25. newlineToken token
  26. newlineSpace token
  27. // commaToken token
  28. // commaToken2 token
  29. // commaToken3 token
  30. // colonToken token
  31. // colonToken2 token
  32. colonTokenVariants []token
  33. commaTokenVariants []token
  34. stringTokenVariants []token
  35. endTokenVariants []token
  36. objectKeyTokenVariants []token
  37. objKeyToColonVariants []token
  38. stringToObjectKeyVariants []token
  39. stringToCommaVariants []token
  40. stringToObjectVariants []token
  41. stringEndToObjectEndVariants []token
  42. stringEndToCommaVariants []token
  43. )
  44. func ComputeTokenVariants(variants []string, proc model.TextProcessor) ([]token, error) {
  45. var allTokens token
  46. for _, variant := range variants {
  47. if t, err := proc.Encode(variant); err == nil {
  48. allTokens = append(allTokens, t...)
  49. }
  50. }
  51. if len(allTokens) == 0 {
  52. return nil, fmt.Errorf("no valid tokens found for variants")
  53. }
  54. return []token{allTokens}, nil
  55. }
  56. func initTokens(proc model.TextProcessor) error {
  57. var err error
  58. s, err := proc.Decode([]int32{761})
  59. fmt.Printf("761 decoded %q\n", s)
  60. // Compute start token variants
  61. startVariants := []string{"{", " {", "{\n", " {\n"}
  62. startTokenVariants, err = ComputeTokenVariants(startVariants, proc)
  63. if err != nil {
  64. return err
  65. }
  66. // Compute end token variants
  67. endVariants := []string{"}", " }", "}\n", " }\n"}
  68. endTokenVariants, err = ComputeTokenVariants(endVariants, proc)
  69. if err != nil {
  70. return err
  71. }
  72. // Compute string token variants
  73. // TODO: removed \n
  74. stringVariants := []string{"\"", " \""}
  75. stringTokenVariants, err = ComputeTokenVariants(stringVariants, proc)
  76. if err != nil {
  77. return err
  78. }
  79. stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\",\n"}, proc)
  80. if err != nil {
  81. return err
  82. }
  83. // objectKeyTokenVariants = []token{stringTokenVariants[0], stringTokenVariants[1]}
  84. objectKeyTokenVariants = stringTokenVariants
  85. // Compute whitespace tokens
  86. tabToken, err = proc.Encode("\t")
  87. if err != nil {
  88. return err
  89. }
  90. spaceToken, err = proc.Encode(" ")
  91. if err != nil {
  92. return err
  93. }
  94. newlineToken, err = proc.Encode("\n")
  95. if err != nil {
  96. return err
  97. }
  98. newlineSpace, err = proc.Encode(" \n")
  99. if err != nil {
  100. return err
  101. }
  102. // Compute colon variants
  103. colonVariants := []string{":"}
  104. colonTokenVariants, err = ComputeTokenVariants(colonVariants, proc)
  105. if err != nil {
  106. return err
  107. }
  108. objKeyToColonVariants, err = ComputeTokenVariants([]string{"\":"}, proc)
  109. if err != nil {
  110. return err
  111. }
  112. // Compute comma variants
  113. commaVariants := []string{",", " ,", ",\n", "\",", "\", "}
  114. commaTokenVariants, err = ComputeTokenVariants(commaVariants, proc)
  115. if err != nil {
  116. return err
  117. }
  118. fmt.Printf("commaTokenVariants: %v\n", commaTokenVariants)
  119. stringToCommaVariants, err = ComputeTokenVariants([]string{"\",", "\","}, proc)
  120. if err != nil {
  121. return err
  122. }
  123. stringEndToCommaVariants, err = ComputeTokenVariants([]string{",", ",\n"}, proc)
  124. stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\","}, proc)
  125. stringToObjectVariants, err = ComputeTokenVariants([]string{"\",\n"}, proc)
  126. stringEndToObjectEndVariants, err = ComputeTokenVariants([]string{"\n"}, proc)
  127. return nil
  128. }
  129. func buildStateMachine(proc model.TextProcessor) (*Node, error) {
  130. if err := initTokens(proc); err != nil {
  131. return nil, err
  132. }
  133. startNode := NewNode(StateStart)
  134. objectNode := NewNode(StateInObject)
  135. objectKeyNode := NewNode(StateInObjectKey)
  136. objectKeyEndNode := NewNode(StateInObjectKeyEnd)
  137. stringNode := NewNode(StateInString)
  138. // intNode := NewNode(StateInInt)
  139. commaNode := NewNode(StateInComma)
  140. colonNode := NewNode(StateInColon)
  141. stringEndNode := NewNode(StateInStringEnd)
  142. endNode := NewNode(StateEnd)
  143. terminateNode := NewNode(StateTerminate)
  144. sentinelToken := token([]int32{-1})
  145. // intSentinelToken := token([]int32{-2})
  146. // TODO: cleanup connections of rules
  147. startNode.TransitionEdges[objectNode] = startTokenVariants
  148. objectNode.TransitionEdges[objectKeyNode] = stringTokenVariants
  149. objectNode.TransitionEdges[objectNode] = []token{newlineToken}
  150. objectNode.TransitionEdges[objectNode] = []token{spaceToken}
  151. // objectNode.TransitionEdges[objectNode] = []token{newlineToken}
  152. // objectNode.TransitionEdges[objectNode] = []token{spaceToken}
  153. objectKeyNode.TransitionEdges[objectKeyNode] = []token{sentinelToken}
  154. // characterize end of object key
  155. objectKeyNode.TransitionEdges[objectKeyEndNode] = stringTokenVariants
  156. objectKeyNode.TransitionEdges[colonNode] = objKeyToColonVariants
  157. // TODO: enable this - key -> object
  158. // objectKeyNode.TransitionEdges[objectNode] = startTokenVariants
  159. // objectKeyNode.TransitionEdges[intNode] = []token{sentinelToken}
  160. // intNode.TransitionEdges[intNode] = []token{intSentinelToken}
  161. // intNode.TransitionEdges[commaNode] = commaTokenVariants
  162. // TODO: handle
  163. // intNode.TransitionEdges[terminateNode] = endTokenVariants
  164. commaNode.TransitionEdges[objectKeyNode] = stringTokenVariants
  165. // commaNode.TransitionEdges[objectNode] = startTokenVariants
  166. colonNode.TransitionEdges[stringNode] = stringTokenVariants
  167. //TODO: enable
  168. // colonNode.TransitionEdges[intNode] = []token{intSentinelToken}
  169. colonNode.TransitionEdges[objectNode] = startTokenVariants
  170. stringNode.TransitionEdges[stringNode] = []token{sentinelToken}
  171. stringNode.TransitionEdges[stringEndNode] = stringTokenVariants
  172. // TODO: "\""," Case not accounted for
  173. stringNode.TransitionEdges[commaNode] = stringToCommaVariants
  174. // TODO: "\"",\"" Case not accounted for
  175. stringNode.TransitionEdges[objectNode] = stringToObjectVariants
  176. stringEndNode.TransitionEdges[commaNode] = stringEndToCommaVariants
  177. stringEndNode.TransitionEdges[objectNode] = stringToObjectKeyVariants
  178. stringEndNode.TransitionEdges[endNode] = stringEndToObjectEndVariants
  179. // stringEndNode.TransitionEdges[terminateNode] = endTokenVariants
  180. // Should be obj end
  181. // TODO: handle
  182. endNode.TransitionEdges[terminateNode] = []token{}
  183. endNode.TransitionEdges[commaNode] = commaTokenVariants
  184. terminateNode.TransitionEdges[terminateNode] = []token{}
  185. return startNode, nil
  186. }