123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289 |
- package sample
- import (
- "errors"
- "fmt"
- "math"
- "github.com/ollama/ollama/model"
- )
- type JSONState int
- const (
- StateStart JSONState = iota
- StateInObject
- StateInObjectKey
- StateNewline
- StateTab
- StateSpace
- StateInString
- StateInInt
- StateInFloat
- StateInBool
- StateInNull
- StateInArray
- StateInColon
- StateInComma
- StateInStringEnd
- StateInObjectKeyEnd
- StateTerminate
- StateEnd
- )
- func (s JSONState) String() string {
- switch s {
- case StateStart:
- return "StateStart"
- case StateInObject:
- return "StateInObject"
- case StateInObjectKey:
- return "StateInObjectKey"
- case StateInString:
- return "StateInString"
- case StateNewline:
- return "StateNewline"
- case StateTab:
- return "StateTab"
- case StateSpace:
- return "StateSpace"
- case StateInInt:
- return "StateInInt"
- case StateInFloat:
- return "StateInFloat"
- case StateInColon:
- return "StateInColon"
- case StateInBool:
- return "StateInBool"
- case StateInNull:
- return "StateInNull"
- case StateInArray:
- return "StateInArray"
- case StateEnd:
- return "StateEnd"
- case StateInComma:
- return "StateInComma"
- case StateInObjectKeyEnd:
- return "StateInObjectKeyEnd"
- case StateTerminate:
- return "StateTerminate"
- case StateInStringEnd:
- return "StateInStringEnd"
- default:
- return fmt.Sprintf("Unknown state: %d", s)
- }
- }
- type JSONSampler struct {
- curNode *Node
- proc model.TextProcessor
- stack []*Node
- bracketCounter int
- }
- func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
- // fmt.Println("Creating new JSON sampler")
- startNode, err := buildStateMachine(proc)
- if err != nil {
- return nil, err
- }
- js := &JSONSampler{
- curNode: startNode,
- proc: proc,
- stack: []*Node{},
- bracketCounter: 0,
- }
- return js, nil
- }
- func isTokenSubset(subset, superset []int32) bool {
- freq1 := make(map[int32]int)
- freq2 := make(map[int32]int)
- for _, v := range subset {
- freq1[v]++
- }
- for _, v := range superset {
- freq2[v]++
- }
- isSubset := true
- for k, count1 := range freq1 {
- count2 := freq2[k]
- if count1 > count2 {
- isSubset = false
- break
- }
- }
- return isSubset
- }
- func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
- // fmt.Printf("Updating state with token: %v\n", tokenSlice)
- // fmt.Printf("Current state: %s\n", s.curNode.State)
- // fmt.Println("tokenSlice", tokenSlice)
- // todo: account for strings here
- objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
- if err != nil {
- return err
- }
- // only move to terminate state if stack is empty
- if s.curNode.State == StateEnd {
- fmt.Println("debug: node.State", s.curNode.State)
- if len(s.stack) > 0 {
- s.stack = s.stack[:len(s.stack)-1]
- fmt.Println("popped and cur state", s.curNode.State)
- return nil
- }
- return nil
- }
- for node, edge := range s.curNode.TransitionEdges {
- for _, validToken := range edge {
- if isTokenSubset(tokenSlice, validToken) {
- s.curNode = node
- for _, token := range objectTokens {
- if isTokenSubset(tokenSlice, token) {
- fmt.Println("Appending to stack", s.curNode.State)
- s.stack = append(s.stack, s.curNode)
- }
- }
- // fmt.Printf("Transitioned to state: %s\n", node.State)
- return nil
- }
- }
- }
- for node, edge := range s.curNode.TransitionEdges {
- for _, validToken := range edge {
- if len(validToken) == 1 && validToken[0] == -1 || validToken[0] == -2 {
- s.curNode = node
- // fmt.Printf("Accepting any token, staying in state: %s\n", node.State)
- return nil
- }
- }
- }
- fmt.Println("invalid token ", tokenSlice)
- dec, err := s.proc.Decode(tokenSlice)
- if err != nil {
- return err
- }
- fmt.Println("decoded token ", dec)
- return errors.New("invalid token")
- }
- func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
- fmt.Printf("Sampling in state: %s\n", s.curNode.State)
- var err error
- switch s.curNode.State {
- case StateTerminate:
- for i := range logits {
- if s.proc.Is(uint32(i), model.SpecialEOS) {
- logits[i] = 1.0
- } else {
- logits[i] = math.NaN()
- }
- }
- return logits, nil
- case StateInInt:
- validStates := []int32{}
- minus, err := s.proc.Encode("-")
- if err != nil {
- return nil, err
- }
- digits := make([][]int32, 10)
- for i := 0; i < 10; i++ {
- digits[i], err = s.proc.Encode(fmt.Sprintf("%d", i))
- if err != nil {
- return nil, err
- }
- }
- // Allow "-" and digits 0-9 at start
- for i := range logits {
- for _, d := range digits {
- if len(d) == 1 && int32(i) == d[0] {
- validStates = append(validStates, int32(i))
- }
- }
- if len(minus) == 1 && int32(i) == minus[0] {
- validStates = append(validStates, int32(i))
- }
- }
- return logits, nil
- case StateInString:
- penalizeNewlineVariants := []string{"\n", " \"\n"}
- penalizeNewlineToks, err := ComputeTokenVariants(penalizeNewlineVariants, s.proc)
- if err != nil {
- return nil, err
- }
- penalizeNewlineToks = append(penalizeNewlineToks, []int32{702})
- logits, err = s.maskSpecificLogits(logits, penalizeNewlineToks)
- if err != nil {
- return nil, err
- }
- validStates := getValidStates(s.curNode)
- logits, err = s.maskLogits(logits, validStates)
- if err != nil {
- return nil, err
- }
- return logits, nil
- default:
- validStates := getValidStates(s.curNode)
- logits, err = s.maskLogits(logits, validStates)
- if err != nil {
- return nil, err
- }
- return logits, nil
- }
- }
- func getValidStates(node *Node) []int32 {
- validStates := []int32{}
- for _, edge := range node.TransitionEdges {
- for _, token := range edge {
- validStates = append(validStates, token...)
- }
- }
- return validStates
- }
- func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) {
- // fmt.Printf("Masking logits with valid states: %v\n", validStates)
- for i := range logits {
- isValid := false
- for _, token := range validStates {
- if token == -1 {
- // fmt.Println("Found sentinel token, returning unmasked logits")
- return logits, nil
- }
- if i == int(token) {
- // fmt.Printf("Found valid token: %d\n", token)
- isValid = true
- break
- }
- }
- if !isValid {
- logits[i] = math.NaN()
- }
- }
- return logits, nil
- }
- func (s *JSONSampler) maskSpecificLogits(logits []float64, tokensToMask []token) ([]float64, error) {
- // fmt.Printf("Masking specific logits: %v\n", tokensToMask)
- for i := range logits {
- for _, token := range tokensToMask {
- for _, chunked := range token {
- if int(chunked) == i {
- logits[i] = math.NaN()
- }
- }
- }
- }
- return logits, nil
- }
|