|
@@ -1,54 +1,145 @@
|
|
|
package sample
|
|
|
|
|
|
-import "github.com/ollama/ollama/model"
|
|
|
+import (
|
|
|
+ "fmt"
|
|
|
+ "runtime"
|
|
|
+ "time"
|
|
|
|
|
|
-type StructuredOutput struct {
|
|
|
- schema *Schema
|
|
|
- stateToNodeMap map[JSONState]*PDANode
|
|
|
+ "github.com/ollama/ollama/model"
|
|
|
+)
|
|
|
+
|
|
|
+type SOSampler struct {
|
|
|
+ schema *Schema
|
|
|
+ propIdx int
|
|
|
+ propStateMap map[string]*PDANode
|
|
|
+ pdaSampler *PushdownSampler
|
|
|
}
|
|
|
|
|
|
-func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *StructuredOutput {
|
|
|
- _, stateToNodeMap, err := BuildGraph(proc)
|
|
|
- if err != nil {
|
|
|
- panic(err)
|
|
|
+func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) {
|
|
|
+ pdaSampler := NewPushdownSampler(proc)
|
|
|
+
|
|
|
+ so := &SOSampler{
|
|
|
+ schema: schema,
|
|
|
+ propIdx: -1,
|
|
|
+ propStateMap: make(map[string]*PDANode),
|
|
|
+ pdaSampler: pdaSampler,
|
|
|
}
|
|
|
|
|
|
- return &StructuredOutput{
|
|
|
- schema: schema,
|
|
|
- stateToNodeMap: stateToNodeMap,
|
|
|
+ so.schemaToGraph()
|
|
|
+
|
|
|
+ vocab := proc.GetVocabulary()
|
|
|
+ decodedToks := make([]string, len(vocab.Values))
|
|
|
+ for i := range vocab.Values {
|
|
|
+ token, err := proc.Decode([]int32{int32(i)})
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ decodedToks[i] = token
|
|
|
}
|
|
|
-}
|
|
|
|
|
|
-func (so *StructuredOutput) schemaToGraph(proc model.TextProcessor) *PDANode {
|
|
|
+ fmt.Println("--------------------------------")
|
|
|
+ fmt.Println("SOSampler")
|
|
|
+ fmt.Println("--------------------------------")
|
|
|
+ // Benchmark this section
|
|
|
+ start := time.Now()
|
|
|
+ var m runtime.MemStats
|
|
|
+ runtime.ReadMemStats(&m)
|
|
|
+ before := m.Alloc
|
|
|
|
|
|
- schemaType := so.schema.EffectiveType()
|
|
|
+ // TODO: still messed up
|
|
|
+ for _, node := range so.propStateMap {
|
|
|
+ // propName -> node
|
|
|
+ curState := node.State
|
|
|
+ fromNode := node
|
|
|
+ CreateMask(fromNode, proc, decodedToks, vocab)
|
|
|
+ for curState == StateInStructuredKey {
|
|
|
+ // there is only one edge
|
|
|
+ for r, toNode := range fromNode.TransitionEdges {
|
|
|
+ // fmt.Println("rune", r, "edge", toNode.State)
|
|
|
+ CreateMask(toNode, proc, decodedToks, vocab)
|
|
|
+ fmt.Printf("created mask for %c\n", r)
|
|
|
+ curState = toNode.State
|
|
|
+ fmt.Println("next state", curState)
|
|
|
+ // TODO: theres an extra gen for " right now
|
|
|
+ fromNode = toNode
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ runtime.ReadMemStats(&m)
|
|
|
+ after := m.Alloc
|
|
|
+ fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
|
|
+ fmt.Printf("Mask creation time = %v\n", time.Since(start))
|
|
|
+ fmt.Println("--------------------------------")
|
|
|
+
|
|
|
+ return so, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (s *SOSampler) schemaToGraph() {
|
|
|
+ schemaType := s.schema.EffectiveType()
|
|
|
switch schemaType {
|
|
|
case "object":
|
|
|
- // each prop is a key
|
|
|
+ // TODO: see if we need to connect these to the JSON graph
|
|
|
// prevState := StateInObjectKey
|
|
|
- for _, prop := range so.schema.Properties {
|
|
|
+ // prevNode := so.stateToNodeMap[prevState]
|
|
|
+
|
|
|
+ // each prop is a key
|
|
|
+ for _, prop := range s.schema.Properties {
|
|
|
// name of key
|
|
|
name := prop.Name
|
|
|
- prevState := StateInObjectKey
|
|
|
- for i, r := range name {
|
|
|
- newState := JSONState(int(StateInObjectKey) + i + 1) // Create new unique state for each rune
|
|
|
-
|
|
|
- // Create new node for this state if it doesn't exist
|
|
|
- if _, exists := so.stateToNodeMap[newState]; !exists {
|
|
|
- so.stateToNodeMap[newState] = &PDANode{
|
|
|
- State: newState,
|
|
|
- TransitionEdges: make(map[rune]*PDANode),
|
|
|
- MaskTokenIDToNode: make(map[int32]JSONState),
|
|
|
- }
|
|
|
- }
|
|
|
+ // prevState := StateInObjectKey
|
|
|
+ keyNode := &PDANode{
|
|
|
+ State: StateInStructuredKey, // this is unchanging, will impact sampling
|
|
|
+ TransitionEdges: make(map[rune]*PDANode),
|
|
|
+ MaskTokenIDToNode: make(map[int32]*PDANode),
|
|
|
+ }
|
|
|
|
|
|
- // Connect previous state to this state via the rune
|
|
|
- so.stateToNodeMap[prevState].TransitionEdges[r] = so.stateToNodeMap[newState]
|
|
|
- prevState = newState
|
|
|
+ prevNode := keyNode
|
|
|
+ for _, r := range name {
|
|
|
+ runeNode := &PDANode{
|
|
|
+ State: StateInStructuredKey, // this is unchanging, will impact sampling
|
|
|
+ TransitionEdges: make(map[rune]*PDANode),
|
|
|
+ MaskTokenIDToNode: make(map[int32]*PDANode),
|
|
|
+ }
|
|
|
+ fmt.Println("runeNode created", runeNode.State)
|
|
|
+ fmt.Printf("runeNode created %c\n", r)
|
|
|
+ // since alloc on heap connections wil still map
|
|
|
+ prevNode.TransitionEdges[r] = runeNode
|
|
|
+ prevNode = runeNode
|
|
|
}
|
|
|
- // type of value
|
|
|
- // propType := prop.Type
|
|
|
+ // point to end of object key node after all chars are done
|
|
|
+ prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
|
|
|
+ // points to start of the key
|
|
|
+ s.propStateMap[name] = keyNode
|
|
|
+ fmt.Println("name", name, "keyNode", keyNode.State)
|
|
|
}
|
|
|
}
|
|
|
- return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (s *SOSampler) Sample(logits []float64) ([]float64, error) {
|
|
|
+ switch s.pdaSampler.curNode.State {
|
|
|
+ // doesnt account for multi rune case
|
|
|
+ case StateInObjectKey:
|
|
|
+ // fmt.Println("in object key - structured outputs")
|
|
|
+ // TODO: this tracking should probably be coming from a stack to track nested objects
|
|
|
+ // simple case
|
|
|
+ s.propIdx++
|
|
|
+ prop := s.schema.Properties[s.propIdx]
|
|
|
+ // fmt.Println("prop", prop.Name)
|
|
|
+ s.pdaSampler.curNode = s.propStateMap[prop.Name]
|
|
|
+ // fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
|
|
|
+ logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ return logits, nil
|
|
|
+
|
|
|
+ default:
|
|
|
+ return s.pdaSampler.Sample(logits)
|
|
|
+ }
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
+func (s *SOSampler) UpdateState(tokenSlice []int32) error {
|
|
|
+ return s.pdaSampler.UpdateState(tokenSlice)
|
|
|
}
|