Browse Source

proper utf16 support

Michael Yang 11 months ago
parent
commit
66ab48772f
1 changed files with 58 additions and 32 deletions
  1. 58 32
      parser/parser.go

+ 58 - 32
parser/parser.go

@@ -3,12 +3,15 @@ package parser
 import (
 	"bufio"
 	"bytes"
+	"encoding/binary"
 	"errors"
 	"fmt"
 	"io"
+	"log/slog"
 	"strconv"
 	"strings"
-	"unicode"
+	"unicode/utf16"
+	"unicode/utf8"
 )
 
 type File struct {
@@ -69,31 +72,29 @@ func ParseFile(r io.Reader) (*File, error) {
 	var b bytes.Buffer
 	var role string
 
-	var lineCount int
-	var linePos int
-
-	var utf16 bool
-
 	var f File
 
 	br := bufio.NewReader(r)
-	for {
-		r, _, err := br.ReadRune()
-		if errors.Is(err, io.EOF) {
-			break
-		} else if err != nil {
-			return nil, err
-		}
 
-		// the utf16 byte order mark will be read as "unreadable" by ReadRune()
-		if isUnreadable(r) && lineCount == 0 && linePos == 0 {
-			utf16 = true
-			continue
-		}
+	var sc scannerDecoder = utf8ScannerDecoder{}
+	if bom, err := br.Peek(2); err != nil {
+		slog.Warn("error reading byte-order mark", "error", err)
+	} else if bytes.Equal(bom, []byte{0xFE, 0xFF}) {
+		sc = utf16ScannerDecoder{binary.LittleEndian}
+		//nolint:errcheck
+		br.Discard(2)
+	} else if bytes.Equal(bom, []byte{0xFF, 0xFE}) {
+		sc = utf16ScannerDecoder{binary.BigEndian}
+		//nolint:errcheck
+		br.Discard(2)
+	}
 
-		// skip the second byte if we're reading utf16
-		if utf16 && r == 0 {
-			continue
+	scanner := bufio.NewScanner(br)
+	scanner.Split(sc.ScanBytes)
+	for scanner.Scan() {
+		r, err := sc.DecodeRune(scanner.Bytes())
+		if err != nil {
+			return nil, err
 		}
 
 		next, r, err := parseRuneForState(r, curr)
@@ -103,13 +104,6 @@ func ParseFile(r io.Reader) (*File, error) {
 			return nil, err
 		}
 
-		if isNewline(r) {
-			lineCount++
-			linePos = 0
-		} else {
-			linePos++
-		}
-
 		// process the state transition, some transitions need to be intercepted and redirected
 		if next != curr {
 			switch curr {
@@ -309,10 +303,6 @@ func isNewline(r rune) bool {
 	return r == '\r' || r == '\n'
 }
 
-func isUnreadable(r rune) bool {
-	return r == unicode.ReplacementChar
-}
-
 func isValidMessageRole(role string) bool {
 	return role == "system" || role == "user" || role == "assistant"
 }
@@ -325,3 +315,39 @@ func isValidCommand(cmd string) bool {
 		return false
 	}
 }
+
+type scannerDecoder interface {
+	ScanBytes(data []byte, atEOF bool) (advance int, token []byte, err error)
+	DecodeRune([]byte) (rune, error)
+}
+
+type utf8ScannerDecoder struct{}
+
+func (utf8ScannerDecoder) ScanBytes(data []byte, atEOF bool) (advance int, token []byte, err error) {
+	return scanBytesN(data, 1, atEOF)
+}
+
+func (utf8ScannerDecoder) DecodeRune(data []byte) (rune, error) {
+	r, _ := utf8.DecodeRune(data)
+	return r, nil
+}
+
+type utf16ScannerDecoder struct {
+	binary.ByteOrder
+}
+
+func (utf16ScannerDecoder) ScanBytes(data []byte, atEOF bool) (advance int, token []byte, err error) {
+	return scanBytesN(data, 2, atEOF)
+}
+
+func (e utf16ScannerDecoder) DecodeRune(data []byte) (rune, error) {
+	return utf16.Decode([]uint16{e.ByteOrder.Uint16(data)})[0], nil
+}
+
+func scanBytesN(data []byte, n int, atEOF bool) (int, []byte, error) {
+	if atEOF && len(data) == 0 {
+		return 0, nil, nil
+	}
+
+	return n, data[:n], nil
+}