fast_json.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. package sample
  2. import (
  3. "errors"
  4. "fmt"
  5. "math"
  6. "github.com/ollama/ollama/model"
  7. )
  8. type JSONState int
  9. const (
  10. StateStart JSONState = iota
  11. StateInObject
  12. StateInObjectKey
  13. StateNewline
  14. StateTab
  15. StateSpace
  16. StateInString
  17. StateInInt
  18. StateInFloat
  19. StateInBool
  20. StateInNull
  21. StateInColon
  22. StateInComma
  23. StateInTab
  24. StateInSpace
  25. StateInObjSpace
  26. StateInList
  27. StateInListComma
  28. StateListEnd
  29. StateInValue
  30. StateInValueEnd
  31. StateInListEnd
  32. StateInListObjectEnd
  33. StateInNewline
  34. StateInNumber
  35. StateInNumberEnd
  36. StateInStringEnd
  37. StateInObjectKeyEnd
  38. StateTerminate
  39. StateInObjectEnd
  40. StateTransitioningToTerminate
  41. )
  42. func (s JSONState) String() string {
  43. switch s {
  44. case StateStart:
  45. return "StateStart"
  46. case StateInObject:
  47. return "StateInObject"
  48. case StateInObjectKey:
  49. return "StateInObjectKey"
  50. case StateNewline:
  51. return "StateNewline"
  52. case StateTab:
  53. return "StateTab"
  54. case StateSpace:
  55. return "StateSpace"
  56. case StateInString:
  57. return "StateInString"
  58. case StateInInt:
  59. return "StateInInt"
  60. case StateInFloat:
  61. return "StateInFloat"
  62. case StateInBool:
  63. return "StateInBool"
  64. case StateInNull:
  65. return "StateInNull"
  66. case StateInColon:
  67. return "StateInColon"
  68. case StateInComma:
  69. return "StateInComma"
  70. case StateInTab:
  71. return "StateInTab"
  72. case StateInSpace:
  73. return "StateInSpace"
  74. case StateInObjSpace:
  75. return "StateInObjSpace"
  76. case StateInList:
  77. return "StateInList"
  78. case StateInListObjectEnd:
  79. return "StateInListObjectEnd"
  80. case StateInListComma:
  81. return "StateInListComma"
  82. case StateListEnd:
  83. return "StateListEnd"
  84. case StateInListEnd:
  85. return "StateInListEnd"
  86. case StateInNewline:
  87. return "StateInNewline"
  88. case StateInNumber:
  89. return "StateInNumber"
  90. case StateInNumberEnd:
  91. return "StateInNumberEnd"
  92. case StateInStringEnd:
  93. return "StateInStringEnd"
  94. case StateInObjectKeyEnd:
  95. return "StateInObjectKeyEnd"
  96. case StateTerminate:
  97. return "StateTerminate"
  98. case StateInObjectEnd:
  99. return "StateInObjectEnd"
  100. default:
  101. return fmt.Sprintf("Unknown state: %d", s)
  102. }
  103. }
  104. type JSONSampler struct {
  105. curNode *Node
  106. proc model.TextProcessor
  107. stack []*Node
  108. bracketCounter int
  109. }
  110. func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
  111. // fmt.Println("Creating new JSON sampler")
  112. startNode, err := buildStateMachine(proc)
  113. if err != nil {
  114. return nil, err
  115. }
  116. js := &JSONSampler{
  117. curNode: startNode,
  118. proc: proc,
  119. stack: []*Node{},
  120. bracketCounter: 0,
  121. }
  122. return js, nil
  123. }
  124. func isTokenSubset(subset, superset []int32) bool {
  125. freq1 := make(map[int32]int)
  126. freq2 := make(map[int32]int)
  127. for _, v := range subset {
  128. freq1[v]++
  129. }
  130. for _, v := range superset {
  131. freq2[v]++
  132. }
  133. isSubset := true
  134. for k, count1 := range freq1 {
  135. count2 := freq2[k]
  136. if count1 > count2 {
  137. isSubset = false
  138. break
  139. }
  140. }
  141. return isSubset
  142. }
  143. func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
  144. // fmt.Printf("Updating state with token: %v\n", tokenSlice)
  145. // fmt.Printf("Current state: %s\n", s.curNode.State)
  146. // fmt.Println("tokenSlice", tokenSlice)
  147. // todo: account for strings here
  148. objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
  149. if err != nil {
  150. return err
  151. }
  152. // only move to terminate state if stack is empty
  153. if s.curNode.State == StateInObjectEnd {
  154. fmt.Println("debug: node.State", s.curNode.State)
  155. if len(s.stack) > 0 {
  156. s.stack = s.stack[:len(s.stack)-1]
  157. fmt.Println("popped and cur state", s.curNode.State)
  158. return nil
  159. }
  160. return nil
  161. }
  162. for node, edge := range s.curNode.TransitionEdges {
  163. for _, validToken := range edge {
  164. if isTokenSubset(tokenSlice, validToken) {
  165. s.curNode = node
  166. for _, token := range objectTokens {
  167. if isTokenSubset(tokenSlice, token) {
  168. fmt.Println("Appending to stack", s.curNode.State)
  169. s.stack = append(s.stack, s.curNode)
  170. }
  171. }
  172. // fmt.Printf("Transitioned to state: %s\n", node.State)
  173. return nil
  174. }
  175. }
  176. }
  177. for node, edge := range s.curNode.TransitionEdges {
  178. for _, validToken := range edge {
  179. if len(validToken) == 1 && validToken[0] == -1 || validToken[0] == -2 {
  180. s.curNode = node
  181. // fmt.Printf("Accepting any token, staying in state: %s\n", node.State)
  182. return nil
  183. }
  184. }
  185. }
  186. fmt.Println("invalid token ", tokenSlice)
  187. dec, err := s.proc.Decode(tokenSlice)
  188. if err != nil {
  189. return err
  190. }
  191. fmt.Println("decoded token ", dec)
  192. return errors.New("invalid token")
  193. }
  194. func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
  195. fmt.Printf("Sampling in state: %s\n", s.curNode.State)
  196. var err error
  197. switch s.curNode.State {
  198. case StateTerminate:
  199. for i := range logits {
  200. if s.proc.Is(uint32(i), model.SpecialEOS) {
  201. logits[i] = 1.0
  202. } else {
  203. logits[i] = math.NaN()
  204. }
  205. }
  206. return logits, nil
  207. case StateInInt:
  208. validStates := []int32{}
  209. minus, err := s.proc.Encode("-")
  210. if err != nil {
  211. return nil, err
  212. }
  213. digits := make([][]int32, 10)
  214. for i := 0; i < 10; i++ {
  215. digits[i], err = s.proc.Encode(fmt.Sprintf("%d", i))
  216. if err != nil {
  217. return nil, err
  218. }
  219. }
  220. // Allow "-" and digits 0-9 at start
  221. for i := range logits {
  222. for _, d := range digits {
  223. if len(d) == 1 && int32(i) == d[0] {
  224. validStates = append(validStates, int32(i))
  225. }
  226. }
  227. if len(minus) == 1 && int32(i) == minus[0] {
  228. validStates = append(validStates, int32(i))
  229. }
  230. }
  231. return logits, nil
  232. case StateInString:
  233. penalizeNewlineVariants := []string{"\n", " \"\n"}
  234. penalizeNewlineToks, err := ComputeTokenVariants(penalizeNewlineVariants, s.proc)
  235. if err != nil {
  236. return nil, err
  237. }
  238. penalizeNewlineToks = append(penalizeNewlineToks, []int32{702})
  239. logits, err = s.maskSpecificLogits(logits, penalizeNewlineToks)
  240. if err != nil {
  241. return nil, err
  242. }
  243. validStates := getValidStates(s.curNode)
  244. logits, err = s.maskLogits(logits, validStates)
  245. if err != nil {
  246. return nil, err
  247. }
  248. return logits, nil
  249. default:
  250. validStates := getValidStates(s.curNode)
  251. logits, err = s.maskLogits(logits, validStates)
  252. if err != nil {
  253. return nil, err
  254. }
  255. return logits, nil
  256. }
  257. }
  258. func getValidStates(node *Node) []int32 {
  259. validStates := []int32{}
  260. for _, edge := range node.TransitionEdges {
  261. for _, token := range edge {
  262. validStates = append(validStates, token...)
  263. }
  264. }
  265. return validStates
  266. }
  267. func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) {
  268. // fmt.Printf("Masking logits with valid states: %v\n", validStates)
  269. // todo: this can prob be more efficient
  270. for i := range logits {
  271. isValid := false
  272. for _, token := range validStates {
  273. if token == -1 {
  274. // fmt.Println("Found sentinel token, returning unmasked logits")
  275. return logits, nil
  276. }
  277. if i == int(token) {
  278. // fmt.Printf("Found valid token: %d\n", token)
  279. isValid = true
  280. break
  281. }
  282. }
  283. if !isValid {
  284. logits[i] = math.NaN()
  285. }
  286. }
  287. return logits, nil
  288. }
  289. func (s *JSONSampler) maskSpecificLogits(logits []float64, tokensToMask []token) ([]float64, error) {
  290. // fmt.Printf("Masking specific logits: %v\n", tokensToMask)
  291. for i := range logits {
  292. for _, token := range tokensToMask {
  293. for _, chunked := range token {
  294. if int(chunked) == i {
  295. logits[i] = math.NaN()
  296. }
  297. }
  298. }
  299. }
  300. return logits, nil
  301. }