Browse Source

first pass so working

ParthSareen 3 months ago
parent
commit
25edfa6fdb
4 changed files with 107 additions and 46 deletions
  1. 28 23
      model/cmd/main.go
  2. 7 3
      sample/pushdown_automata.go
  3. 16 11
      sample/pushdown_runner.go
  4. 56 9
      sample/structured_outputs.go

+ 28 - 23
model/cmd/main.go

@@ -106,8 +106,6 @@ func temp() error {
 		}
 	}
 
-	// pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
-
 	// simple schema
 	// This schema maps to JSON like:
 	// {
@@ -119,9 +117,12 @@ func temp() error {
 		Properties: []*sample.Schema{
 			{Name: "name", Type: "string"},
 			{Name: "age", Type: "integer"},
+			{Name: "is_student", Type: "boolean"},
+			// {Name: "is_student", Type: "boolean"},
 		},
 	}
 
+	// pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
 	pushdownSampler, err := sample.NewSOSampler(schema, m.(model.TextProcessor))
 	if err != nil {
 		return err
@@ -129,44 +130,47 @@ func temp() error {
 
 	var offset int
 	var stringBuffer string
-	var firstTokenTime time.Duration
+	var ttft time.Duration
 	var totalSamplingTime time.Duration
 	count := 0
 	for range args.n {
-		logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
+		logits, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
 		if err != nil {
 			return err
 		}
 
-		f32s := logit.Floats()
-		f64s := make([]float64, len(f32s))
-		for i, f32 := range f32s {
-			f64s[i] = float64(f32)
-		}
-		samplers := []sample.Sampler{
+		// f64s := make([]float64, len(f32s))
+		// for i, f32 := range f32s {
+		// 	f64s[i] = float64(f32)
+		// }
+		// samplers := []sample.Transform{
+		// pushdownSampler,
+		// sample.Weighed(),
+		// sample.TopP(0.9),
+		// sample.Weighed(),
+		// sample.Greedy(),
+		// }
+		transforms := []sample.Transform{
 			pushdownSampler,
-			// sample.Weighed(),
-			// sample.TopP(0.9),
-			// sample.Weighed(),
-			sample.Greedy(),
 		}
 
 		samplingStart := time.Now()
-		f64s, err = sample.Sample(f64s, samplers...)
+		sampler := sample.NewSampler(transforms, sample.Greedy())
+		sampledIdx, err := sampler.Sample(logits.Floats())
 		if err != nil {
 			return err
 		}
+
 		samplingTime := time.Since(samplingStart)
 		totalSamplingTime += samplingTime
 
-		// fmt.Println("sampling time", samplingTime)
+		fmt.Println("sampling time", samplingTime)
 		// fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
 
 		var outputIDs []int32
-		for _, f64 := range f64s {
-			if !m.(model.TextProcessor).Is(uint32(f64), model.SpecialEOS) {
-				outputIDs = append(outputIDs, int32(f64))
-			}
+
+		if !m.(model.TextProcessor).Is(uint32(sampledIdx), model.SpecialEOS) {
+			outputIDs = append(outputIDs, int32(sampledIdx))
 		}
 
 		if len(outputIDs) == 0 {
@@ -180,9 +184,9 @@ func temp() error {
 			return err
 		}
 
-		if firstTokenTime == 0 {
-			firstTokenTime = time.Since(start)
-			fmt.Printf("Time to first token: %vms\n", firstTokenTime.Milliseconds())
+		if ttft == 0 {
+			ttft = time.Since(start)
+			fmt.Printf("Time to first token: %vms\n", ttft.Milliseconds())
 		}
 
 		// fmt.Printf("--- token: %q\n", s)
@@ -196,6 +200,7 @@ func temp() error {
 			return err
 		}
 
+		// can do fun shifting stuff here if needed
 		inputIDs = append(inputIDs, outputIDs...)
 		if args.cache {
 			offset = len(inputIDs) - 1

+ 7 - 3
sample/pushdown_automata.go

@@ -54,6 +54,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
 	//new line
 	stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
 	stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
+	stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
 
 	stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
 
@@ -76,6 +77,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
 	stateToNodeMap[StateInSpace].TransitionEdges['['] = stateToNodeMap[StateInList]
 	stateToNodeMap[StateInSpace].TransitionEdges['{'] = stateToNodeMap[StateInObject]
 	addValueConnections(stateToNodeMap[StateInSpace], stateToNodeMap)
+	stateToNodeMap[StateInSpace].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
 
 	// Values
 	// string node
@@ -97,6 +99,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
 		stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
 	}
 	addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
+	stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpace]
 
 	// list node
 	stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
@@ -128,6 +131,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
 	for _, r := range validBoolRunes {
 		stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
 	}
+	stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
 	addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
 
 	stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
@@ -178,7 +182,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
 
 	var err error
 	for _, node := range stateToNodeMap {
-		err = CreateMask(node, proc, decodedToks, vocab)
+		err = CreateMask(node, proc, decodedToks)
 		if err != nil {
 			return err
 		}
@@ -186,8 +190,8 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
 	return nil
 }
 
-func CreateMask(node *PDANode, proc model.TextProcessor, decodedToks []string, vocab *model.Vocabulary) error {
-	for i := range vocab.Values {
+func CreateMask(node *PDANode, proc model.TextProcessor, decodedToks []string) error {
+	for i := range decodedToks {
 		token := decodedToks[i]
 		// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
 		if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {

+ 16 - 11
sample/pushdown_runner.go

@@ -57,7 +57,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
 
 // TODO: need to add resampling logic if the first sample was not good
 // greedy sample + backtrack?
-func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
+func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
 	switch s.curNode.State {
 	case StateInString:
 		return s.maskLogits(logits, s.curNode)
@@ -70,7 +70,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
 				if s.proc.Is(uint32(i), model.SpecialEOS) {
 					logits[i] = 1.0
 				} else {
-					logits[i] = math.NaN()
+					logits[i] = math.Inf(-1)
 				}
 			}
 			return logits, nil
@@ -90,7 +90,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
 				if s.proc.Is(uint32(i), model.SpecialEOS) {
 					logits[i] = 1.0
 				} else {
-					logits[i] = math.NaN()
+					logits[i] = math.Inf(-1)
 				}
 			}
 			return logits, nil
@@ -123,7 +123,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
 			if s.proc.Is(uint32(i), model.SpecialEOS) {
 				logits[i] = 1.0
 			} else {
-				logits[i] = math.NaN()
+				logits[i] = math.Inf(-1)
 			}
 		}
 		return logits, nil
@@ -199,15 +199,20 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
 
 // greedy sample + backtrack?
 func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
-	// TODO: can be optimized by only masking the logits that are not in the node.MaskTokenIDToNode
-	// Should be possible through bitwise ops as well
-	for i := range logits {
-		_, exists := node.MaskTokenIDToNode[int32(i)]
-		if !exists {
-			logits[i] = math.NaN()
+	// Create a new slice with same length as logits, initialized to -Inf
+	maskedLogits := make([]float64, len(logits))
+	for i := range maskedLogits {
+		maskedLogits[i] = math.Inf(-1)
+	}
+
+	// Only update values for valid token IDs from the mask map
+	for tokenID := range node.MaskTokenIDToNode {
+		if int(tokenID) < len(logits) {
+			maskedLogits[tokenID] = logits[tokenID]
 		}
 	}
-	return logits, nil
+
+	return maskedLogits, nil
 }
 
 // TODO: add penalties for string \n stuff

+ 56 - 9
sample/structured_outputs.go

@@ -13,6 +13,7 @@ type SOSampler struct {
 	propIdx       int
 	propToNodeMap map[string]*PDANode
 	pdaSampler    *PushdownSampler
+	decodedToks   []string
 }
 
 func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) {
@@ -27,6 +28,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
 
 	so.schemaToGraph()
 
+	// This is prob slow
 	vocab := proc.GetVocabulary()
 	decodedToks := make([]string, len(vocab.Values))
 	for i := range vocab.Values {
@@ -36,6 +38,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
 		}
 		decodedToks[i] = token
 	}
+	so.decodedToks = decodedToks
 
 	fmt.Println("--------------------------------")
 	fmt.Println("SOSampler")
@@ -47,16 +50,19 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
 	before := m.Alloc
 
 	// TODO: still messed up
-	for _, node := range so.propToNodeMap {
+	// TODO: recursion use case
+	// key masks
+	for _, prop := range so.schema.Properties {
+		node := so.propToNodeMap[prop.Name]
 		// propName -> node
 		curState := node.State
 		fromNode := node
-		CreateMask(fromNode, proc, decodedToks, vocab)
+		CreateMask(fromNode, proc, decodedToks)
 		for curState == StateInStructuredKey {
 			// there is only one edge
 			for r, toNode := range fromNode.TransitionEdges {
 				// fmt.Println("rune", r, "edge", toNode.State)
-				CreateMask(toNode, proc, decodedToks, vocab)
+				CreateMask(toNode, proc, decodedToks)
 				fmt.Printf("created mask for %c\n", r)
 				curState = toNode.State
 				fmt.Println("next state", curState)
@@ -80,14 +86,11 @@ func (s *SOSampler) schemaToGraph() {
 	switch schemaType {
 	case "object":
 		// TODO: see if we need to connect these to the JSON graph
-		// prevState := StateInObjectKey
-		// prevNode := so.stateToNodeMap[prevState]
 
 		// each prop is a key
 		for _, prop := range s.schema.Properties {
 			// name of key
 			name := prop.Name
-			// prevState := StateInObjectKey
 			keyNode := &PDANode{
 				State:             StateInStructuredKey, // this is unchanging, will impact sampling
 				TransitionEdges:   make(map[rune]*PDANode),
@@ -116,10 +119,13 @@ func (s *SOSampler) schemaToGraph() {
 	}
 }
 
-func (s *SOSampler) Sample(logits []float64) ([]float64, error) {
+func (s *SOSampler) Apply(logits []float64) ([]float64, error) {
 	switch s.pdaSampler.curNode.State {
 	// doesnt account for multi rune case
 	case StateInObjectKey:
+		if s.propIdx > len(s.schema.Properties)-1 {
+			return nil, fmt.Errorf("propIdx out of bounds")
+		}
 		// fmt.Println("in object key - structured outputs")
 		// TODO: this tracking should probably be coming from a stack to track nested objects
 		// simple case
@@ -136,11 +142,52 @@ func (s *SOSampler) Sample(logits []float64) ([]float64, error) {
 		return logits, nil
 
 	default:
-		return s.pdaSampler.Sample(logits)
+
+		// Will only happen for the last prop - can also be precomputed.
+		if s.propIdx == len(s.schema.Properties)-1 {
+			// todo: if i incremenet propidx then i know im in last value as well
+			switch s.pdaSampler.curNode.State {
+			case StateInObjectEnd:
+				fmt.Println("<<<<< in obj end- generating mask for", s.pdaSampler.curNode.State)
+				s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDANode)
+				s.pdaSampler.curNode = NewPDANode(StateTerminate)
+				s.propIdx++
+
+			case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
+				fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
+				delete(s.pdaSampler.curNode.TransitionEdges, ',')
+				s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDANode)
+
+				CreateMask(s.pdaSampler.curNode, s.pdaSampler.proc, s.decodedToks)
+				s.propIdx++
+			}
+		}
+		return s.pdaSampler.Apply(logits)
 	}
 
 }
 
 func (s *SOSampler) UpdateState(tokenSlice []int32) error {
-	return s.pdaSampler.UpdateState(tokenSlice)
+	err := s.pdaSampler.UpdateState(tokenSlice)
+	if err != nil {
+		return err
+	}
+
+	switch s.pdaSampler.curNode.State {
+	case StateInObjectKey:
+		s.propIdx++
+		fmt.Println("propIdx", s.propIdx)
+		prop := s.schema.Properties[s.propIdx]
+		fmt.Println("prop", prop.Name)
+		s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
+		str, err := s.pdaSampler.proc.Decode(tokenSlice)
+		if err != nil {
+			return err
+		}
+		fmt.Println("str", str)
+
+		return nil
+	default:
+		return nil
+	}
 }