ParthSareen 3 місяців тому
батько
коміт
6ba557f25b
2 змінених файлів з 198 додано та 56 видалено
  1. 89 7
      sample/fast_json.go
  2. 109 49
      sample/state_machine.go

+ 89 - 7
sample/fast_json.go

@@ -4,7 +4,6 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"math"
 	"math"
-	"slices"
 
 
 	"github.com/ollama/ollama/model"
 	"github.com/ollama/ollama/model"
 )
 )
@@ -76,9 +75,10 @@ func (s JSONState) String() string {
 }
 }
 
 
 type JSONSampler struct {
 type JSONSampler struct {
-	curNode *Node
-	proc    model.TextProcessor
-	stack   []*Node
+	curNode        *Node
+	proc           model.TextProcessor
+	stack          []*Node
+	bracketCounter int
 }
 }
 
 
 func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
 func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
@@ -88,23 +88,68 @@ func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 	js := &JSONSampler{
 	js := &JSONSampler{
-		curNode: startNode,
-		proc:    proc,
+		curNode:        startNode,
+		proc:           proc,
+		stack:          []*Node{},
+		bracketCounter: 0,
 	}
 	}
 
 
 	return js, nil
 	return js, nil
 }
 }
 
 
+func isTokenSubset(subset, superset []int32) bool {
+	freq1 := make(map[int32]int)
+	freq2 := make(map[int32]int)
+
+	for _, v := range subset {
+		freq1[v]++
+	}
+	for _, v := range superset {
+		freq2[v]++
+	}
+	isSubset := true
+	for k, count1 := range freq1 {
+		count2 := freq2[k]
+		if count1 > count2 {
+			isSubset = false
+			break
+		}
+	}
+	return isSubset
+}
+
 func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
 func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
 	// fmt.Printf("Updating state with token: %v\n", tokenSlice)
 	// fmt.Printf("Updating state with token: %v\n", tokenSlice)
 	// fmt.Printf("Current state: %s\n", s.curNode.State)
 	// fmt.Printf("Current state: %s\n", s.curNode.State)
 
 
 	// fmt.Println("tokenSlice", tokenSlice)
 	// fmt.Println("tokenSlice", tokenSlice)
 	// todo: account for strings here
 	// todo: account for strings here
+	objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
+	if err != nil {
+		return err
+	}
+
+	// only move to terminate state if stack is empty
+	if s.curNode.State == StateEnd {
+		fmt.Println("debug: node.State", s.curNode.State)
+		if len(s.stack) > 0 {
+			s.stack = s.stack[:len(s.stack)-1]
+			fmt.Println("popped and cur state", s.curNode.State)
+			return nil
+		}
+		return nil
+	}
+
 	for node, edge := range s.curNode.TransitionEdges {
 	for node, edge := range s.curNode.TransitionEdges {
 		for _, validToken := range edge {
 		for _, validToken := range edge {
-			if slices.Equal(tokenSlice, validToken) {
+			if isTokenSubset(tokenSlice, validToken) {
 				s.curNode = node
 				s.curNode = node
+				for _, token := range objectTokens {
+					if isTokenSubset(tokenSlice, token) {
+						fmt.Println("Appending to stack", s.curNode.State)
+						s.stack = append(s.stack, s.curNode)
+					}
+				}
 				// fmt.Printf("Transitioned to state: %s\n", node.State)
 				// fmt.Printf("Transitioned to state: %s\n", node.State)
 				return nil
 				return nil
 			}
 			}
@@ -120,6 +165,11 @@ func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
 		}
 		}
 	}
 	}
 	fmt.Println("invalid token ", tokenSlice)
 	fmt.Println("invalid token ", tokenSlice)
+	dec, err := s.proc.Decode(tokenSlice)
+	if err != nil {
+		return err
+	}
+	fmt.Println("decoded token ", dec)
 	return errors.New("invalid token")
 	return errors.New("invalid token")
 }
 }
 
 
@@ -164,6 +214,24 @@ func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
 		}
 		}
 		return logits, nil
 		return logits, nil
 
 
