fast_json.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. package sample
  2. import (
  3. "errors"
  4. "fmt"
  5. "math"
  6. "github.com/ollama/ollama/model"
  7. )
  8. type JSONState int
  9. const (
  10. StateStart JSONState = iota
  11. StateInObject
  12. StateInObjectKey
  13. StateNewline
  14. StateTab
  15. StateSpace
  16. StateInString
  17. StateInInt
  18. StateInFloat
  19. StateInBool
  20. StateInNull
  21. StateInColon
  22. StateInComma
  23. StateInTab
  24. StateInSpace
  25. StateInObjSpace
  26. StateInList
  27. StateInListComma
  28. StateInValue
  29. StateInValueEnd
  30. StateInListEnd
  31. StateInListObjectEnd
  32. StateInNewline
  33. StateInNumber
  34. StateInNumberEnd
  35. StateInStringEnd
  36. StateInObjectKeyEnd
  37. StateTerminate
  38. StateInObjectEnd
  39. StateTransitioningToTerminate
  40. )
  41. var JSONStates = []JSONState{
  42. StateStart,
  43. StateInObject,
  44. StateInObjectKey,
  45. StateNewline,
  46. StateTab,
  47. StateSpace,
  48. StateInString,
  49. StateInInt,
  50. StateInFloat,
  51. StateInBool,
  52. StateInNull,
  53. StateInColon,
  54. StateInComma,
  55. StateInTab,
  56. StateInSpace,
  57. StateInObjSpace,
  58. StateInList,
  59. StateInListComma,
  60. StateInValue,
  61. StateInValueEnd,
  62. StateInListEnd,
  63. StateInListObjectEnd,
  64. StateInNewline,
  65. StateInNumber,
  66. StateInNumberEnd,
  67. StateInStringEnd,
  68. StateInObjectKeyEnd,
  69. StateTerminate,
  70. StateInObjectEnd,
  71. StateTransitioningToTerminate,
  72. }
  73. func (s JSONState) String() string {
  74. switch s {
  75. case StateStart:
  76. return "StateStart"
  77. case StateInObject:
  78. return "StateInObject"
  79. case StateInObjectKey:
  80. return "StateInObjectKey"
  81. case StateNewline:
  82. return "StateNewline"
  83. case StateTab:
  84. return "StateTab"
  85. case StateSpace:
  86. return "StateSpace"
  87. case StateInString:
  88. return "StateInString"
  89. case StateInInt:
  90. return "StateInInt"
  91. case StateInFloat:
  92. return "StateInFloat"
  93. case StateInBool:
  94. return "StateInBool"
  95. case StateInNull:
  96. return "StateInNull"
  97. case StateInColon:
  98. return "StateInColon"
  99. case StateInComma:
  100. return "StateInComma"
  101. case StateInTab:
  102. return "StateInTab"
  103. case StateInSpace:
  104. return "StateInSpace"
  105. case StateInObjSpace:
  106. return "StateInObjSpace"
  107. case StateInList:
  108. return "StateInList"
  109. case StateInListObjectEnd:
  110. return "StateInListObjectEnd"
  111. case StateInListComma:
  112. return "StateInListComma"
  113. case StateInListEnd:
  114. return "StateInListEnd"
  115. case StateInNewline:
  116. return "StateInNewline"
  117. case StateInNumber:
  118. return "StateInNumber"
  119. case StateInNumberEnd:
  120. return "StateInNumberEnd"
  121. case StateInStringEnd:
  122. return "StateInStringEnd"
  123. case StateInObjectKeyEnd:
  124. return "StateInObjectKeyEnd"
  125. case StateTerminate:
  126. return "StateTerminate"
  127. case StateInObjectEnd:
  128. return "StateInObjectEnd"
  129. default:
  130. return fmt.Sprintf("Unknown state: %d", s)
  131. }
  132. }
  133. type JSONSampler struct {
  134. curNode *Node
  135. proc model.TextProcessor
  136. stack []*Node
  137. bracketCounter int
  138. }
  139. func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
  140. // fmt.Println("Creating new JSON sampler")
  141. startNode, err := buildStateMachine(proc)
  142. if err != nil {
  143. return nil, err
  144. }
  145. js := &JSONSampler{
  146. curNode: startNode,
  147. proc: proc,
  148. stack: []*Node{},
  149. bracketCounter: 0,
  150. }
  151. return js, nil
  152. }
  153. func isTokenSubset(subset, superset []int32) bool {
  154. freq1 := make(map[int32]int)
  155. freq2 := make(map[int32]int)
  156. for _, v := range subset {
  157. freq1[v]++
  158. }
  159. for _, v := range superset {
  160. freq2[v]++
  161. }
  162. isSubset := true
  163. for k, count1 := range freq1 {
  164. count2 := freq2[k]
  165. if count1 > count2 {
  166. isSubset = false
  167. break
  168. }
  169. }
  170. return isSubset
  171. }
  172. func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
  173. // fmt.Printf("Updating state with token: %v\n", tokenSlice)
  174. // fmt.Printf("Current state: %s\n", s.curNode.State)
  175. // fmt.Println("tokenSlice", tokenSlice)
  176. // todo: account for strings here
  177. objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
  178. if err != nil {
  179. return err
  180. }
  181. // only move to terminate state if stack is empty
  182. if s.curNode.State == StateInObjectEnd {
  183. fmt.Println("debug: node.State", s.curNode.State)
  184. if len(s.stack) > 0 {
  185. s.stack = s.stack[:len(s.stack)-1]
  186. fmt.Println("popped and cur state", s.curNode.State)
  187. return nil
  188. }
  189. return nil
  190. }
  191. for node, edge := range s.curNode.TransitionEdges {
  192. for _, validToken := range edge {
  193. if isTokenSubset(tokenSlice, validToken) {
  194. s.curNode = node
  195. for _, token := range objectTokens {
  196. if isTokenSubset(tokenSlice, token) {
  197. fmt.Println("Appending to stack", s.curNode.State)
  198. s.stack = append(s.stack, s.curNode)
  199. }
  200. }
  201. // fmt.Printf("Transitioned to state: %s\n", node.State)
  202. return nil
  203. }
  204. }
  205. }
  206. for node, edge := range s.curNode.TransitionEdges {
  207. for _, validToken := range edge {
  208. if len(validToken) == 1 && validToken[0] == -1 || validToken[0] == -2 {
  209. s.curNode = node
  210. // fmt.Printf("Accepting any token, staying in state: %s\n", node.State)
  211. return nil
  212. }
  213. }
  214. }
  215. fmt.Println("invalid token ", tokenSlice)
  216. dec, err := s.proc.Decode(tokenSlice)
  217. if err != nil {
  218. return err
  219. }
  220. fmt.Println("decoded token ", dec)
  221. return errors.New("invalid token")
  222. }
  223. func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
  224. fmt.Printf("Sampling in state: %s\n", s.curNode.State)
  225. var err error
  226. switch s.curNode.State {
  227. case StateTerminate:
  228. for i := range logits {
  229. if s.proc.Is(uint32(i), model.SpecialEOS) {
  230. logits[i] = 1.0
  231. } else {
  232. logits[i] = math.NaN()
  233. }
  234. }
  235. return logits, nil
  236. case StateInInt:
  237. validStates := []int32{}
  238. minus, err := s.proc.Encode("-")
  239. if err != nil {
  240. return nil, err
  241. }
  242. digits := make([][]int32, 10)
  243. for i := 0; i < 10; i++ {
  244. digits[i], err = s.proc.Encode(fmt.Sprintf("%d", i))
  245. if err != nil {
  246. return nil, err
  247. }
  248. }
  249. // Allow "-" and digits 0-9 at start
  250. for i := range logits {
  251. for _, d := range digits {
  252. if len(d) == 1 && int32(i) == d[0] {
  253. validStates = append(validStates, int32(i))
  254. }
  255. }
  256. if len(minus) == 1 && int32(i) == minus[0] {
  257. validStates = append(validStates, int32(i))
  258. }
  259. }
  260. return logits, nil
  261. case StateInString:
  262. penalizeNewlineVariants := []string{"\n", " \"\n"}
  263. penalizeNewlineToks, err := ComputeTokenVariants(penalizeNewlineVariants, s.proc)
  264. if err != nil {
  265. return nil, err
  266. }
  267. penalizeNewlineToks = append(penalizeNewlineToks, []int32{702})
  268. logits, err = s.maskSpecificLogits(logits, penalizeNewlineToks)
  269. if err != nil {
  270. return nil, err
  271. }
  272. validStates := getValidStates(s.curNode)
  273. logits, err = s.maskLogits(logits, validStates)
  274. if err != nil {
  275. return nil, err
  276. }
  277. return logits, nil
  278. default:
  279. validStates := getValidStates(s.curNode)
  280. logits, err = s.maskLogits(logits, validStates)
  281. if err != nil {
  282. return nil, err
  283. }
  284. return logits, nil
  285. }
  286. }
  287. func getValidStates(node *Node) []int32 {
  288. validStates := []int32{}
  289. for _, edge := range node.TransitionEdges {
  290. for _, token := range edge {
  291. validStates = append(validStates, token...)
  292. }
  293. }
  294. return validStates
  295. }
  296. func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) {
  297. // fmt.Printf("Masking logits with valid states: %v\n", validStates)
  298. // todo: this can prob be more efficient
  299. for i := range logits {
  300. isValid := false
  301. for _, token := range validStates {
  302. if token == -1 {
  303. // fmt.Println("Found sentinel token, returning unmasked logits")
  304. return logits, nil
  305. }
  306. if i == int(token) {
  307. // fmt.Printf("Found valid token: %d\n", token)
  308. isValid = true
  309. break
  310. }
  311. }
  312. if !isValid {
  313. logits[i] = math.NaN()
  314. }
  315. }
  316. return logits, nil
  317. }
  318. func (s *JSONSampler) maskSpecificLogits(logits []float64, tokensToMask []token) ([]float64, error) {
  319. // fmt.Printf("Masking specific logits: %v\n", tokensToMask)
  320. for i := range logits {
  321. for _, token := range tokensToMask {
  322. for _, chunked := range token {
  323. if int(chunked) == i {
  324. logits[i] = math.NaN()
  325. }
  326. }
  327. }
  328. }
  329. return logits, nil
  330. }