ParthSareen пре 3 месеци
родитељ
комит
a2a73ce5e0
3 измењених фајлова са 336 додато и 4 уклоњено
  1. 14 4
      sample/fast_json.go
  2. 175 0
      sample/pushdown_automata.go
  3. 147 0
      sample/pushdown_runner.go

+ 14 - 4
sample/fast_json.go

@@ -25,10 +25,13 @@ const (
 	StateInArray
 	StateInColon
 	StateInComma
+	StateInTab
+	StateInSpace
+	StateInNewline
 	StateInStringEnd
 	StateInObjectKeyEnd
 	StateTerminate
-	StateEnd
+	StateInObjectEnd
 )
 
 func (s JSONState) String() string {
@@ -59,12 +62,18 @@ func (s JSONState) String() string {
 		return "StateInNull"
 	case StateInArray:
 		return "StateInArray"
-	case StateEnd:
-		return "StateEnd"
+	case StateInObjectEnd:
+		return "StateInObjectEnd"
 	case StateInComma:
 		return "StateInComma"
+	case StateInTab:
+		return "StateInTab"
 	case StateInObjectKeyEnd:
 		return "StateInObjectKeyEnd"
+	case StateInNewline:
+		return "StateInNewline"
+	case StateInSpace:
+		return "StateInSpace"
 	case StateTerminate:
 		return "StateTerminate"
 	case StateInStringEnd:
@@ -124,13 +133,14 @@ func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
 
 	// fmt.Println("tokenSlice", tokenSlice)
 	// todo: account for strings here
+
 	objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
 	if err != nil {
 		return err
 	}
 
 	// only move to terminate state if stack is empty
-	if s.curNode.State == StateEnd {
+	if s.curNode.State == StateInObjectEnd {
 		fmt.Println("debug: node.State", s.curNode.State)
 		if len(s.stack) > 0 {
 			s.stack = s.stack[:len(s.stack)-1]

+ 175 - 0
sample/pushdown_automata.go

@@ -0,0 +1,175 @@
+package sample
+
+import (
+	"slices"
+
+	"github.com/ollama/ollama/model"
+)
+
+var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
+
+type PDANode struct {
+	State             JSONState
+	TransitionEdges   map[rune]*PDANode
+	MaskTokenIDToNode map[int32]JSONState
+}
+
+func NewPDANode(state JSONState) *PDANode {
+	return &PDANode{
+		State:             state,
+		TransitionEdges:   make(map[rune]*PDANode),
+		MaskTokenIDToNode: make(map[int32]JSONState),
+	}
+}
+
+func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
+	stateToNodeMap := make(map[JSONState]*PDANode)
+
+	startNode := NewPDANode(StateStart)
+	stateToNodeMap[StateStart] = startNode
+
+	objNode := NewPDANode(StateInObject)
+	stateToNodeMap[StateInObject] = objNode
+
+	objEndNode := NewPDANode(StateInObjectEnd)
+	stateToNodeMap[StateInObjectEnd] = objEndNode
+
+	objKeyNode := NewPDANode(StateInObjectKey)
+	stateToNodeMap[StateInObjectKey] = objKeyNode
+
+	objKeyEndNode := NewPDANode(StateInObjectKeyEnd)
+	stateToNodeMap[StateInObjectKeyEnd] = objKeyEndNode
+
+	colonNode := NewPDANode(StateInColon)
+	stateToNodeMap[StateInColon] = colonNode
+
+	commaNode := NewPDANode(StateInComma)
+	stateToNodeMap[StateInComma] = commaNode
+
+	newlineNode := NewPDANode(StateInNewline)
+	stateToNodeMap[StateInNewline] = newlineNode
+
+	spaceNode := NewPDANode(StateInSpace)
+	stateToNodeMap[StateInSpace] = spaceNode
+
+	tabNode := NewPDANode(StateInTab)
+	stateToNodeMap[StateInTab] = tabNode
+
+	stringNode := NewPDANode(StateInString)
+	stateToNodeMap[StateInString] = stringNode
+
+	stringEndNode := NewPDANode(StateInStringEnd)
+	stateToNodeMap[StateInStringEnd] = stringEndNode
+
+	// terminateNode := NewNode(StateTerminate)
+
+	// Connect nodes
+	// TODO: if all are single tokens then this can just be connected instead of defining the token
+	startNode.TransitionEdges['{'] = objNode
+
+	objNode.TransitionEdges['"'] = objKeyNode
+	objNode.TransitionEdges['\n'] = newlineNode
+
+	newlineNode.TransitionEdges['"'] = objKeyNode
+	newlineNode.TransitionEdges['\t'] = tabNode
+
+	tabNode.TransitionEdges['"'] = objKeyNode
+
+	spaceNode.TransitionEdges['"'] = stringNode
+
+	objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
+	objKeyNode.TransitionEdges['"'] = objKeyEndNode
+	objKeyNode.TransitionEdges[' '] = spaceNode
+	// objKeyNode.TransitionEdges['\t'] = tabNode
+
+	objKeyEndNode.TransitionEdges[':'] = colonNode
+
+	colonNode.TransitionEdges['"'] = stringNode
+	colonNode.TransitionEdges[' '] = spaceNode
+
+	stringNode.TransitionEdges[rune(-1)] = stringNode
+	stringNode.TransitionEdges['"'] = stringEndNode
+
+	stringEndNode.TransitionEdges[','] = commaNode
+	stringEndNode.TransitionEdges['}'] = objEndNode
+
+	commaNode.TransitionEdges['{'] = objNode
+	commaNode.TransitionEdges['\n'] = newlineNode
+	commaNode.TransitionEdges['\t'] = tabNode
+	commaNode.TransitionEdges['"'] = objKeyNode
+
+	return startNode, stateToNodeMap, nil
+}
+
+func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
+
+	vocab := proc.GetVocabulary()
+
+	decodedToks := make([]string, len(vocab.Values))
+	for i := range vocab.Values {
+		token, err := proc.Decode([]int32{int32(i)})
+		if err != nil {
+			return err
+		}
+		decodedToks[i] = token
+	}
+
+	var err error
+	for _, node := range stateToNodeMap {
+		for i := range vocab.Values {
+			token := decodedToks[i]
+			// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
+			if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" {
+				continue
+			}
+			valid := true
+			curNode := node
+			consumedSpecialRunes := make(map[rune]bool)
+			for _, r := range token {
+				valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
+				if err != nil {
+					return err
+				}
+				if !valid {
+					break
+				}
+			}
+			if valid {
+				node.MaskTokenIDToNode[int32(i)] = curNode.State
+			}
+		}
+	}
+	return nil
+}
+
+func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
+	if consumedSpecialRunes[r] {
+		return false, nil, nil
+	}
+
+	specialRune := slices.Contains(stringInvalidRunes, r)
+	if specialRune {
+		if curNode.State == StateInString || curNode.State == StateInObjectKey {
+			return false, nil, nil
+		}
+	}
+
+	// Check for specific rune transition
+	if nextNode, ok := curNode.TransitionEdges[r]; ok {
+		if specialRune {
+			if curNode.State == nextNode.State {
+				return false, nil, nil
+			}
+			// fmt.Println("special rune", r, "consumed")
+			consumedSpecialRunes[r] = true
+		}
+		return true, nextNode, nil
+	}
+
+	// Check for sentinel value - if present, any rune is valid
+	if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
+		return true, nextNode, nil
+	}
+
+	return false, nil, nil
+}

