package sample import ( "errors" "fmt" "math" "github.com/ollama/ollama/model" ) type JSONState int const ( StateStart JSONState = iota StateInObject StateInObjectKey StateNewline StateTab StateSpace StateInString StateInInt StateInFloat StateInBool StateInNull StateInArray StateInColon StateInComma StateInStringEnd StateInObjectKeyEnd StateTerminate StateEnd ) func (s JSONState) String() string { switch s { case StateStart: return "StateStart" case StateInObject: return "StateInObject" case StateInObjectKey: return "StateInObjectKey" case StateInString: return "StateInString" case StateNewline: return "StateNewline" case StateTab: return "StateTab" case StateSpace: return "StateSpace" case StateInInt: return "StateInInt" case StateInFloat: return "StateInFloat" case StateInColon: return "StateInColon" case StateInBool: return "StateInBool" case StateInNull: return "StateInNull" case StateInArray: return "StateInArray" case StateEnd: return "StateEnd" case StateInComma: return "StateInComma" case StateInObjectKeyEnd: return "StateInObjectKeyEnd" case StateTerminate: return "StateTerminate" case StateInStringEnd: return "StateInStringEnd" default: return fmt.Sprintf("Unknown state: %d", s) } } type JSONSampler struct { curNode *Node proc model.TextProcessor stack []*Node bracketCounter int } func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) { // fmt.Println("Creating new JSON sampler") startNode, err := buildStateMachine(proc) if err != nil { return nil, err } js := &JSONSampler{ curNode: startNode, proc: proc, stack: []*Node{}, bracketCounter: 0, } return js, nil } func isTokenSubset(subset, superset []int32) bool { freq1 := make(map[int32]int) freq2 := make(map[int32]int) for _, v := range subset { freq1[v]++ } for _, v := range superset { freq2[v]++ } isSubset := true for k, count1 := range freq1 { count2 := freq2[k] if count1 > count2 { isSubset = false break } } return isSubset } func (s *JSONSampler) UpdateState(tokenSlice []int32) error { // fmt.Printf("Updating state with token: %v\n", tokenSlice) // fmt.Printf("Current state: %s\n", s.curNode.State) // fmt.Println("tokenSlice", tokenSlice) // todo: account for strings here objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc) if err != nil { return err } // only move to terminate state if stack is empty if s.curNode.State == StateEnd { fmt.Println("debug: node.State", s.curNode.State) if len(s.stack) > 0 { s.stack = s.stack[:len(s.stack)-1] fmt.Println("popped and cur state", s.curNode.State) return nil } return nil } for node, edge := range s.curNode.TransitionEdges { for _, validToken := range edge { if isTokenSubset(tokenSlice, validToken) { s.curNode = node for _, token := range objectTokens { if isTokenSubset(tokenSlice, token) { fmt.Println("Appending to stack", s.curNode.State) s.stack = append(s.stack, s.curNode) } } // fmt.Printf("Transitioned to state: %s\n", node.State) return nil } } } for node, edge := range s.curNode.TransitionEdges { for _, validToken := range edge { if len(validToken) == 1 && validToken[0] == -1 || validToken[0] == -2 { s.curNode = node // fmt.Printf("Accepting any token, staying in state: %s\n", node.State) return nil } } } fmt.Println("invalid token ", tokenSlice) dec, err := s.proc.Decode(tokenSlice) if err != nil { return err } fmt.Println("decoded token ", dec) return errors.New("invalid token") } func (s *JSONSampler) Sample(logits []float64) ([]float64, error) { fmt.Printf("Sampling in state: %s\n", s.curNode.State) var err error switch s.curNode.State { case StateTerminate: for i := range logits { if s.proc.Is(uint32(i), model.SpecialEOS) { logits[i] = 1.0 } else { logits[i] = math.NaN() } } return logits, nil case StateInInt: validStates := []int32{} minus, err := s.proc.Encode("-") if err != nil { return nil, err } digits := make([][]int32, 10) for i := 0; i < 10; i++ { digits[i], err = s.proc.Encode(fmt.Sprintf("%d", i)) if err != nil { return nil, err } } // Allow "-" and digits 0-9 at start for i := range logits { for _, d := range digits { if len(d) == 1 && int32(i) == d[0] { validStates = append(validStates, int32(i)) } } if len(minus) == 1 && int32(i) == minus[0] { validStates = append(validStates, int32(i)) } } return logits, nil case StateInString: penalizeNewlineVariants := []string{"\n", " \"\n"} penalizeNewlineToks, err := ComputeTokenVariants(penalizeNewlineVariants, s.proc) if err != nil { return nil, err } penalizeNewlineToks = append(penalizeNewlineToks, []int32{702}) logits, err = s.maskSpecificLogits(logits, penalizeNewlineToks) if err != nil { return nil, err } validStates := getValidStates(s.curNode) logits, err = s.maskLogits(logits, validStates) if err != nil { return nil, err } return logits, nil default: validStates := getValidStates(s.curNode) logits, err = s.maskLogits(logits, validStates) if err != nil { return nil, err } return logits, nil } } func getValidStates(node *Node) []int32 { validStates := []int32{} for _, edge := range node.TransitionEdges { for _, token := range edge { validStates = append(validStates, token...) } } return validStates } func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) { // fmt.Printf("Masking logits with valid states: %v\n", validStates) for i := range logits { isValid := false for _, token := range validStates { if token == -1 { // fmt.Println("Found sentinel token, returning unmasked logits") return logits, nil } if i == int(token) { // fmt.Printf("Found valid token: %d\n", token) isValid = true break } } if !isValid { logits[i] = math.NaN() } } return logits, nil } func (s *JSONSampler) maskSpecificLogits(logits []float64, tokensToMask []token) ([]float64, error) { // fmt.Printf("Masking specific logits: %v\n", tokensToMask) for i := range logits { for _, token := range tokensToMask { for _, chunked := range token { if int(chunked) == i { logits[i] = math.NaN() } } } } return logits, nil }