+	case StateInString:
+		penalizeNewlineVariants := []string{"\n", " \"\n"}
+		penalizeNewlineToks, err := ComputeTokenVariants(penalizeNewlineVariants, s.proc)
+		if err != nil {
+			return nil, err
+		}
+		penalizeNewlineToks = append(penalizeNewlineToks, []int32{702})
+		logits, err = s.maskSpecificLogits(logits, penalizeNewlineToks)
+		if err != nil {
+			return nil, err
+		}
+		validStates := getValidStates(s.curNode)
+		logits, err = s.maskLogits(logits, validStates)
+		if err != nil {
+			return nil, err
+		}
+		return logits, nil
+
 	default:
 	default:
 		validStates := getValidStates(s.curNode)
 		validStates := getValidStates(s.curNode)
 		logits, err = s.maskLogits(logits, validStates)
 		logits, err = s.maskLogits(logits, validStates)
@@ -205,3 +273,17 @@ func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float
 	}
 	}
 	return logits, nil
 	return logits, nil
 }
 }
+
+func (s *JSONSampler) maskSpecificLogits(logits []float64, tokensToMask []token) ([]float64, error) {
+	// fmt.Printf("Masking specific logits: %v\n", tokensToMask)
+	for i := range logits {
+		for _, token := range tokensToMask {
+			for _, chunked := range token {
+				if int(chunked) == i {
+					logits[i] = math.NaN()
+				}
+			}
+		}
+	}
+	return logits, nil
+}

+ 109 - 49
sample/state_machine.go

@@ -21,39 +21,78 @@ func NewNode(state JSONState) *Node {
 }
 }
 
 
 var (
 var (
-	startToken     token
-	endToken       token
-	stringToken    token
-	objectKeyToken token
-	tabToken       token
-	spaceToken     token
-	newlineToken   token
-	newlineSpace   token
-	commaToken     token
-	commaToken2    token
-	commaToken3    token
-	colonToken     token
-	colonToken2    token
+	// startToken             token
+	startTokenVariants []token
+	// endToken               token
+	// stringToken            token
+	// objectKeyToken         token
+	tabToken     token
+	spaceToken   token
+	newlineToken token
+	newlineSpace token
+	// commaToken             token
+	// commaToken2            token
+	// commaToken3            token
+	// colonToken             token
+	// colonToken2            token
+	colonTokenVariants           []token
+	commaTokenVariants           []token
+	stringTokenVariants          []token
+	endTokenVariants             []token
+	objectKeyTokenVariants       []token
+	objKeyToColonVariants        []token
+	stringToObjectKeyVariants    []token
+	stringToCommaVariants        []token
+	stringToObjectVariants       []token
+	stringEndToObjectEndVariants []token
+	stringEndToCommaVariants     []token
 )
 )
 
 
