ParthSareen 3 mēneši atpakaļ
vecāks
revīzija
524029cd6d

+ 29 - 5
model/cmd/main.go

@@ -106,11 +106,31 @@ func temp() error {
 		}
 	}
 
-	pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
+	// pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
+
+	// simple schema
+	// This schema maps to JSON like:
+	// {
+	//   "name": "some string value"
+	// }
+	schema := &sample.Schema{
+		Name: "root",
+		Type: "object",
+		Properties: []*sample.Schema{
+			{Name: "name", Type: "string"},
+		},
+	}
+
+	pushdownSampler, err := sample.NewSOSampler(schema, m.(model.TextProcessor))
+	if err != nil {
+		return err
+	}
 
 	var offset int
 	var stringBuffer string
 	var firstTokenTime time.Duration
+	var totalSamplingTime time.Duration
+	count := 0
 	for range args.n {
 		logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
 		if err != nil {
@@ -122,7 +142,6 @@ func temp() error {
 		for i, f32 := range f32s {
 			f64s[i] = float64(f32)
 		}
-		sampleTime := time.Now()
 		samplers := []sample.Sampler{
 			pushdownSampler,
 			// sample.Weighed(),
@@ -131,12 +150,16 @@ func temp() error {
 			sample.Greedy(),
 		}
 
+		samplingStart := time.Now()
 		f64s, err = sample.Sample(f64s, samplers...)
 		if err != nil {
 			return err
 		}
-		finishTime := time.Now()
-		fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
+		samplingTime := time.Since(samplingStart)
+		totalSamplingTime += samplingTime
+
+		fmt.Println("sampling time", samplingTime)
+		// fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
 
 		var outputIDs []int32
 		for _, f64 := range f64s {
@@ -164,6 +187,7 @@ func temp() error {
 		// fmt.Printf("--- token: %q\n", s)
 		// fmt.Printf("--- outputIDs: %v\n", outputIDs)
 		stringBuffer += s
+		count++
 		fmt.Println("--- stringBuffer", stringBuffer)
 
 		err = pushdownSampler.UpdateState(outputIDs)
@@ -179,7 +203,7 @@ func temp() error {
 	fmt.Println("\n------ Output: ------")
 	fmt.Println(stringBuffer)
 	fmt.Println("--------------------")
-
+	fmt.Println("sample average time", totalSamplingTime/time.Duration(count))
 	return nil
 }
 

+ 5 - 0
sample/fast_json.go

@@ -10,6 +10,8 @@ const (
 	StateStart JSONState = iota
 	StateInObject
 	StateInObjectKey
+	StateInStructuredKey
+	StateInStructuredValue
 	StateNewline
 	StateTab
 	StateSpace
@@ -43,6 +45,7 @@ var JSONStates = []JSONState{
 	StateStart,
 	StateInObject,
 	StateInObjectKey,
+	StateInStructuredKey,
 	StateNewline,
 	StateTab,
 	StateSpace,
@@ -80,6 +83,8 @@ func (s JSONState) String() string {
 		return "StateInObject"
 	case StateInObjectKey:
 		return "StateInObjectKey"
+	case StateInStructuredKey:
+		return "StateInStructuredKey"
 	case StateNewline:
 		return "StateNewline"
 	case StateTab:

+ 9 - 5
sample/pushdown_automata.go

@@ -21,14 +21,14 @@ var validNullRunes = []rune{'n', 'u', 'l', 'l'}
 type PDANode struct {
 	State             JSONState
 	TransitionEdges   map[rune]*PDANode
-	MaskTokenIDToNode map[int32]JSONState
+	MaskTokenIDToNode map[int32]*PDANode
 }
 
 func NewPDANode(state JSONState) *PDANode {
 	return &PDANode{
 		State:             state,
 		TransitionEdges:   make(map[rune]*PDANode),
-		MaskTokenIDToNode: make(map[int32]JSONState),
+		MaskTokenIDToNode: make(map[int32]*PDANode),
 	}
 }
 
@@ -103,6 +103,8 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
 	stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
 	stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
 	stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
+	// empty list
+	stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
 	addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
 
 	// null node
@@ -162,6 +164,7 @@ func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
 // TODO: tough life fr. plz fix.
 func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
 
+	// TODO; should come from top level
 	vocab := proc.GetVocabulary()
 
 	decodedToks := make([]string, len(vocab.Values))
@@ -175,7 +178,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
 
 	var err error
 	for _, node := range stateToNodeMap {
-		err = createMask(node, proc, decodedToks, vocab)
+		err = CreateMask(node, proc, decodedToks, vocab)
 		if err != nil {
 			return err
 		}
@@ -183,7 +186,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
 	return nil
 }
 
-func createMask(node *PDANode, proc model.TextProcessor, decodedToks []string, vocab *model.Vocabulary) error {
+func CreateMask(node *PDANode, proc model.TextProcessor, decodedToks []string, vocab *model.Vocabulary) error {
 	for i := range vocab.Values {
 		token := decodedToks[i]
 		// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
@@ -204,7 +207,8 @@ func createMask(node *PDANode, proc model.TextProcessor, decodedToks []string, v
 			}
 		}
 		if valid {
-			node.MaskTokenIDToNode[int32(i)] = curNode.State
+			// cur node allows skipping
+			node.MaskTokenIDToNode[int32(i)] = curNode
 		}
 	}
 	return nil

+ 19 - 14
sample/pushdown_runner.go

@@ -4,7 +4,6 @@ import (
 	"fmt"
 	"math"
 	"runtime"
-	"time"
 
 	"github.com/ollama/ollama/model"
 )
@@ -22,12 +21,15 @@ type PushdownSampler struct {
 
 // graph should be built once and reused per tokenizer
 func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
-	start := time.Now()
+	// start := time.Now()
 
+	// fmt.Println("--------------------------------")
+	// fmt.Println("PDA sampler")
+	// fmt.Println("--------------------------------")
 	var m runtime.MemStats
 	runtime.ReadMemStats(&m)
-	before := m.Alloc
-	fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
+	// before := m.Alloc
+	// fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
 
 	startNode, stateToNodeMap, err := BuildGraph(proc)
 	if err != nil {
@@ -38,10 +40,10 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
 		panic(err)
 	}
 	runtime.ReadMemStats(&m)
-	after := m.Alloc
-	fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
-	fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
-	fmt.Printf("Graph build time = %v\n", time.Since(start))
+	// after := m.Alloc
+	// fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
+	// fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
+	// fmt.Printf("Graph build time = %v\n", time.Since(start))
 
 	return &PushdownSampler{
 		curNode:        startNode,
@@ -53,6 +55,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
 }
 
 // TODO: need to add resampling logic if the first sample was not good
+// greedy sample + backtrack?
 func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
 	// fmt.Println(">>> sample:", s.curNode.State)
 	switch s.curNode.State {
@@ -60,7 +63,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
 		return s.maskLogits(logits, s.curNode)
 
 	case StateInListEnd:
-		fmt.Println("in list end", s.braceStack)
+		// fmt.Println("in list end", s.braceStack)
 		// force finish if no braces left
 		if len(s.braceStack) == 0 {
 			s.curNode = NewPDANode(StateTerminate)
@@ -139,12 +142,12 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
 }
 
 func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
-	fmt.Println("update state", s.curNode.State)
+	// fmt.Println("current state - updating", s.curNode.State)
 	mappedString, err := s.proc.Decode(tokenSlice)
 	if err != nil {
 		return err
 	}
-	fmt.Println("mappedString", mappedString)
+	// fmt.Println("mappedString", mappedString)
 
 	// TODO: should force closing for all braces - not doing square yet
 	for _, r := range mappedString {
@@ -183,23 +186,25 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
 
 	for _, tokenID := range tokenSlice {
 		// transition to the next node
-		nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID]
+		nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
 		if !ok {
 			return fmt.Errorf("invalid token: %q", mappedString)
 		}
 		// fmt.Println("transitioning to", nextNodeState)
 
 		// TODO: add a penalty for staying in the same state too long
-		if nextNodeState == s.curNode.State {
+		if nextNode.State == s.curNode.State {
 			s.stateCounter++
 		} else {
 			s.stateCounter = 0
 		}
-		s.curNode = s.stateToNodeMap[nextNodeState]
+		s.curNode = nextNode
+		// fmt.Println("updated curNode state", s.curNode.State)
 	}
 	return nil
 }
 
+// greedy sample + backtrack?
 func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
 	// TODO: can be optimized by only masking the logits that are not in the node.MaskTokenIDToNode
 	// Should be possible through bitwise ops as well

+ 125 - 34
sample/structured_outputs.go

@@ -1,54 +1,145 @@
 package sample
 
-import "github.com/ollama/ollama/model"
+import (
+	"fmt"
+	"runtime"
+	"time"
 
-type StructuredOutput struct {
-	schema         *Schema
-	stateToNodeMap map[JSONState]*PDANode
+	"github.com/ollama/ollama/model"
+)
+
+type SOSampler struct {
+	schema       *Schema
+	propIdx      int
+	propStateMap map[string]*PDANode
+	pdaSampler   *PushdownSampler
 }
 
-func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *StructuredOutput {
-	_, stateToNodeMap, err := BuildGraph(proc)
-	if err != nil {
-		panic(err)
+func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) {
+	pdaSampler := NewPushdownSampler(proc)
+
+	so := &SOSampler{
+		schema:       schema,
+		propIdx:      -1,
+		propStateMap: make(map[string]*PDANode),
+		pdaSampler:   pdaSampler,
 	}
 
-	return &StructuredOutput{
-		schema:         schema,
-		stateToNodeMap: stateToNodeMap,
+	so.schemaToGraph()
+
+	vocab := proc.GetVocabulary()
+	decodedToks := make([]string, len(vocab.Values))
+	for i := range vocab.Values {
+		token, err := proc.Decode([]int32{int32(i)})
+		if err != nil {
+			return nil, err
+		}
+		decodedToks[i] = token
 	}
-}
 
-func (so *StructuredOutput) schemaToGraph(proc model.TextProcessor) *PDANode {
+	fmt.Println("--------------------------------")
+	fmt.Println("SOSampler")
+	fmt.Println("--------------------------------")
+	// Benchmark this section
+	start := time.Now()
+	var m runtime.MemStats
+	runtime.ReadMemStats(&m)
+	before := m.Alloc
 
-	schemaType := so.schema.EffectiveType()
+	// TODO: still messed up
+	for _, node := range so.propStateMap {
+		// propName -> node
+		curState := node.State
+		fromNode := node
+		CreateMask(fromNode, proc, decodedToks, vocab)
+		for curState == StateInStructuredKey {
+			// there is only one edge
+			for r, toNode := range fromNode.TransitionEdges {
+				// fmt.Println("rune", r, "edge", toNode.State)
+				CreateMask(toNode, proc, decodedToks, vocab)
+				fmt.Printf("created mask for %c\n", r)
+				curState = toNode.State
+				fmt.Println("next state", curState)
+				// TODO: theres an extra gen for " right now
+				fromNode = toNode
+			}
+		}
+	}
+
+	runtime.ReadMemStats(&m)
+	after := m.Alloc
+	fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
+	fmt.Printf("Mask creation time = %v\n", time.Since(start))
+	fmt.Println("--------------------------------")
+
+	return so, nil
+}
+
+func (s *SOSampler) schemaToGraph() {
+	schemaType := s.schema.EffectiveType()
 	switch schemaType {
 	case "object":
-		// each prop is a key
+		// TODO: see if we need to connect these to the JSON graph
 		// prevState := StateInObjectKey
-		for _, prop := range so.schema.Properties {
+		// prevNode := so.stateToNodeMap[prevState]
+
+		// each prop is a key
+		for _, prop := range s.schema.Properties {
 			// name of key
 			name := prop.Name
-			prevState := StateInObjectKey
-			for i, r := range name {
-				newState := JSONState(int(StateInObjectKey) + i + 1) // Create new unique state for each rune
-
-				// Create new node for this state if it doesn't exist
-				if _, exists := so.stateToNodeMap[newState]; !exists {
-					so.stateToNodeMap[newState] = &PDANode{
-						State:             newState,
-						TransitionEdges:   make(map[rune]*PDANode),
-						MaskTokenIDToNode: make(map[int32]JSONState),
-					}
-				}
+			// prevState := StateInObjectKey
+			keyNode := &PDANode{
+				State:             StateInStructuredKey, // this is unchanging, will impact sampling
+				TransitionEdges:   make(map[rune]*PDANode),
+				MaskTokenIDToNode: make(map[int32]*PDANode),
+			}
 
-				// Connect previous state to this state via the rune
-				so.stateToNodeMap[prevState].TransitionEdges[r] = so.stateToNodeMap[newState]
-				prevState = newState
+			prevNode := keyNode
+			for _, r := range name {
+				runeNode := &PDANode{
+					State:             StateInStructuredKey, // this is unchanging, will impact sampling
+					TransitionEdges:   make(map[rune]*PDANode),
+					MaskTokenIDToNode: make(map[int32]*PDANode),
+				}
+				fmt.Println("runeNode created", runeNode.State)
+				fmt.Printf("runeNode created %c\n", r)
+				// since alloc on heap connections wil still map
+				prevNode.TransitionEdges[r] = runeNode
+				prevNode = runeNode
 			}
-			// type of value
-			// propType := prop.Type
+			// point to end of object key node after all chars are done
+			prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
+			// points to start of the key
+			s.propStateMap[name] = keyNode
+			fmt.Println("name", name, "keyNode", keyNode.State)
 		}
 	}
-	return nil
+}
+
+func (s *SOSampler) Sample(logits []float64) ([]float64, error) {
+	switch s.pdaSampler.curNode.State {
+	// doesnt account for multi rune case
+	case StateInObjectKey:
+		// fmt.Println("in object key - structured outputs")
+		// TODO: this tracking should probably be coming from a stack to track nested objects
+		// simple case
+		s.propIdx++
+		prop := s.schema.Properties[s.propIdx]
+		// fmt.Println("prop", prop.Name)
+		s.pdaSampler.curNode = s.propStateMap[prop.Name]
+		// fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
+		logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
+		if err != nil {
+			return nil, err
+		}
+		return logits, nil
+
+	default:
+		return s.pdaSampler.Sample(logits)
+	}
+
+}
+
+func (s *SOSampler) UpdateState(tokenSlice []int32) error {
+	return s.pdaSampler.UpdateState(tokenSlice)
 }