Browse Source

saved state

ParthSareen 3 months ago
parent
commit
b973dedb4b
7 changed files with 128 additions and 43 deletions
  1. 2 2
      go.mod
  2. 33 9
      model/cmd/main.go
  3. 5 4
      model/process_text.go
  4. 30 22
      sample/pushdown_automata.go
  5. 11 2
      sample/pushdown_runner.go
  6. 3 2
      sample/sample.go
  7. 44 2
      sample/structured_outputs.go

+ 2 - 2
go.mod

@@ -24,8 +24,8 @@ require (
 	github.com/nlpodyssey/gopickle v0.3.0
 	github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
 	golang.org/x/image v0.22.0
-	gonum.org/v1/gonum v0.15.0
 	golang.org/x/tools v0.28.0
+	gonum.org/v1/gonum v0.15.0
 )
 
 require (
@@ -72,7 +72,7 @@ require (
 	golang.org/x/arch v0.8.0 // indirect
 	golang.org/x/crypto v0.31.0
 	golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
-	golang.org/x/net v0.25.0 // indirect
+	golang.org/x/net v0.32.0 // indirect
 	golang.org/x/sys v0.28.0
 	golang.org/x/term v0.27.0
 	golang.org/x/text v0.21.0

+ 33 - 9
model/cmd/main.go

@@ -10,6 +10,7 @@ import (
 	"os"
 	"path/filepath"
 	"strings"
+	"time"
 
 	"github.com/ollama/ollama/cache"
 	"github.com/ollama/ollama/ml"
@@ -27,6 +28,7 @@ var args struct {
 }
 
 func temp() error {
+	start := time.Now()
 	flag.IntVar(&args.n, "n", 10, "number of samples")
 	flag.BoolVar(&args.debug, "debug", false, "enable debug logging")
 	flag.StringVar(&args.image, "image", "", "path to image file")
@@ -104,9 +106,11 @@ func temp() error {
 		}
 	}
 
-	pdaSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
-	var stringBuffer string
+	pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
+
 	var offset int
+	var stringBuffer string
+	var firstTokenTime time.Duration
 	for range args.n {
 		logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
 		if err != nil {
@@ -118,15 +122,21 @@ func temp() error {
 		for i, f32 := range f32s {
 			f64s[i] = float64(f32)
 		}
+		sampleTime := time.Now()
+		samplers := []sample.Sampler{
+			pushdownSampler,
+			// sample.Weighed(),
+			// sample.TopP(0.9),
+			// sample.Weighed(),
+			sample.Greedy(),
+		}
 
-		// do sampling
-		// []ints back
-		// ints map to sampled logits
-		f64s, err = sample.Sample(f64s, pdaSampler, sample.Greedy())
-
+		f64s, err = sample.Sample(f64s, samplers...)
 		if err != nil {
 			return err
 		}
+		finishTime := time.Now()
+		fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
 
 		var outputIDs []int32
 		for _, f64 := range f64s {
@@ -134,7 +144,6 @@ func temp() error {
 				outputIDs = append(outputIDs, int32(f64))
 			}
 		}
-		pdaSampler.UpdateState(outputIDs)
 
 		if len(outputIDs) == 0 {
 			break
@@ -147,14 +156,29 @@ func temp() error {
 			return err
 		}
 
-		// fmt.Print(s)
+		if firstTokenTime == 0 {
+			firstTokenTime = time.Since(start)
+			fmt.Printf("Time to first token: %vms\n", firstTokenTime.Milliseconds())
+		}
+
+		// fmt.Printf("--- token: %q\n", s)
+		// fmt.Printf("--- outputIDs: %v\n", outputIDs)
 		stringBuffer += s
 		fmt.Println("--- stringBuffer", stringBuffer)
+
+		err = pushdownSampler.UpdateState(outputIDs)
+		if err != nil {
+			return err
+		}
+
 		inputIDs = append(inputIDs, outputIDs...)
 		if args.cache {
 			offset = len(inputIDs) - 1
 		}
 	}
+	fmt.Println("\n------ Output: ------")
+	fmt.Println(stringBuffer)
+	fmt.Println("--------------------")
 
 	return nil
 }

+ 5 - 4
model/process_text.go

@@ -21,6 +21,7 @@ type TextProcessor interface {
 	Encode(string) ([]int32, error)
 	Decode([]int32) (string, error)
 	Is(uint32, Special) bool
+
 	GetVocabulary() *Vocabulary
 }
 
@@ -99,16 +100,16 @@ func (v *Vocabulary) Merge(left, right string) int {
 	return -1
 }
 
+func (v *Vocabulary) GetVocabulary() *Vocabulary {
+	return v
+}
+
 type BytePairEncoding struct {
 	Pretokenizer string
 
 	*Vocabulary
 }
 
-func (bpe BytePairEncoding) GetVocabulary() *Vocabulary {
-	return bpe.Vocabulary
-}
-
 func (bpe BytePairEncoding) split(s string) ([]string, error) {
 	re, err := regexp2.Compile(bpe.Pretokenizer, regexp2.Unicode|regexp2.RE2)
 	if err != nil {

+ 30 - 22
sample/pushdown_automata.go

@@ -44,8 +44,6 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
 	// consider adding a node to just point to values, could be good to compute that
 	// mask rather than many different nodes
 
-	// Connect nodes
-	// TODO: if all are single tokens then this can just be connected instead of defining the token
 	stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
 	stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
 
@@ -161,6 +159,7 @@ func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
 	node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
 }
 
+// TODO: tough life fr. plz fix.
 func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
 
 	vocab := proc.GetVocabulary()
@@ -176,33 +175,42 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
 
 	var err error
 	for _, node := range stateToNodeMap {
-		for i := range vocab.Values {
-			token := decodedToks[i]
-			// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
-			if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
-				continue
-			}
-			valid := true
-			curNode := node
-			consumedSpecialRunes := make(map[rune]bool)
-			for _, r := range token {
-				valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
-				if err != nil {
-					return err
-				}
-				if !valid {
-					break
-				}
+		err = createMask(node, proc, decodedToks, vocab)
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func createMask(node *PDANode, proc model.TextProcessor, decodedToks []string, vocab *model.Vocabulary) error {
+	for i := range vocab.Values {
+		token := decodedToks[i]
+		// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
+		if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
+			continue
+		}
+		valid := true
+		curNode := node
+		consumedSpecialRunes := make(map[rune]bool)
+		var err error
+		for _, r := range token {
+			valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
+			if err != nil {
+				return err
 			}
-			if valid {
-				node.MaskTokenIDToNode[int32(i)] = curNode.State
+			if !valid {
+				break
 			}
 		}
+		if valid {
+			node.MaskTokenIDToNode[int32(i)] = curNode.State
+		}
 	}
 	return nil
 }
 
-// garbage interface plz fix
+// TODO: garbage interface plz fix
 func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
 	if consumedSpecialRunes[r] {
 		return false, nil, nil

+ 11 - 2
sample/pushdown_runner.go

@@ -52,6 +52,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
 	}
 }
 
+// TODO: need to add resampling logic if the first sample was not good
 func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
 	// fmt.Println(">>> sample:", s.curNode.State)
 	switch s.curNode.State {
@@ -156,8 +157,11 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
 			// fmt.Println("pushing [ brace stack", r)
 		}
 		if r == rune('}') {
+			if len(s.braceStack) == 0 {
+				return fmt.Errorf("stack is empty, extra closing brace %c", r)
+			}
 			top := s.braceStack[len(s.braceStack)-1]
-			if len(s.braceStack) == 0 || top != rune('{') {
+			if top != rune('{') {
 				return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
 			}
 			s.braceStack = s.braceStack[:len(s.braceStack)-1]
@@ -165,8 +169,11 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
 		}
 
 		if r == rune(']') {
+			if len(s.braceStack) == 0 {
+				return fmt.Errorf("stack is empty, extra closing brace %c", r)
+			}
 			top := s.braceStack[len(s.braceStack)-1]
-			if len(s.braceStack) == 0 || top != rune('[') {
+			if top != rune('[') {
 				return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
 			}
 			s.braceStack = s.braceStack[:len(s.braceStack)-1]
@@ -194,6 +201,8 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
 }
 
 func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
+	// TODO: can be optimized by only masking the logits that are not in the node.MaskTokenIDToNode
+	// Should be possible through bitwise ops as well
 	for i := range logits {
 		_, exists := node.MaskTokenIDToNode[int32(i)]
 		if !exists {

+ 3 - 2
sample/sample.go

@@ -165,11 +165,12 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
 	if len(logitsCopy) == 0 {
 		return nil, errors.New("no valid tokens found")
 	}
-	logitsCopy, err := computeSoftmax(logitsCopy)
+
+	softmax, err := computeSoftmax(logitsCopy)
 	if err != nil {
 		return nil, err
 	}
-	w := sampleuv.NewWeighted(logitsCopy, nil)
+	w := sampleuv.NewWeighted(softmax, nil)
 	if v, ok := w.Take(); ok {
 		// returns the token ID
 		return []float64{float64(indices[v])}, nil

+ 44 - 2
sample/structured_outputs.go

@@ -3,10 +3,52 @@ package sample
 import "github.com/ollama/ollama/model"
 
 type StructuredOutput struct {
-	schema *Schema
+	schema         *Schema
+	stateToNodeMap map[JSONState]*PDANode
 }
 
-func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *PDANode {
+func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *StructuredOutput {
+	_, stateToNodeMap, err := BuildGraph(proc)
+	if err != nil {
+		panic(err)
+	}
 
+	return &StructuredOutput{
+		schema:         schema,
+		stateToNodeMap: stateToNodeMap,
+	}
+}
+
+func (so *StructuredOutput) schemaToGraph(proc model.TextProcessor) *PDANode {
+
+	schemaType := so.schema.EffectiveType()
+	switch schemaType {
+	case "object":
+		// each prop is a key
+		// prevState := StateInObjectKey
+		for _, prop := range so.schema.Properties {
+			// name of key
+			name := prop.Name
+			prevState := StateInObjectKey
+			for i, r := range name {
+				newState := JSONState(int(StateInObjectKey) + i + 1) // Create new unique state for each rune
+
+				// Create new node for this state if it doesn't exist
+				if _, exists := so.stateToNodeMap[newState]; !exists {
+					so.stateToNodeMap[newState] = &PDANode{
+						State:             newState,
+						TransitionEdges:   make(map[rune]*PDANode),
+						MaskTokenIDToNode: make(map[int32]JSONState),
+					}
+				}
+
+				// Connect previous state to this state via the rune
+				so.stateToNodeMap[prevState].TransitionEdges[r] = so.stateToNodeMap[newState]
+				prevState = newState
+			}
+			// type of value
+			// propType := prop.Type
+		}
+	}
 	return nil
 }