123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218 |
- package sample
- import (
- "fmt"
- "github.com/ollama/ollama/model"
- )
- type token []int32
- type Node struct {
- State JSONState
- TransitionEdges map[*Node][]token
- }
- func NewNode(state JSONState) *Node {
- return &Node{
- State: state,
- TransitionEdges: make(map[*Node][]token),
- }
- }
- var (
- // startToken token
- startTokenVariants []token
- // endToken token
- // stringToken token
- // objectKeyToken token
- tabToken token
- spaceToken token
- newlineToken token
- newlineSpace token
- // commaToken token
- // commaToken2 token
- // commaToken3 token
- // colonToken token
- // colonToken2 token
- colonTokenVariants []token
- commaTokenVariants []token
- stringTokenVariants []token
- endTokenVariants []token
- objectKeyTokenVariants []token
- objKeyToColonVariants []token
- stringToObjectKeyVariants []token
- stringToCommaVariants []token
- stringToObjectVariants []token
- stringEndToObjectEndVariants []token
- stringEndToCommaVariants []token
- )
- func ComputeTokenVariants(variants []string, proc model.TextProcessor) ([]token, error) {
- var allTokens token
- for _, variant := range variants {
- if t, err := proc.Encode(variant); err == nil {
- allTokens = append(allTokens, t...)
- }
- }
- if len(allTokens) == 0 {
- return nil, fmt.Errorf("no valid tokens found for variants")
- }
- return []token{allTokens}, nil
- }
- func initTokens(proc model.TextProcessor) error {
- var err error
- s, err := proc.Decode([]int32{761})
- fmt.Printf("761 decoded %q\n", s)
- // Compute start token variants
- startVariants := []string{"{", " {", "{\n", " {\n"}
- startTokenVariants, err = ComputeTokenVariants(startVariants, proc)
- if err != nil {
- return err
- }
- // Compute end token variants
- endVariants := []string{"}", " }", "}\n", " }\n"}
- endTokenVariants, err = ComputeTokenVariants(endVariants, proc)
- if err != nil {
- return err
- }
- // Compute string token variants
- // TODO: removed \n
- stringVariants := []string{"\"", " \""}
- stringTokenVariants, err = ComputeTokenVariants(stringVariants, proc)
- if err != nil {
- return err
- }
- stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\",\n"}, proc)
- if err != nil {
- return err
- }
- // objectKeyTokenVariants = []token{stringTokenVariants[0], stringTokenVariants[1]}
- objectKeyTokenVariants = stringTokenVariants
- // Compute whitespace tokens
- tabToken, err = proc.Encode("\t")
- if err != nil {
- return err
- }
- spaceToken, err = proc.Encode(" ")
- if err != nil {
- return err
- }
- newlineToken, err = proc.Encode("\n")
- if err != nil {
- return err
- }
- newlineSpace, err = proc.Encode(" \n")
- if err != nil {
- return err
- }
- // Compute colon variants
- colonVariants := []string{":"}
- colonTokenVariants, err = ComputeTokenVariants(colonVariants, proc)
- if err != nil {
- return err
- }
- objKeyToColonVariants, err = ComputeTokenVariants([]string{"\":"}, proc)
- if err != nil {
- return err
- }
- // Compute comma variants
- commaVariants := []string{",", " ,", ",\n", "\",", "\", "}
- commaTokenVariants, err = ComputeTokenVariants(commaVariants, proc)
- if err != nil {
- return err
- }
- fmt.Printf("commaTokenVariants: %v\n", commaTokenVariants)
- stringToCommaVariants, err = ComputeTokenVariants([]string{"\",", "\","}, proc)
- if err != nil {
- return err
- }
- stringEndToCommaVariants, err = ComputeTokenVariants([]string{",", ",\n"}, proc)
- stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\","}, proc)
- stringToObjectVariants, err = ComputeTokenVariants([]string{"\",\n"}, proc)
- stringEndToObjectEndVariants, err = ComputeTokenVariants([]string{"\n"}, proc)
- return nil
- }
- func buildStateMachine(proc model.TextProcessor) (*Node, error) {
- if err := initTokens(proc); err != nil {
- return nil, err
- }
- startNode := NewNode(StateStart)
- objectNode := NewNode(StateInObject)
- objectKeyNode := NewNode(StateInObjectKey)
- objectKeyEndNode := NewNode(StateInObjectKeyEnd)
- stringNode := NewNode(StateInString)
- // intNode := NewNode(StateInInt)
- commaNode := NewNode(StateInComma)
- colonNode := NewNode(StateInColon)
- stringEndNode := NewNode(StateInStringEnd)
- endNode := NewNode(StateEnd)
- terminateNode := NewNode(StateTerminate)
- sentinelToken := token([]int32{-1})
- // intSentinelToken := token([]int32{-2})
- // TODO: cleanup connections of rules
- startNode.TransitionEdges[objectNode] = startTokenVariants
- objectNode.TransitionEdges[objectKeyNode] = stringTokenVariants
- objectNode.TransitionEdges[objectNode] = []token{newlineToken}
- objectNode.TransitionEdges[objectNode] = []token{spaceToken}
- // objectNode.TransitionEdges[objectNode] = []token{newlineToken}
- // objectNode.TransitionEdges[objectNode] = []token{spaceToken}
- objectKeyNode.TransitionEdges[objectKeyNode] = []token{sentinelToken}
- // characterize end of object key
- objectKeyNode.TransitionEdges[objectKeyEndNode] = stringTokenVariants
- objectKeyNode.TransitionEdges[colonNode] = objKeyToColonVariants
- // TODO: enable this - key -> object
- // objectKeyNode.TransitionEdges[objectNode] = startTokenVariants
- // objectKeyNode.TransitionEdges[intNode] = []token{sentinelToken}
- // intNode.TransitionEdges[intNode] = []token{intSentinelToken}
- // intNode.TransitionEdges[commaNode] = commaTokenVariants
- // TODO: handle
- // intNode.TransitionEdges[terminateNode] = endTokenVariants
- commaNode.TransitionEdges[objectKeyNode] = stringTokenVariants
- // commaNode.TransitionEdges[objectNode] = startTokenVariants
- colonNode.TransitionEdges[stringNode] = stringTokenVariants
- //TODO: enable
- // colonNode.TransitionEdges[intNode] = []token{intSentinelToken}
- colonNode.TransitionEdges[objectNode] = startTokenVariants
- stringNode.TransitionEdges[stringNode] = []token{sentinelToken}
- stringNode.TransitionEdges[stringEndNode] = stringTokenVariants
- // TODO: "\""," Case not accounted for
- stringNode.TransitionEdges[commaNode] = stringToCommaVariants
- // TODO: "\"",\"" Case not accounted for
- stringNode.TransitionEdges[objectNode] = stringToObjectVariants
- stringEndNode.TransitionEdges[commaNode] = stringEndToCommaVariants
- stringEndNode.TransitionEdges[objectNode] = stringToObjectKeyVariants
- stringEndNode.TransitionEdges[endNode] = stringEndToObjectEndVariants
- // stringEndNode.TransitionEdges[terminateNode] = endTokenVariants
- // Should be obj end
- // TODO: handle
- endNode.TransitionEdges[terminateNode] = []token{}
- endNode.TransitionEdges[commaNode] = commaTokenVariants
- terminateNode.TransitionEdges[terminateNode] = []token{}
- return startNode, nil
- }
|