fast_json.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. package sample
  2. import (
  3. "errors"
  4. "fmt"
  5. "math"
  6. "github.com/ollama/ollama/model"
  7. )
  8. type JSONState int
  9. const (
  10. StateStart JSONState = iota
  11. StateInObject
  12. StateInObjectKey
  13. StateNewline
  14. StateTab
  15. StateSpace
  16. StateInString
  17. StateInInt
  18. StateInFloat
  19. StateInBool
  20. StateInNull
  21. StateInArray
  22. StateInColon
  23. StateInComma
  24. StateInTab
  25. StateInSpace
  26. StateInNewline
  27. StateInStringEnd
  28. StateInObjectKeyEnd
  29. StateTerminate
  30. StateInObjectEnd
  31. )
  32. func (s JSONState) String() string {
  33. switch s {
  34. case StateStart:
  35. return "StateStart"
  36. case StateInObject:
  37. return "StateInObject"
  38. case StateInObjectKey:
  39. return "StateInObjectKey"
  40. case StateInString:
  41. return "StateInString"
  42. case StateNewline:
  43. return "StateNewline"
  44. case StateTab:
  45. return "StateTab"
  46. case StateSpace:
  47. return "StateSpace"
  48. case StateInInt:
  49. return "StateInInt"
  50. case StateInFloat:
  51. return "StateInFloat"
  52. case StateInColon:
  53. return "StateInColon"
  54. case StateInBool:
  55. return "StateInBool"
  56. case StateInNull:
  57. return "StateInNull"
  58. case StateInArray:
  59. return "StateInArray"
  60. case StateInObjectEnd:
  61. return "StateInObjectEnd"
  62. case StateInComma:
  63. return "StateInComma"
  64. case StateInTab:
  65. return "StateInTab"
  66. case StateInObjectKeyEnd:
  67. return "StateInObjectKeyEnd"
  68. case StateInNewline:
  69. return "StateInNewline"
  70. case StateInSpace:
  71. return "StateInSpace"
  72. case StateTerminate:
  73. return "StateTerminate"
  74. case StateInStringEnd:
  75. return "StateInStringEnd"
  76. default:
  77. return fmt.Sprintf("Unknown state: %d", s)
  78. }
  79. }
  80. type JSONSampler struct {
  81. curNode *Node
  82. proc model.TextProcessor
  83. stack []*Node
  84. bracketCounter int
  85. }
  86. func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
  87. // fmt.Println("Creating new JSON sampler")
  88. startNode, err := buildStateMachine(proc)
  89. if err != nil {
  90. return nil, err
  91. }
  92. js := &JSONSampler{
  93. curNode: startNode,
  94. proc: proc,
  95. stack: []*Node{},
  96. bracketCounter: 0,
  97. }
  98. return js, nil
  99. }
  100. func isTokenSubset(subset, superset []int32) bool {
  101. freq1 := make(map[int32]int)
  102. freq2 := make(map[int32]int)
  103. for _, v := range subset {
  104. freq1[v]++
  105. }
  106. for _, v := range superset {
  107. freq2[v]++
  108. }
  109. isSubset := true
  110. for k, count1 := range freq1 {
  111. count2 := freq2[k]
  112. if count1 > count2 {
  113. isSubset = false
  114. break
  115. }
  116. }
  117. return isSubset
  118. }
  119. func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
  120. // fmt.Printf("Updating state with token: %v\n", tokenSlice)
  121. // fmt.Printf("Current state: %s\n", s.curNode.State)
  122. // fmt.Println("tokenSlice", tokenSlice)
  123. // todo: account for strings here
  124. objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
  125. if err != nil {
  126. return err
  127. }
  128. // only move to terminate state if stack is empty
  129. if s.curNode.State == StateInObjectEnd {
  130. fmt.Println("debug: node.State", s.curNode.State)
  131. if len(s.stack) > 0 {
  132. s.stack = s.stack[:len(s.stack)-1]
  133. fmt.Println("popped and cur state", s.curNode.State)
  134. return nil
  135. }
  136. return nil
  137. }
  138. for node, edge := range s.curNode.TransitionEdges {
  139. for _, validToken := range edge {
  140. if isTokenSubset(tokenSlice, validToken) {
  141. s.curNode = node
  142. for _, token := range objectTokens {
  143. if isTokenSubset(tokenSlice, token) {
  144. fmt.Println("Appending to stack", s.curNode.State)
  145. s.stack = append(s.stack, s.curNode)
  146. }
  147. }
  148. // fmt.Printf("Transitioned to state: %s\n", node.State)
  149. return nil
  150. }
  151. }
  152. }
  153. for node, edge := range s.curNode.TransitionEdges {
  154. for _, validToken := range edge {
  155. if len(validToken) == 1 && validToken[0] == -1 || validToken[0] == -2 {
  156. s.curNode = node
  157. // fmt.Printf("Accepting any token, staying in state: %s\n", node.State)
  158. return nil
  159. }
  160. }
  161. }
  162. fmt.Println("invalid token ", tokenSlice)
  163. dec, err := s.proc.Decode(tokenSlice)
  164. if err != nil {
  165. return err
  166. }
  167. fmt.Println("decoded token ", dec)
  168. return errors.New("invalid token")
  169. }
  170. func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
  171. fmt.Printf("Sampling in state: %s\n", s.curNode.State)
  172. var err error
  173. switch s.curNode.State {
  174. case StateTerminate:
  175. for i := range logits {
  176. if s.proc.Is(uint32(i), model.SpecialEOS) {
  177. logits[i] = 1.0
  178. } else {
  179. logits[i] = math.NaN()
  180. }
  181. }
  182. return logits, nil
  183. case StateInInt:
  184. validStates := []int32{}
  185. minus, err := s.proc.Encode("-")
  186. if err != nil {
  187. return nil, err
  188. }
  189. digits := make([][]int32, 10)
  190. for i := 0; i < 10; i++ {
  191. digits[i], err = s.proc.Encode(fmt.Sprintf("%d", i))
  192. if err != nil {
  193. return nil, err
  194. }
  195. }
  196. // Allow "-" and digits 0-9 at start
  197. for i := range logits {
  198. for _, d := range digits {
  199. if len(d) == 1 && int32(i) == d[0] {
  200. validStates = append(validStates, int32(i))
  201. }
  202. }
  203. if len(minus) == 1 && int32(i) == minus[0] {
  204. validStates = append(validStates, int32(i))
  205. }
  206. }
  207. return logits, nil
  208. case StateInString:
  209. penalizeNewlineVariants := []string{"\n", " \"\n"}
  210. penalizeNewlineToks, err := ComputeTokenVariants(penalizeNewlineVariants, s.proc)
  211. if err != nil {
  212. return nil, err
  213. }
  214. penalizeNewlineToks = append(penalizeNewlineToks, []int32{702})
  215. logits, err = s.maskSpecificLogits(logits, penalizeNewlineToks)
  216. if err != nil {
  217. return nil, err
  218. }
  219. validStates := getValidStates(s.curNode)
  220. logits, err = s.maskLogits(logits, validStates)
  221. if err != nil {
  222. return nil, err
  223. }
  224. return logits, nil
  225. default:
  226. validStates := getValidStates(s.curNode)
  227. logits, err = s.maskLogits(logits, validStates)
  228. if err != nil {
  229. return nil, err
  230. }
  231. return logits, nil
  232. }
  233. }
  234. func getValidStates(node *Node) []int32 {
  235. validStates := []int32{}
  236. for _, edge := range node.TransitionEdges {
  237. for _, token := range edge {
  238. validStates = append(validStates, token...)
  239. }
  240. }
  241. return validStates
  242. }
  243. func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) {
  244. // fmt.Printf("Masking logits with valid states: %v\n", validStates)
  245. for i := range logits {
  246. isValid := false
  247. for _, token := range validStates {
  248. if token == -1 {
  249. // fmt.Println("Found sentinel token, returning unmasked logits")
  250. return logits, nil
  251. }
  252. if i == int(token) {
  253. // fmt.Printf("Found valid token: %d\n", token)
  254. isValid = true
  255. break
  256. }
  257. }
  258. if !isValid {
  259. logits[i] = math.NaN()
  260. }
  261. }
  262. return logits, nil
  263. }
  264. func (s *JSONSampler) maskSpecificLogits(logits []float64, tokensToMask []token) ([]float64, error) {
  265. // fmt.Printf("Masking specific logits: %v\n", tokensToMask)
  266. for i := range logits {
  267. for _, token := range tokensToMask {
  268. for _, chunked := range token {
  269. if int(chunked) == i {
  270. logits[i] = math.NaN()
  271. }
  272. }
  273. }
  274. }
  275. return logits, nil
  276. }