+ 147 - 0
sample/pushdown_runner.go

@@ -0,0 +1,147 @@
+package sample
+
+import (
+	"fmt"
+	"math"
+
+	"github.com/ollama/ollama/model"
+)
+
+type PushdownSampler struct {
+	// stateful
+	curNode        *PDANode
+	proc           model.TextProcessor
+	stateToNodeMap map[JSONState]*PDANode
+	braceStack     []rune
+}
+
+func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
+	startNode, stateToNodeMap, err := BuildGraph(proc)
+	if err != nil {
+		panic(err)
+	}
+	err = PreComputeValidStates(stateToNodeMap, proc)
+	if err != nil {
+		panic(err)
+	}
+	// for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode {
+	// 	token, err := proc.Decode([]int32{int32(id)})
+	// 	if err != nil {
+	// 		panic(err)
+	// 	}
+	// 	fmt.Println("id", id, "node", node, "token", token)
+	// }
+	// time.Sleep(10 * time.Second)
+	return &PushdownSampler{
+		curNode:        startNode,
+		proc:           proc,
+		stateToNodeMap: stateToNodeMap,
+		braceStack:     []rune{},
+	}
+}
+
+func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
+	fmt.Println("sample:", s.curNode.State)
+
+	switch s.curNode.State {
+	case StateInObjectEnd:
+		// force finish if no braces left
+		if len(s.braceStack) == 0 {
+			s.curNode = NewPDANode(StateTerminate)
+			for i := range logits {
+				if s.proc.Is(uint32(i), model.SpecialEOS) {
+					logits[i] = 1.0
+				} else {
+					logits[i] = math.NaN()
+				}
+			}
+			return logits, nil
+		}
+		valid, err := s.proc.Encode("}")
+		if err != nil {
+			return nil, err
+		}
+		for i := range logits {
+			for _, token := range valid {
+				if i != int(token) {
+					logits[i] = math.NaN()
+				}
+			}
+		}
+		return logits, nil
+	// return logits, nil
+	case StateTerminate:
+		for i := range logits {
+			if s.proc.Is(uint32(i), model.SpecialEOS) {
+				logits[i] = 1.0
+			} else {
+				logits[i] = math.NaN()
+			}
+		}
+		return logits, nil
+
+	// case StateInStringEnd:
+
+	// 	return logits, nil
+	default:
+		fmt.Println("masking logits current state", s.curNode.State)
+		logits, err := s.maskLogits(logits, s.curNode)
+		if err != nil {
+			return nil, err
+		}
+		return logits, nil
+	}
+}
+
+func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
+	fmt.Println("update state", s.curNode.State)
+
+	// TODO: need to handle end states and entering object case
+	if s.curNode.State == StateInObjectEnd {
+		fmt.Println("in object end")
+		if len(s.braceStack) > 0 {
+			s.braceStack = s.braceStack[:len(s.braceStack)-1]
+			return nil
+		}
+		s.curNode = NewPDANode(StateTerminate)
+		// TODO: return here?
+	}
+	// need this cause there could be multiple transitions
+	mappedString, err := s.proc.Decode(tokenSlice)
+	if err != nil {
+		return err
+	}
+	for _, r := range mappedString {
+		if r == rune('{') {
+			s.braceStack = append(s.braceStack, r)
+		}
+		if r == rune('}') {
+			if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') {
+				return fmt.Errorf("unmatched closing brace")
+			}
+			s.braceStack = s.braceStack[:len(s.braceStack)-1]
+		}
+	}
+	for _, tokenID := range tokenSlice {
+		// transition to the next node
+		nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
+		if !ok {
+			return fmt.Errorf("invalid token: %q", mappedString)
+		}
+		fmt.Println("transitioning to", nextNode)
+		s.curNode = s.stateToNodeMap[nextNode]
+	}
+	return nil
+}
+
+func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
+	for i := range logits {
+		_, exists := node.MaskTokenIDToNode[int32(i)]
+		if !exists {
+			logits[i] = math.NaN()
+		}
+	}
+	return logits, nil
+}
+
+// TODO: add penalties for string \n stuff