+func ComputeTokenVariants(variants []string, proc model.TextProcessor) ([]token, error) {
+	var allTokens token
+	for _, variant := range variants {
+		if t, err := proc.Encode(variant); err == nil {
+			allTokens = append(allTokens, t...)
+		}
+	}
+	if len(allTokens) == 0 {
+		return nil, fmt.Errorf("no valid tokens found for variants")
+	}
+	return []token{allTokens}, nil
+}
 func initTokens(proc model.TextProcessor) error {
 func initTokens(proc model.TextProcessor) error {
 	var err error
 	var err error
-	startToken, err = proc.Encode("{")
+
+	s, err := proc.Decode([]int32{761})
+	fmt.Printf("761 decoded %q\n", s)
+
+	// Compute start token variants
+	startVariants := []string{"{", " {", "{\n", " {\n"}
+	startTokenVariants, err = ComputeTokenVariants(startVariants, proc)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	endToken, err = proc.Encode("}")
+	// Compute end token variants
+	endVariants := []string{"}", " }", "}\n", " }\n"}
+	endTokenVariants, err = ComputeTokenVariants(endVariants, proc)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	stringToken, err = proc.Encode("\"")
+
+	// Compute string token variants
+	// TODO: removed \n
+	stringVariants := []string{"\"", " \""}
+	stringTokenVariants, err = ComputeTokenVariants(stringVariants, proc)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	objectKeyToken, err = proc.Encode("\"")
+	stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\",\n"}, proc)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
+	// objectKeyTokenVariants = []token{stringTokenVariants[0], stringTokenVariants[1]}
+	objectKeyTokenVariants = stringTokenVariants
+	// Compute whitespace tokens
 	tabToken, err = proc.Encode("\t")
 	tabToken, err = proc.Encode("\t")
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -70,29 +109,35 @@ func initTokens(proc model.TextProcessor) error {
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	// TODO: figure out how to encode colon correctly
-	colonToken, err = proc.Encode("\":")
-	if err != nil {
-		return err
-	}
-	fmt.Println("colonToken", colonToken)
-	colonToken2, err = proc.Encode(":")
+
+	// Compute colon variants
+	colonVariants := []string{":"}
+	colonTokenVariants, err = ComputeTokenVariants(colonVariants, proc)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	commaToken, err = proc.Encode(",")
+	objKeyToColonVariants, err = ComputeTokenVariants([]string{"\":"}, proc)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	commaToken2, err = proc.Encode("\",")
+
+	// Compute comma variants
+	commaVariants := []string{",", " ,", ",\n", "\",", "\", "}
+	commaTokenVariants, err = ComputeTokenVariants(commaVariants, proc)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	fmt.Println("commaToken2", commaToken2)
-	commaToken3, err = proc.Encode("\",\"")
+	fmt.Printf("commaTokenVariants: %v\n", commaTokenVariants)
+	stringToCommaVariants, err = ComputeTokenVariants([]string{"\",", "\","}, proc)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
+
+	stringEndToCommaVariants, err = ComputeTokenVariants([]string{",", ",\n"}, proc)
+	stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\","}, proc)
+	stringToObjectVariants, err = ComputeTokenVariants([]string{"\",\n"}, proc)
+	stringEndToObjectEndVariants, err = ComputeTokenVariants([]string{"\n"}, proc)
+
 	return nil
 	return nil
 }
 }
 
 
