|
@@ -13,6 +13,7 @@ type SOSampler struct {
|
|
|
propIdx int
|
|
|
propToNodeMap map[string]*PDANode
|
|
|
pdaSampler *PushdownSampler
|
|
|
+ decodedToks []string
|
|
|
}
|
|
|
|
|
|
func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) {
|
|
@@ -27,6 +28,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
|
|
|
|
|
|
so.schemaToGraph()
|
|
|
|
|
|
+ // This is prob slow
|
|
|
vocab := proc.GetVocabulary()
|
|
|
decodedToks := make([]string, len(vocab.Values))
|
|
|
for i := range vocab.Values {
|
|
@@ -36,6 +38,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
|
|
|
}
|
|
|
decodedToks[i] = token
|
|
|
}
|
|
|
+ so.decodedToks = decodedToks
|
|
|
|
|
|
fmt.Println("--------------------------------")
|
|
|
fmt.Println("SOSampler")
|
|
@@ -47,16 +50,19 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
|
|
|
before := m.Alloc
|
|
|
|
|
|
// TODO: still messed up
|
|
|
- for _, node := range so.propToNodeMap {
|
|
|
+ // TODO: recursion use case
|
|
|
+ // key masks
|
|
|
+ for _, prop := range so.schema.Properties {
|
|
|
+ node := so.propToNodeMap[prop.Name]
|
|
|
// propName -> node
|
|
|
curState := node.State
|
|
|
fromNode := node
|
|
|
- CreateMask(fromNode, proc, decodedToks, vocab)
|
|
|
+ CreateMask(fromNode, proc, decodedToks)
|
|
|
for curState == StateInStructuredKey {
|
|
|
// there is only one edge
|
|
|
for r, toNode := range fromNode.TransitionEdges {
|
|
|
// fmt.Println("rune", r, "edge", toNode.State)
|
|
|
- CreateMask(toNode, proc, decodedToks, vocab)
|
|
|
+ CreateMask(toNode, proc, decodedToks)
|
|
|
fmt.Printf("created mask for %c\n", r)
|
|
|
curState = toNode.State
|
|
|
fmt.Println("next state", curState)
|
|
@@ -80,14 +86,11 @@ func (s *SOSampler) schemaToGraph() {
|
|
|
switch schemaType {
|
|
|
case "object":
|
|
|
// TODO: see if we need to connect these to the JSON graph
|
|
|
- // prevState := StateInObjectKey
|
|
|
- // prevNode := so.stateToNodeMap[prevState]
|
|
|
|
|
|
// each prop is a key
|
|
|
for _, prop := range s.schema.Properties {
|
|
|
// name of key
|
|
|
name := prop.Name
|
|
|
- // prevState := StateInObjectKey
|
|
|
keyNode := &PDANode{
|
|
|
State: StateInStructuredKey, // this is unchanging, will impact sampling
|
|
|
TransitionEdges: make(map[rune]*PDANode),
|
|
@@ -116,10 +119,13 @@ func (s *SOSampler) schemaToGraph() {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (s *SOSampler) Sample(logits []float64) ([]float64, error) {
|
|
|
+func (s *SOSampler) Apply(logits []float64) ([]float64, error) {
|
|
|
switch s.pdaSampler.curNode.State {
|
|
|
// doesnt account for multi rune case
|
|
|
case StateInObjectKey:
|
|
|
+ if s.propIdx > len(s.schema.Properties)-1 {
|
|
|
+ return nil, fmt.Errorf("propIdx out of bounds")
|
|
|
+ }
|
|
|
// fmt.Println("in object key - structured outputs")
|
|
|
// TODO: this tracking should probably be coming from a stack to track nested objects
|
|
|
// simple case
|
|
@@ -136,11 +142,52 @@ func (s *SOSampler) Sample(logits []float64) ([]float64, error) {
|
|
|
return logits, nil
|
|
|
|
|
|
default:
|
|
|
- return s.pdaSampler.Sample(logits)
|
|
|
+
|
|
|
+ // Will only happen for the last prop - can also be precomputed.
|
|
|
+ if s.propIdx == len(s.schema.Properties)-1 {
|
|
|
+ // todo: if i incremenet propidx then i know im in last value as well
|
|
|
+ switch s.pdaSampler.curNode.State {
|
|
|
+ case StateInObjectEnd:
|
|
|
+ fmt.Println("<<<<< in obj end- generating mask for", s.pdaSampler.curNode.State)
|
|
|
+ s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDANode)
|
|
|
+ s.pdaSampler.curNode = NewPDANode(StateTerminate)
|
|
|
+ s.propIdx++
|
|
|
+
|
|
|
+ case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
|
|
|
+ fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
|
|
|
+ delete(s.pdaSampler.curNode.TransitionEdges, ',')
|
|
|
+ s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDANode)
|
|
|
+
|
|
|
+ CreateMask(s.pdaSampler.curNode, s.pdaSampler.proc, s.decodedToks)
|
|
|
+ s.propIdx++
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return s.pdaSampler.Apply(logits)
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
func (s *SOSampler) UpdateState(tokenSlice []int32) error {
|
|
|
- return s.pdaSampler.UpdateState(tokenSlice)
|
|
|
+ err := s.pdaSampler.UpdateState(tokenSlice)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ switch s.pdaSampler.curNode.State {
|
|
|
+ case StateInObjectKey:
|
|
|
+ s.propIdx++
|
|
|
+ fmt.Println("propIdx", s.propIdx)
|
|
|
+ prop := s.schema.Properties[s.propIdx]
|
|
|
+ fmt.Println("prop", prop.Name)
|
|
|
+ s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
|
|
|
+ str, err := s.pdaSampler.proc.Decode(tokenSlice)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ fmt.Println("str", str)
|
|
|
+
|
|
|
+ return nil
|
|
|
+ default:
|
|
|
+ return nil
|
|
|
+ }
|
|
|
}
|