Pārlūkot izejas kodu

truncate stop properly

jmorganca 11 mēneši atpakaļ
vecāks
revīzija
72ff94efe0
2 mainītis faili ar 102 papildinājumiem un 25 dzēšanām
  1. 53 25
      llama/runner/runner.go
  2. 49 0
      llama/runner/runner_test.go

+ 53 - 25
llama/runner/runner.go

@@ -94,7 +94,7 @@ func (s *Server) allNil() bool {
 	return true
 }
 
-func contains(sequence string, stops []string) (bool, string) {
+func findStop(sequence string, stops []string) (bool, string) {
 	for _, stop := range stops {
 		if strings.Contains(sequence, stop) {
 			return true, stop
@@ -104,9 +104,9 @@ func contains(sequence string, stops []string) (bool, string) {
 	return false, ""
 }
 
-func overlap(sequence string, stops []string) bool {
+func containsStopSuffix(sequence string, stops []string) bool {
 	for _, stop := range stops {
-		for i := 1; i < len(stop); i++ {
+		for i := 1; i <= len(stop); i++ {
 			if strings.HasSuffix(sequence, stop[:i]) {
 				return true
 			}
@@ -116,13 +116,50 @@ func overlap(sequence string, stops []string) bool {
 	return false
 }
 
+// truncateStop removes the provided stop string from pieces,
+// returning the partial pieces with stop removed, including truncating
+// the last piece if required
+func truncateStop(pieces []string, stop string) []string {
+	joined := strings.Join(pieces, "")
+
+	index := strings.Index(joined, stop)
+	if index == -1 {
+		return pieces
+	}
+
+	joined = joined[:index]
+
+	// Split truncated string back into pieces of original lengths
+	lengths := make([]int, len(pieces))
+	for i, piece := range pieces {
+		lengths[i] = len(piece)
+	}
+
+	var result []string
+	start := 0
+	for _, length := range lengths {
+		if start >= len(joined) {
+			break
+		}
+
+		end := start + length
+		if end > len(joined) {
+			end = len(joined)
+		}
+		result = append(result, joined[start:end])
+		start = end
+	}
+
+	return result
+}
+
 func (s *Server) run(ctx context.Context) {
 	batch := llama.NewBatch(512, 0, s.parallel)
 	defer batch.Free()
 
 	// build up stop sequences as we recognize them
 	// TODO (jmorganca): simplify this
-	sofar := make([][]string, s.parallel)
+	pieces := make([][]string, s.parallel)
 
 	for {
 		select {
@@ -214,50 +251,41 @@ func (s *Server) run(ctx context.Context) {
 
 					close(seq.responses)
 					seq.samplingCtx.Free()
-					sofar[i] = []string{}
+					pieces[i] = []string{}
 					s.seqs[i] = nil
 					continue
 				}
 
 				seq.tokens = []int{token}
 
-				// recognize stop sequences
-				// TODO (jmorganca): add tests around this
-				// TODO (jmorganca): send back parital piece
-
-				sequence := strings.Join(append(sofar[i], piece), "")
-				if ok, stop := contains(sequence, seq.stop); ok {
+				pieces[i] = append(pieces[i], piece)
+				sequence := strings.Join(pieces[i], "")
+				if ok, stop := findStop(sequence, seq.stop); ok {
 					slog.Info("hit stop token", "stop", seq.stop)
-					for _, p := range sofar[i] {
+
+					truncated := truncateStop(pieces[i], stop)
+
+					for _, p := range truncated {
 						seq.responses <- p
 					}
 
-					piece, _, _ := strings.Cut(piece, stop)
-					seq.responses <- piece
-
 					s.lc.KvCacheSeqRm(i, 0, -1)
 					close(seq.responses)
 					seq.samplingCtx.Free()
-					sofar[i] = []string{}
+					pieces[i] = []string{}
 					s.seqs[i] = nil
 					continue
 				}
 
-				if overlap(sequence, seq.stop) {
-					slog.Info("overlap", "sequence", sequence)
-					// partial stop, don't send
+				if containsStopSuffix(sequence, seq.stop) {
 					continue
 				}
 
-				slog.Info("sending", "sofar", sofar[i])
-
-				sofar[i] = append(sofar[i], piece)
-
-				for _, p := range sofar[i] {
+				for _, p := range pieces[i] {
 					seq.responses <- p
 				}
 
-				sofar[i] = []string{}
+				pieces[i] = []string{}
 			}
 
 			batch.Clear()

+ 49 - 0
llama/runner/runner_test.go

@@ -0,0 +1,49 @@
+package main
+
+import (
+	"reflect"
+	"testing"
+)
+
+func TestTruncateStop(t *testing.T) {
+	tests := []struct {
+		name     string
+		pieces   []string
+		stop     string
+		expected []string
+	}{
+		{
+			name:     "Single word",
+			pieces:   []string{"hello", "world"},
+			stop:     "world",
+			expected: []string{"hello"},
+		},
+		{
+			name:     "Partial",
+			pieces:   []string{"hello", "wor"},
+			stop:     "or",
+			expected: []string{"hello", "w"},
+		},
+		{
+			name:     "Suffix",
+			pieces:   []string{"Hello", " there", "!"},
+			stop:     "!",
+			expected: []string{"Hello", " there"},
+		},
+		{
+			name:     "Middle",
+			pieces:   []string{"hello", " wor"},
+			stop:     "llo w",
+			expected: []string{"he"},
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			result := truncateStop(tt.pieces, tt.stop)
+			if !reflect.DeepEqual(result, tt.expected) {
+				t.Errorf("truncateStop(%v, %s): have %v; want %v", tt.pieces, tt.stop, result, tt.expected)
+			}
+		})
+	}
+}