Browse Source

runner.go: Check for incomplete UTF-8 character

Generated text can contain a partial multi-byte Unicode character at
the end. Check for this and hold it over until the next token is
produced.
Jesse Gross 8 months ago
parent
commit
90d25d3b0a
1 changed files with 35 additions and 0 deletions
  1. 35 0
      llama/runner/runner.go

+ 35 - 0
llama/runner/runner.go

@@ -167,6 +167,36 @@ func (s *Server) shiftContext(seqIndex int) {
 	seq.nPast -= numDiscard
 	seq.nPast -= numDiscard
 }
 }
 
 
+func incompleteUnicode(token string) bool {
+	incomplete := false
+
+	// check if there is incomplete UTF-8 character at the end
+	for i := 1; i < 5 && i <= len(token); i++ {
+		c := token[len(token)-i]
+
+		if (c & 0xc0) == 0x80 {
+			// continuation byte: 10xxxxxx
+			continue
+		}
+
+		if (c & 0xe0) == 0xc0 {
+			// 2-byte character: 110xxxxx ...
+			incomplete = i < 2
+		} else if (c & 0xf0) == 0xe0 {
+			// 3-byte character: 1110xxxx ...
+			incomplete = i < 3
+		} else if (c & 0xf8) == 0xf0 {
+			// 4-byte character: 11110xxx ...
+			incomplete = i < 4
+		}
+
+		// else 1-byte character or invalid byte
+		break
+	}
+
+	return incomplete
+}
+
 func (s *Server) run(ctx context.Context) {
 func (s *Server) run(ctx context.Context) {
 	// TODO - should this be n_ctx / parallel like the old server.cpp setup?
 	// TODO - should this be n_ctx / parallel like the old server.cpp setup?
 	batch := llama.NewBatch(s.batchSize, 0, s.parallel)
 	batch := llama.NewBatch(s.batchSize, 0, s.parallel)
@@ -296,6 +326,11 @@ func (s *Server) run(ctx context.Context) {
 
 
 				pieces[i] = append(pieces[i], piece)
 				pieces[i] = append(pieces[i], piece)
 				sequence := strings.Join(pieces[i], "")
 				sequence := strings.Join(pieces[i], "")
+
+				if incompleteUnicode(sequence) {
+					continue
+				}
+
 				if ok, stop := findStop(sequence, seq.stop); ok {
 				if ok, stop := findStop(sequence, seq.stop); ok {
 					slog.Info("hit stop token", "stop", seq.stop)
 					slog.Info("hit stop token", "stop", seq.stop)