ParthSareen 1 month ago
parent
commit
4450f871db
2 changed files with 25 additions and 12 deletions
  1. 11 11
      runner/ollamarunner/runner.go
  2. 14 1
      sample/structured_python.go

+ 11 - 11
runner/ollamarunner/runner.go

@@ -582,16 +582,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	// 	return
 	// }
 	// jsonSampler = nil
-	// pythonSampler := sample.NewPythonSampler(s.model.(model.TextProcessor), nil)
-	// pythonSampler := &sample.PythonSampler{}
-	// functions := []sample.PythonFunction{
-	// 	{
-	// 		Name:  "add_two_strings",
-	// 		Args:  []string{"s1", "s2"},
-	// 		Types: []string{"string", "string"},
-	// 	},
-	// }
-	// pythonSampler.Init(functions, s.model.(model.TextProcessor))
+	pythonSampler := &sample.PythonSampler{}
+	functions := []sample.PythonFunction{
+		{
+			Name:  "add_two_strings",
+			Args:  []string{"s1", "s2"},
+			Types: []string{"string", "string"},
+		},
+	}
+	pythonSampler.Init(functions, s.model.(model.TextProcessor))
 	sampler := sample.NewSampler(
 		req.Options.Temperature,
 		req.Options.TopK,
@@ -600,7 +599,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 		req.Options.Seed,
 		grammar,
 		nil,
-		nil,
+		pythonSampler,
+		// nil,
 	)
 
 	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{

+ 14 - 1
sample/structured_python.go

@@ -108,10 +108,13 @@ type PythonSampler struct {
 	proc         model.TextProcessor
 	decodedToks  []string
 	curNode      *Node
+	completed    int
+	functions    []PythonFunction
 }
 
 func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
 	s.proc = proc
+	s.functions = functions
 	decodedToks := make([]string, len(proc.Vocab().Values))
 	for i := range proc.Vocab().Values {
 		token, err := proc.Decode([]int32{int32(i)})
@@ -194,7 +197,7 @@ func (s *PythonSampler) BuildGraph() error {
 
 	// String end
 	s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
-	s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
+	// s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
 	// Number
 	for _, r := range validNumberRunes {
 		s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
@@ -237,6 +240,16 @@ func (s *PythonSampler) UpdateState(token int32) error {
 	if !ok {
 		return fmt.Errorf("invalid token: %q", mappedString)
 	}
+
+	if mappedString == "\"" {
+		if s.curNode.State == PStateInStringEnd {
+			s.completed++
+		}
+		if s.completed == len(s.functions) {
+			s.curNode.TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
+			s.CreateMask(s.curNode)
+		}
+	}
 	s.curNode = nextNode
 	fmt.Println("curNode", s.curNode.State)
 	for r, node := range s.curNode.TransitionEdges {