123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299 |
- package sample
- import (
- "fmt"
- "log/slog"
- "runtime"
- "time"
- "github.com/ollama/ollama/grammar/jsonschema"
- "github.com/ollama/ollama/model"
- )
- type JSONSampler struct {
- schema *jsonschema.Schema
- propIdx int
- propToNodeMap map[string]*PDA
- pdaSampler *PushdownSampler
- decodedToks []string
- }
- func NewJSONSampler(proc model.TextProcessor, schema *jsonschema.Schema) (*JSONSampler, error) {
- slog.Info("NewJSONSampler", "schema", schema)
- if proc == nil {
- return nil, fmt.Errorf("TextProcessor cannot be nil")
- }
- pdaSampler, err := NewPushdownSampler(proc)
- if err != nil {
- return nil, fmt.Errorf("failed to create PushdownSampler: %w", err)
- }
- if schema == nil {
- return &JSONSampler{
- schema: nil,
- propIdx: -1,
- propToNodeMap: nil,
- pdaSampler: pdaSampler,
- }, nil
- }
- // fmt.Println("schema not nil")
- so := &JSONSampler{
- schema: schema,
- propIdx: -1,
- propToNodeMap: make(map[string]*PDA),
- pdaSampler: pdaSampler,
- }
- so.schemaToGraph()
- // Benchmark token decoding
- start := time.Now()
- var m runtime.MemStats
- runtime.ReadMemStats(&m)
- before := m.Alloc
- vocab := proc.Vocab()
- decodedToks := make([]string, len(vocab.Values))
- for i := range vocab.Values {
- token, err := proc.Decode([]int32{int32(i)})
- if err != nil {
- return nil, err
- }
- decodedToks[i] = token
- }
- so.decodedToks = decodedToks
- runtime.ReadMemStats(&m)
- after := m.Alloc
- fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
- fmt.Printf("Token decode time = %v\n", time.Since(start))
- fmt.Println("--------------------------------")
- fmt.Println("SOSampler")
- fmt.Println("--------------------------------")
- // Benchmark this section
- start = time.Now()
- runtime.ReadMemStats(&m)
- before = m.Alloc
- // TODO: still messed up
- // TODO: recursion use case
- // key masks
- for _, prop := range so.schema.Properties {
- node := so.propToNodeMap[prop.Name]
- // propName -> node
- curState := node.State
- fromNode := node
- so.pdaSampler.CreateMask(fromNode)
- for curState == StateInStructuredKey {
- // there is only one edge
- for r, toNode := range fromNode.TransitionEdges {
- fmt.Println("rune", r, "edge", toNode.State)
- so.pdaSampler.CreateMask(toNode)
- fmt.Printf("created mask for %c\n", r)
- curState = toNode.State
- fmt.Println("next state", curState)
- // TODO: theres an extra gen for " right now
- fromNode = toNode
- }
- }
- if curState != StateInColon {
- return nil, fmt.Errorf("expected state to be StateInColon, got %v", curState)
- }
- // so.pdaSampler.CreateMask(fromNode)
- fromNode = fromNode.TransitionEdges[' ']
- so.pdaSampler.CreateMask(fromNode)
- curState = fromNode.State
- for _, toNode := range fromNode.TransitionEdges {
- fmt.Println("toNode", toNode.State)
- }
- }
- // runtime.ReadMemStats(&m)
- // after = m.Alloc
- // fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
- // fmt.Printf("Mask creation time = %v\n", time.Since(start))
- // fmt.Println("--------------------------------")
- return so, nil
- }
- func (s *JSONSampler) schemaToGraph() {
- schemaType := s.schema.EffectiveType()
- switch schemaType {
- case "object":
- // TODO: see if we need to connect these to the JSON graph
- // each prop is a key
- for _, prop := range s.schema.Properties {
- // name of key
- name := prop.Name
- keyNode := &PDA{
- State: StateInStructuredKey, // this is unchanging, will impact sampling
- TransitionEdges: make(map[rune]*PDA),
- MaskTokenIDToNode: make(map[int32]*PDA),
- }
- prevNode := keyNode
- for _, r := range name {
- runeNode := &PDA{
- State: StateInStructuredKey, // this is unchanging, will impact sampling
- TransitionEdges: make(map[rune]*PDA),
- MaskTokenIDToNode: make(map[int32]*PDA),
- }
- // fmt.Println("runeNode created", runeNode.State)
- // fmt.Printf("runeNode created %c\n", r)
- // since alloc on heap connections wil still map
- prevNode.TransitionEdges[r] = runeNode
- prevNode = runeNode
- }
- // point to end of object key node after all chars are done
- // prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
- // link to value node
- // Create a node for the end of the key (after the closing quote)
- stringEndNode := &PDA{
- State: StateInStructuredKey,
- TransitionEdges: make(map[rune]*PDA),
- MaskTokenIDToNode: make(map[int32]*PDA),
- }
- prevNode.TransitionEdges['"'] = stringEndNode
- prevNode = stringEndNode
- // Add transition for colon after key
- colonNode := &PDA{
- State: StateInColon,
- TransitionEdges: make(map[rune]*PDA),
- MaskTokenIDToNode: make(map[int32]*PDA),
- }
- prevNode.TransitionEdges[':'] = colonNode
- prevNode = colonNode
- // Add transition for space after colon
- spaceNode := &PDA{
- State: StateInSpaceToValue,
- TransitionEdges: make(map[rune]*PDA),
- MaskTokenIDToNode: make(map[int32]*PDA),
- }
- prevNode.TransitionEdges[' '] = spaceNode
- prevNode = spaceNode
- value := prop.Type
- switch value {
- case "object":
- fmt.Println("object under key: ", name)
- case "array":
- fmt.Println("array under key: ", name)
- case "string":
- fmt.Println("string under key: ", name)
- prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInString]
- case "number":
- fmt.Println("number under key: ", name)
- for _, r := range validNumberRunes {
- prevNode.TransitionEdges[r] = s.pdaSampler.stateToNodeMap[StateInNumber]
- }
- case "boolean":
- fmt.Println("boolean under key: ", name)
- prevNode.TransitionEdges['t'] = s.pdaSampler.stateToNodeMap[StateInBool]
- prevNode.TransitionEdges['f'] = s.pdaSampler.stateToNodeMap[StateInBool]
- prevNode.TransitionEdges['n'] = s.pdaSampler.stateToNodeMap[StateInNull]
- }
- // points to start of the key
- s.propToNodeMap[name] = keyNode
- fmt.Println("name", name, "keyNode", keyNode.State)
- }
- }
- // TODO: do values + recursion
- }
- func (s *JSONSampler) Apply(logits []float32) ([]float32, error) {
- if s.schema == nil {
- return s.pdaSampler.Apply(logits)
- }
- switch s.pdaSampler.curNode.State {
- // TODO: doesnt account for multi rune case
- case StateInObjectKey:
- if s.propIdx > len(s.schema.Properties)-1 {
- return nil, fmt.Errorf("propIdx out of bounds")
- }
- // fmt.Println("in object key - structured outputs")
- // TODO: this tracking should probably be coming from a stack to track nested objects
- // simple case
- s.propIdx++
- fmt.Println("propIdx", s.propIdx)
- prop := s.schema.Properties[s.propIdx]
- fmt.Println("prop", prop.Name)
- s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
- fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
- logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
- if err != nil {
- return nil, err
- }
- return logits, nil
- default:
- // Will only happen for the last prop - can also be precomputed.
- if s.propIdx == len(s.schema.Properties)-1 {
- // todo: if i incremenet propidx then i know im in last value as well
- switch s.pdaSampler.curNode.State {
- case StateInObjectEnd:
- fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State)
- s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA)
- s.pdaSampler.curNode = NewPDANode(StateTerminate)
- s.propIdx++
- // TODO: this needs to be optimized in some way, computing mask on the fly is expensive
- case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
- fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
- delete(s.pdaSampler.curNode.TransitionEdges, ',')
- s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA)
- s.pdaSampler.CreateMask(s.pdaSampler.curNode)
- s.propIdx++
- }
- }
- return s.pdaSampler.Apply(logits)
- }
- }
- func (s *JSONSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
- tokenSlice, err := s.pdaSampler.UpdateState(tokenSlice)
- if err != nil {
- return nil, err
- }
- if s.schema == nil {
- // Don't need to update state for unconstrained JSON sampling
- return tokenSlice, nil
- }
- switch s.pdaSampler.curNode.State {
- case StateInObjectKey:
- s.propIdx++
- fmt.Println("propIdx", s.propIdx)
- prop := s.schema.Properties[s.propIdx]
- fmt.Println("prop", prop.Name)
- s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
- // TODO: this does not work - mike
- // str, err := s.pdaSampler.proc.Decode(tokenSlice)
- // if err != nil {
- // return nil, err
- // }
- // fmt.Println("str", str)
- return tokenSlice, nil
- default:
- return tokenSlice, nil
- }
- }
|