Browse Source

json checkpoint

ParthSareen 3 months ago
parent
commit
a7c8cc06da
2 changed files with 365 additions and 0 deletions
  1. 207 0
      sample/fast_json.go
  2. 158 0
      sample/state_machine.go

+ 207 - 0
sample/fast_json.go

@@ -0,0 +1,207 @@
+package sample
+
+import (
+	"errors"
+	"fmt"
+	"math"
+	"slices"
+
+	"github.com/ollama/ollama/model"
+)
+
+type JSONState int
+
+const (
+	StateStart JSONState = iota
+	StateInObject
+	StateInObjectKey
+	StateNewline
+	StateTab
+	StateSpace
+	StateInString
+	StateInInt
+	StateInFloat
+	StateInBool
+	StateInNull
+	StateInArray
+	StateInColon
+	StateInComma
+	StateInStringEnd
+	StateInObjectKeyEnd
+	StateTerminate
+	StateEnd
+)
+
+func (s JSONState) String() string {
+	switch s {
+	case StateStart:
+		return "StateStart"
+	case StateInObject:
+		return "StateInObject"
+	case StateInObjectKey:
+		return "StateInObjectKey"
+	case StateInString:
+		return "StateInString"
+	case StateNewline:
+		return "StateNewline"
+	case StateTab:
+		return "StateTab"
+	case StateSpace:
+		return "StateSpace"
+	case StateInInt:
+		return "StateInInt"
+	case StateInFloat:
+		return "StateInFloat"
+	case StateInColon:
+		return "StateInColon"
+	case StateInBool:
+		return "StateInBool"
+	case StateInNull:
+		return "StateInNull"
+	case StateInArray:
+		return "StateInArray"
+	case StateEnd:
+		return "StateEnd"
+	case StateInComma:
+		return "StateInComma"
+	case StateInObjectKeyEnd:
+		return "StateInObjectKeyEnd"
+	case StateTerminate:
+		return "StateTerminate"
+	case StateInStringEnd:
+		return "StateInStringEnd"
+	default:
+		return fmt.Sprintf("Unknown state: %d", s)
+	}
+}
+
+type JSONSampler struct {
+	curNode *Node
+	proc    model.TextProcessor
+	stack   []*Node
+}
+
+func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
+	// fmt.Println("Creating new JSON sampler")
+	startNode, err := buildStateMachine(proc)
+	if err != nil {
+		return nil, err
+	}
+	js := &JSONSampler{
+		curNode: startNode,
+		proc:    proc,
+	}
+
+	return js, nil
+}
+
+func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
+	// fmt.Printf("Updating state with token: %v\n", tokenSlice)
+	// fmt.Printf("Current state: %s\n", s.curNode.State)
+
+	// fmt.Println("tokenSlice", tokenSlice)
+	// todo: account for strings here
+	for node, edge := range s.curNode.TransitionEdges {
+		for _, validToken := range edge {
+			if slices.Equal(tokenSlice, validToken) {
+				s.curNode = node
+				// fmt.Printf("Transitioned to state: %s\n", node.State)
+				return nil
+			}
+		}
+	}
+	for node, edge := range s.curNode.TransitionEdges {
+		for _, validToken := range edge {
+			if len(validToken) == 1 && validToken[0] == -1 || validToken[0] == -2 {
+				s.curNode = node
+				// fmt.Printf("Accepting any token, staying in state: %s\n", node.State)
+				return nil
+			}
+		}
+	}
+	fmt.Println("invalid token ", tokenSlice)
+	return errors.New("invalid token")
+}
+
+func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
+	fmt.Printf("Sampling in state: %s\n", s.curNode.State)
+	var err error
+
+	switch s.curNode.State {
+	case StateTerminate:
+		for i := range logits {
+			if s.proc.Is(uint32(i), model.SpecialEOS) {
+				logits[i] = 1.0
+			} else {
+				logits[i] = math.NaN()
+			}
+		}
+		return logits, nil
+
+	case StateInInt:
+		validStates := []int32{}
+		minus, err := s.proc.Encode("-")
+		if err != nil {
+			return nil, err
+		}
+		digits := make([][]int32, 10)
+		for i := 0; i < 10; i++ {
+			digits[i], err = s.proc.Encode(fmt.Sprintf("%d", i))
+			if err != nil {
+				return nil, err
+			}
+		}
+		// Allow "-" and digits 0-9 at start
+		for i := range logits {
+			for _, d := range digits {
+				if len(d) == 1 && int32(i) == d[0] {
+					validStates = append(validStates, int32(i))
+				}
+			}
+			if len(minus) == 1 && int32(i) == minus[0] {
+				validStates = append(validStates, int32(i))
+			}
+		}
+		return logits, nil
+
+	default:
+		validStates := getValidStates(s.curNode)
+		logits, err = s.maskLogits(logits, validStates)
+		if err != nil {
+			return nil, err
+		}
+		return logits, nil
+	}
+}
+
+func getValidStates(node *Node) []int32 {
+	validStates := []int32{}
+	for _, edge := range node.TransitionEdges {
+		for _, token := range edge {
+			validStates = append(validStates, token...)
+		}
+	}
+	return validStates
+}
+
+func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) {
+	// fmt.Printf("Masking logits with valid states: %v\n", validStates)
+	for i := range logits {
+		isValid := false
+		for _, token := range validStates {
+			if token == -1 {
+				// fmt.Println("Found sentinel token, returning unmasked logits")
+				return logits, nil
+			}
+			if i == int(token) {
+				// fmt.Printf("Found valid token: %d\n", token)
+				isValid = true
+				break
+			}
+		}
+		if !isValid {
+			logits[i] = math.NaN()
+		}
+	}
+	return logits, nil
+}

