ParthSareen il y a 3 mois
Parent
commit
e93db4d20e
3 fichiers modifiés avec 214 ajouts et 29 suppressions
  1. 34 15
      sample/fast_json.go
  2. 127 6
      sample/pushdown_automata.go
  3. 53 8
      sample/pushdown_runner.go

+ 34 - 15
sample/fast_json.go

@@ -22,12 +22,18 @@ const (
 	StateInFloat
 	StateInBool
 	StateInNull
-	StateInArray
 	StateInColon
 	StateInComma
 	StateInTab
 	StateInSpace
+	StateInObjSpace
+	StateInList
+	StateInListComma
+	StateListEnd
+	StateInListEnd
 	StateInNewline
+	StateInNumber
+	StateInNumberEnd
 	StateInStringEnd
 	StateInObjectKeyEnd
 	StateTerminate
@@ -42,42 +48,54 @@ func (s JSONState) String() string {
 		return "StateInObject"
 	case StateInObjectKey:
 		return "StateInObjectKey"
-	case StateInString:
-		return "StateInString"
 	case StateNewline:
 		return "StateNewline"
 	case StateTab:
 		return "StateTab"
 	case StateSpace:
 		return "StateSpace"
+	case StateInString:
+		return "StateInString"
 	case StateInInt:
 		return "StateInInt"
 	case StateInFloat:
 		return "StateInFloat"
-	case StateInColon:
-		return "StateInColon"
 	case StateInBool:
 		return "StateInBool"
 	case StateInNull:
 		return "StateInNull"
-	case StateInArray:
-		return "StateInArray"
-	case StateInObjectEnd:
-		return "StateInObjectEnd"
+	case StateInColon:
+		return "StateInColon"
 	case StateInComma:
 		return "StateInComma"
 	case StateInTab:
 		return "StateInTab"
-	case StateInObjectKeyEnd:
-		return "StateInObjectKeyEnd"
-	case StateInNewline:
-		return "StateInNewline"
 	case StateInSpace:
 		return "StateInSpace"
-	case StateTerminate:
-		return "StateTerminate"
+	case StateInObjSpace:
+		return "StateInObjSpace"
+	case StateInList:
+		return "StateInList"
+	case StateInListComma:
+		return "StateInListComma"
+	case StateListEnd:
+		return "StateListEnd"
+	case StateInListEnd:
+		return "StateInListEnd"
+	case StateInNewline:
+		return "StateInNewline"
+	case StateInNumber:
+		return "StateInNumber"
+	case StateInNumberEnd:
+		return "StateInNumberEnd"
 	case StateInStringEnd:
 		return "StateInStringEnd"
+	case StateInObjectKeyEnd:
+		return "StateInObjectKeyEnd"
+	case StateTerminate:
+		return "StateTerminate"
+	case StateInObjectEnd:
+		return "StateInObjectEnd"
 	default:
 		return fmt.Sprintf("Unknown state: %d", s)
 	}
@@ -264,6 +282,7 @@ func getValidStates(node *Node) []int32 {
 
 func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) {
 	// fmt.Printf("Masking logits with valid states: %v\n", validStates)
+	// todo: this can prob be more efficient
 	for i := range logits {
 		isValid := false
 		for _, token := range validStates {

+ 127 - 6
sample/pushdown_automata.go

@@ -8,6 +8,15 @@ import (
 
 var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
 
+var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
+var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
+
+var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
+
+var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
+
+var validNullRunes = []rune{'n', 'u', 'l', 'l'}
+
 type PDANode struct {
 	State             JSONState
 	TransitionEdges   map[rune]*PDANode
@@ -52,6 +61,9 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
 	spaceNode := NewPDANode(StateInSpace)
 	stateToNodeMap[StateInSpace] = spaceNode
 
+	spaceObjNode := NewPDANode(StateInObjSpace)
+	stateToNodeMap[StateInObjSpace] = spaceObjNode
+
 	tabNode := NewPDANode(StateInTab)
 	stateToNodeMap[StateInTab] = tabNode
 
@@ -61,7 +73,31 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
 	stringEndNode := NewPDANode(StateInStringEnd)
 	stateToNodeMap[StateInStringEnd] = stringEndNode
 
-	// terminateNode := NewNode(StateTerminate)
+	listNode := NewPDANode(StateInList)
+	stateToNodeMap[StateInList] = listNode
+
+	listCommaNode := NewPDANode(StateInListComma)
+	stateToNodeMap[StateInListComma] = listCommaNode
+
+	listEndNode := NewPDANode(StateListEnd)
+	stateToNodeMap[StateListEnd] = listEndNode
+
+	numberNode := NewPDANode(StateInNumber)
+	stateToNodeMap[StateInNumber] = numberNode
+
+	boolNode := NewPDANode(StateInBool)
+	stateToNodeMap[StateInBool] = boolNode
+
+	nullNode := NewPDANode(StateInNull)
+	stateToNodeMap[StateInNull] = nullNode
+
+	// Defined with structured outputs only
+	intNode := NewPDANode(StateInInt)
+	stateToNodeMap[StateInInt] = intNode
+
+	// TODO:
+	// consider adding a node to just point to values, could be good to compute that
+	// mask rather than many different nodes
 
 	// Connect nodes
 	// TODO: if all are single tokens then this can just be connected instead of defining the token
@@ -69,34 +105,119 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
 
 	objNode.TransitionEdges['"'] = objKeyNode
 	objNode.TransitionEdges['\n'] = newlineNode
+	// objNode.TransitionEdges['\t'] = tabNode
 
 	newlineNode.TransitionEdges['"'] = objKeyNode
 	newlineNode.TransitionEdges['\t'] = tabNode
 
 	tabNode.TransitionEdges['"'] = objKeyNode
-
-	spaceNode.TransitionEdges['"'] = stringNode
+	// tabNode.TransitionEdges['\t'] = tabNode
 
 	objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
 	objKeyNode.TransitionEdges['"'] = objKeyEndNode
-	objKeyNode.TransitionEdges[' '] = spaceNode
-	// objKeyNode.TransitionEdges['\t'] = tabNode
 
 	objKeyEndNode.TransitionEdges[':'] = colonNode
+	objEndNode.TransitionEdges[' '] = spaceNode
 
-	colonNode.TransitionEdges['"'] = stringNode
+	// where values should be
+	// this could be combined but the probs might change, we're alr doing a skip ahead
 	colonNode.TransitionEdges[' '] = spaceNode
 
+	// Leads to a value
+	spaceNode.TransitionEdges['"'] = stringNode
+	spaceNode.TransitionEdges['['] = listNode
+	spaceNode.TransitionEdges['{'] = objNode
+
+	for _, r := range validNumberRunes {
+		spaceNode.TransitionEdges[r] = numberNode
+	}
+	for _, r := range validBoolRunes {
+		spaceNode.TransitionEdges[r] = boolNode
+	}
+
+	for _, r := range validNullRunes {
+		spaceNode.TransitionEdges[r] = nullNode
+	}
+
+	// Values
+	// string node
 	stringNode.TransitionEdges[rune(-1)] = stringNode
 	stringNode.TransitionEdges['"'] = stringEndNode
 
 	stringEndNode.TransitionEdges[','] = commaNode
 	stringEndNode.TransitionEdges['}'] = objEndNode
+	stringEndNode.TransitionEdges[']'] = listEndNode
+
+	// TODO: add counters for allowable number of decimals, e, E, etc
+	// number node
+	for _, r := range validNumberRunes {
+		numberNode.TransitionEdges[r] = numberNode
+	}
+	numberNode.TransitionEdges[','] = commaNode
+	numberNode.TransitionEdges['}'] = objEndNode
+	numberNode.TransitionEdges[']'] = listEndNode
+
+	for _, r := range validBoolRunes {
+		boolNode.TransitionEdges[r] = boolNode
+	}
+
+	// list node
+	listNode.TransitionEdges[','] = commaNode
+	listNode.TransitionEdges['"'] = stringNode
+	// squash states to a value
+	for _, r := range validNumberRunes {
+		listNode.TransitionEdges[r] = numberNode
+	}
+	for _, r := range validBoolRunes {
+		listNode.TransitionEdges[r] = boolNode
+	}
+	for _, r := range validNullRunes {
+		listNode.TransitionEdges[r] = nullNode
+	}
+
+	// null node
+	for _, r := range validNullRunes {
+		nullNode.TransitionEdges[r] = nullNode
+	}
+	nullNode.TransitionEdges[','] = commaNode
+	nullNode.TransitionEdges['}'] = objEndNode
+	nullNode.TransitionEdges[']'] = listEndNode
+
+	// list comma
+	// should point to values
+	listCommaNode.TransitionEdges['"'] = stringNode
+	listCommaNode.TransitionEdges[' '] = listCommaNode
+	listCommaNode.TransitionEdges['{'] = objNode
+	listCommaNode.TransitionEdges['\n'] = newlineNode
+
+	for _, r := range validNumberRunes {
+		listCommaNode.TransitionEdges[r] = numberNode
+	}
+	for _, r := range validBoolRunes {
+		listCommaNode.TransitionEdges[r] = boolNode
+	}
+	for _, r := range validNullRunes {
+		listCommaNode.TransitionEdges[r] = nullNode
+	}
+
+	// bool node
+	for _, r := range validBoolRunes {
+		boolNode.TransitionEdges[r] = boolNode
+	}
+	boolNode.TransitionEdges['}'] = objEndNode
+	boolNode.TransitionEdges[']'] = listEndNode
+	boolNode.TransitionEdges[','] = commaNode
+
+	listEndNode.TransitionEdges['}'] = objEndNode
+	listEndNode.TransitionEdges[','] = commaNode
 
 	commaNode.TransitionEdges['{'] = objNode
 	commaNode.TransitionEdges['\n'] = newlineNode
 	commaNode.TransitionEdges['\t'] = tabNode
 	commaNode.TransitionEdges['"'] = objKeyNode
+	commaNode.TransitionEdges[' '] = spaceObjNode
+
+	spaceObjNode.TransitionEdges['"'] = objKeyNode
 
 	return startNode, stateToNodeMap, nil
 }

+ 53 - 8
sample/pushdown_runner.go

@@ -3,6 +3,8 @@ package sample
 import (
 	"fmt"
 	"math"
+	"runtime"
+	"time"
 
 	"github.com/ollama/ollama/model"
 )
@@ -13,9 +15,17 @@ type PushdownSampler struct {
 	proc           model.TextProcessor
 	stateToNodeMap map[JSONState]*PDANode
 	braceStack     []rune
+	stateCounter   uint32
 }
 
 func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
+	start := time.Now()
+
+	var m runtime.MemStats
+	runtime.ReadMemStats(&m)
+	before := m.Alloc
+	fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
+
 	startNode, stateToNodeMap, err := BuildGraph(proc)
 	if err != nil {
 		panic(err)
@@ -24,6 +34,11 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
 	if err != nil {
 		panic(err)
 	}
+	runtime.ReadMemStats(&m)
+	after := m.Alloc
+	fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
+	fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
+	fmt.Printf("Graph build time = %v\n", time.Since(start))
 	// for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode {
 	// 	token, err := proc.Decode([]int32{int32(id)})
 	// 	if err != nil {
@@ -37,6 +52,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
 		proc:           proc,
 		stateToNodeMap: stateToNodeMap,
 		braceStack:     []rune{},
+		stateCounter:   0,
 	}
 }
 
@@ -69,7 +85,19 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
 			}
 		}
 		return logits, nil
-	// return logits, nil
+
+	case StateInComma:
+		peek := s.braceStack[len(s.braceStack)-1]
+		if peek == rune('[') {
+			s.curNode = s.stateToNodeMap[StateInListComma]
+			fmt.Println("switching to list comma", s.curNode.State)
+		}
+		logits, err := s.maskLogits(logits, s.curNode)
+		if err != nil {
+			return nil, err
+		}
+		return logits, nil
+
 	case StateTerminate:
 		for i := range logits {
 			if s.proc.Is(uint32(i), model.SpecialEOS) {
@@ -80,9 +108,6 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
 		}
 		return logits, nil
 
-	// case StateInStringEnd:
-
-	// 	return logits, nil
 	default:
 		fmt.Println("masking logits current state", s.curNode.State)
 		logits, err := s.maskLogits(logits, s.curNode)
@@ -96,7 +121,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
 func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
 	fmt.Println("update state", s.curNode.State)
 
-	// TODO: need to handle end states and entering object case
+	// TODO: need to handle end states and entering object case, and list case
 	if s.curNode.State == StateInObjectEnd {
 		fmt.Println("in object end")
 		if len(s.braceStack) > 0 {
@@ -111,25 +136,45 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
 	if err != nil {
 		return err
 	}
+	// TODO: should force closing for all braces
 	for _, r := range mappedString {
 		if r == rune('{') {
 			s.braceStack = append(s.braceStack, r)
 		}
+		if r == rune('[') {
+			s.braceStack = append(s.braceStack, r)
+		}
 		if r == rune('}') {
 			if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') {
 				return fmt.Errorf("unmatched closing brace")
 			}
 			s.braceStack = s.braceStack[:len(s.braceStack)-1]
+			fmt.Println("popping brace stack", s.braceStack)
+		}
+
+		if r == rune(']') {
+			if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('[') {
+				return fmt.Errorf("unmatched closing brace")
+			}
+			s.braceStack = s.braceStack[:len(s.braceStack)-1]
+			fmt.Println("popping brace stack", s.braceStack)
 		}
 	}
 	for _, tokenID := range tokenSlice {
 		// transition to the next node
-		nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
+		nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID]
 		if !ok {
 			return fmt.Errorf("invalid token: %q", mappedString)
 		}
-		fmt.Println("transitioning to", nextNode)
-		s.curNode = s.stateToNodeMap[nextNode]
+		fmt.Println("transitioning to", nextNodeState)
+
+		// TODO: add a penalty for staying in the same state too long
+		if nextNodeState == s.curNode.State {
+			s.stateCounter++
+		} else {
+			s.stateCounter = 0
+		}
+		s.curNode = s.stateToNodeMap[nextNodeState]
 	}
 	return nil
 }