@@ -106,7 +151,7 @@ func buildStateMachine(proc model.TextProcessor) (*Node, error) {
 	objectKeyNode := NewNode(StateInObjectKey)
 	objectKeyNode := NewNode(StateInObjectKey)
 	objectKeyEndNode := NewNode(StateInObjectKeyEnd)
 	objectKeyEndNode := NewNode(StateInObjectKeyEnd)
 	stringNode := NewNode(StateInString)
 	stringNode := NewNode(StateInString)
-	intNode := NewNode(StateInInt)
+	// intNode := NewNode(StateInInt)
 	commaNode := NewNode(StateInComma)
 	commaNode := NewNode(StateInComma)
 	colonNode := NewNode(StateInColon)
 	colonNode := NewNode(StateInColon)
 	stringEndNode := NewNode(StateInStringEnd)
 	stringEndNode := NewNode(StateInStringEnd)
@@ -114,44 +159,59 @@ func buildStateMachine(proc model.TextProcessor) (*Node, error) {
 	terminateNode := NewNode(StateTerminate)
 	terminateNode := NewNode(StateTerminate)
 
 
 	sentinelToken := token([]int32{-1})
 	sentinelToken := token([]int32{-1})
-	intSentinelToken := token([]int32{-2})
+	// intSentinelToken := token([]int32{-2})
+
+	// TODO: cleanup connections of rules
+	startNode.TransitionEdges[objectNode] = startTokenVariants
 
 
-	startNode.TransitionEdges[objectNode] = []token{startToken}
+	objectNode.TransitionEdges[objectKeyNode] = stringTokenVariants
+	objectNode.TransitionEdges[objectNode] = []token{newlineToken}
+	objectNode.TransitionEdges[objectNode] = []token{spaceToken}
 
 
-	objectNode.TransitionEdges[objectKeyNode] = []token{stringToken}
 	// objectNode.TransitionEdges[objectNode] = []token{newlineToken}
 	// objectNode.TransitionEdges[objectNode] = []token{newlineToken}
 	// objectNode.TransitionEdges[objectNode] = []token{spaceToken}
 	// objectNode.TransitionEdges[objectNode] = []token{spaceToken}
 
 
 	objectKeyNode.TransitionEdges[objectKeyNode] = []token{sentinelToken}
 	objectKeyNode.TransitionEdges[objectKeyNode] = []token{sentinelToken}
-	objectKeyNode.TransitionEdges[colonNode] = []token{colonToken, colonToken2}
 	// characterize end of object key
 	// characterize end of object key
-	objectKeyNode.TransitionEdges[objectKeyEndNode] = []token{stringToken}
+	objectKeyNode.TransitionEdges[objectKeyEndNode] = stringTokenVariants
+	objectKeyNode.TransitionEdges[colonNode] = objKeyToColonVariants
 
 
-	objectKeyEndNode.TransitionEdges[colonNode] = []token{colonToken}
+	// TODO: enable this - key -> object
+	// objectKeyNode.TransitionEdges[objectNode] = startTokenVariants
 
 
 	// objectKeyNode.TransitionEdges[intNode] = []token{sentinelToken}
 	// objectKeyNode.TransitionEdges[intNode] = []token{sentinelToken}
 
 
-	intNode.TransitionEdges[intNode] = []token{intSentinelToken}
-	intNode.TransitionEdges[commaNode] = []token{commaToken, commaToken2}
-	intNode.TransitionEdges[terminateNode] = []token{endToken}
+	// intNode.TransitionEdges[intNode] = []token{intSentinelToken}
+	// intNode.TransitionEdges[commaNode] = commaTokenVariants
+	// TODO: handle
+	// intNode.TransitionEdges[terminateNode] = endTokenVariants
 
 
-	commaNode.TransitionEdges[objectKeyNode] = []token{newlineToken}
+	commaNode.TransitionEdges[objectKeyNode] = stringTokenVariants
+	// commaNode.TransitionEdges[objectNode] = startTokenVariants
 
 
-	colonNode.TransitionEdges[stringNode] = []token{stringToken}
-	colonNode.TransitionEdges[intNode] = []token{intSentinelToken}
+	colonNode.TransitionEdges[stringNode] = stringTokenVariants
+	//TODO: enable
+	// colonNode.TransitionEdges[intNode] = []token{intSentinelToken}
+	colonNode.TransitionEdges[objectNode] = startTokenVariants
 
 
 	stringNode.TransitionEdges[stringNode] = []token{sentinelToken}
 	stringNode.TransitionEdges[stringNode] = []token{sentinelToken}
-	stringNode.TransitionEdges[stringEndNode] = []token{stringToken}
-	// "\""," Case
-	stringNode.TransitionEdges[commaNode] = []token{commaToken2}
+	stringNode.TransitionEdges[stringEndNode] = stringTokenVariants
+	// TODO: "\""," Case not accounted for
+	stringNode.TransitionEdges[commaNode] = stringToCommaVariants
+
+	// TODO: "\"",\"" Case not accounted for
+	stringNode.TransitionEdges[objectNode] = stringToObjectVariants
 
 
-	// "\"",\"" Case
-	stringNode.TransitionEdges[objectKeyNode] = []token{commaToken3}
+	stringEndNode.TransitionEdges[commaNode] = stringEndToCommaVariants
+	stringEndNode.TransitionEdges[objectNode] = stringToObjectKeyVariants
+	stringEndNode.TransitionEdges[endNode] = stringEndToObjectEndVariants
+	// stringEndNode.TransitionEdges[terminateNode] = endTokenVariants
 
 
-	stringEndNode.TransitionEdges[commaNode] = []token{commaToken, commaToken2}
-	stringEndNode.TransitionEdges[terminateNode] = []token{endToken}
+	// Should be obj end
+	// TODO: handle
+	endNode.TransitionEdges[terminateNode] = []token{}
 
 
-	endNode.TransitionEdges[terminateNode] = []token{endToken}
+	endNode.TransitionEdges[commaNode] = commaTokenVariants
 
 
 	terminateNode.TransitionEdges[terminateNode] = []token{}
 	terminateNode.TransitionEdges[terminateNode] = []token{}
 	return startNode, nil
 	return startNode, nil