+ 158 - 0
sample/state_machine.go

@@ -0,0 +1,158 @@
+package sample
+
+import (
+	"fmt"
+
+	"github.com/ollama/ollama/model"
+)
+
+type token []int32
+
+type Node struct {
+	State           JSONState
+	TransitionEdges map[*Node][]token
+}
+
+func NewNode(state JSONState) *Node {
+	return &Node{
+		State:           state,
+		TransitionEdges: make(map[*Node][]token),
+	}
+}
+
+var (
+	startToken     token
+	endToken       token
+	stringToken    token
+	objectKeyToken token
+	tabToken       token
+	spaceToken     token
+	newlineToken   token
+	newlineSpace   token
+	commaToken     token
+	commaToken2    token
+	commaToken3    token
+	colonToken     token
+	colonToken2    token
+)
+
+func initTokens(proc model.TextProcessor) error {
+	var err error
+	startToken, err = proc.Encode("{")
+	if err != nil {
+		return err
+	}
+	endToken, err = proc.Encode("}")
+	if err != nil {
+		return err
+	}
+	stringToken, err = proc.Encode("\"")
+	if err != nil {
+		return err
+	}
+	objectKeyToken, err = proc.Encode("\"")
+	if err != nil {
+		return err
+	}
+	tabToken, err = proc.Encode("\t")
+	if err != nil {
+		return err
+	}
+	spaceToken, err = proc.Encode(" ")
+	if err != nil {
+		return err
+	}
+	newlineToken, err = proc.Encode("\n")
+	if err != nil {
+		return err
+	}
+	newlineSpace, err = proc.Encode(" \n")
+	if err != nil {
+		return err
+	}
+	// TODO: figure out how to encode colon correctly
+	colonToken, err = proc.Encode("\":")
+	if err != nil {
+		return err
+	}
+	fmt.Println("colonToken", colonToken)
+	colonToken2, err = proc.Encode(":")
+	if err != nil {
+		return err
+	}
+	commaToken, err = proc.Encode(",")
+	if err != nil {
+		return err
+	}
+	commaToken2, err = proc.Encode("\",")
+	if err != nil {
+		return err
+	}
+	fmt.Println("commaToken2", commaToken2)
+	commaToken3, err = proc.Encode("\",\"")
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func buildStateMachine(proc model.TextProcessor) (*Node, error) {
+	if err := initTokens(proc); err != nil {
+		return nil, err
+	}
+
+	startNode := NewNode(StateStart)
+	objectNode := NewNode(StateInObject)
+	objectKeyNode := NewNode(StateInObjectKey)
+	objectKeyEndNode := NewNode(StateInObjectKeyEnd)
+	stringNode := NewNode(StateInString)
+	intNode := NewNode(StateInInt)
+	commaNode := NewNode(StateInComma)
+	colonNode := NewNode(StateInColon)
+	stringEndNode := NewNode(StateInStringEnd)
+	endNode := NewNode(StateEnd)
+	terminateNode := NewNode(StateTerminate)
+
+	sentinelToken := token([]int32{-1})
+	intSentinelToken := token([]int32{-2})
+
+	startNode.TransitionEdges[objectNode] = []token{startToken}
+
+	objectNode.TransitionEdges[objectKeyNode] = []token{stringToken}
+	// objectNode.TransitionEdges[objectNode] = []token{newlineToken}
+	// objectNode.TransitionEdges[objectNode] = []token{spaceToken}
+
+	objectKeyNode.TransitionEdges[objectKeyNode] = []token{sentinelToken}
+	objectKeyNode.TransitionEdges[colonNode] = []token{colonToken, colonToken2}
+	// characterize end of object key
+	objectKeyNode.TransitionEdges[objectKeyEndNode] = []token{stringToken}
+
+	objectKeyEndNode.TransitionEdges[colonNode] = []token{colonToken}
+
+	// objectKeyNode.TransitionEdges[intNode] = []token{sentinelToken}
+
+	intNode.TransitionEdges[intNode] = []token{intSentinelToken}
+	intNode.TransitionEdges[commaNode] = []token{commaToken, commaToken2}
+	intNode.TransitionEdges[terminateNode] = []token{endToken}
+
+	commaNode.TransitionEdges[objectKeyNode] = []token{newlineToken}
+
+	colonNode.TransitionEdges[stringNode] = []token{stringToken}
+	colonNode.TransitionEdges[intNode] = []token{intSentinelToken}
+
+	stringNode.TransitionEdges[stringNode] = []token{sentinelToken}
+	stringNode.TransitionEdges[stringEndNode] = []token{stringToken}
+	// "\""," Case
+	stringNode.TransitionEdges[commaNode] = []token{commaToken2}
+
+	// "\"",\"" Case
+	stringNode.TransitionEdges[objectKeyNode] = []token{commaToken3}
+
+	stringEndNode.TransitionEdges[commaNode] = []token{commaToken, commaToken2}
+	stringEndNode.TransitionEdges[terminateNode] = []token{endToken}
+
+	endNode.TransitionEdges[terminateNode] = []token{endToken}
+
+	terminateNode.TransitionEdges[terminateNode] = []token{}
+	return startNode, nil
+}