浏览代码

Merge pull request #5031 from ollama/mxyng/fix-multibyte-utf16

fix: multibyte utf16
Michael Yang 10 月之前
父节点
当前提交
15a687ae4b
共有 2 个文件被更改,包括 54 次插入45 次删除
  1. 6 29
      parser/parser.go
  2. 48 16
      parser/parser_test.go

+ 6 - 29
parser/parser.go

@@ -8,7 +8,9 @@ import (
 	"io"
 	"strconv"
 	"strings"
-	"unicode"
+
+	"golang.org/x/text/encoding/unicode"
+	"golang.org/x/text/transform"
 )
 
 type File struct {
@@ -69,14 +71,11 @@ func ParseFile(r io.Reader) (*File, error) {
 	var b bytes.Buffer
 	var role string
 
-	var lineCount int
-	var linePos int
-
-	var utf16 bool
-
 	var f File
 
-	br := bufio.NewReader(r)
+	tr := unicode.BOMOverride(unicode.UTF8.NewDecoder())
+	br := bufio.NewReader(transform.NewReader(r, tr))
+
 	for {
 		r, _, err := br.ReadRune()
 		if errors.Is(err, io.EOF) {
@@ -85,17 +84,6 @@ func ParseFile(r io.Reader) (*File, error) {
 			return nil, err
 		}
 
-		// the utf16 byte order mark will be read as "unreadable" by ReadRune()
-		if isUnreadable(r) && lineCount == 0 && linePos == 0 {
-			utf16 = true
-			continue
-		}
-
-		// skip the second byte if we're reading utf16
-		if utf16 && r == 0 {
-			continue
-		}
-
 		next, r, err := parseRuneForState(r, curr)
 		if errors.Is(err, io.ErrUnexpectedEOF) {
 			return nil, fmt.Errorf("%w: %s", err, b.String())
@@ -103,13 +91,6 @@ func ParseFile(r io.Reader) (*File, error) {
 			return nil, err
 		}
 
-		if isNewline(r) {
-			lineCount++
-			linePos = 0
-		} else {
-			linePos++
-		}
-
 		// process the state transition, some transitions need to be intercepted and redirected
 		if next != curr {
 			switch curr {
@@ -309,10 +290,6 @@ func isNewline(r rune) bool {
 	return r == '\r' || r == '\n'
 }
 
-func isUnreadable(r rune) bool {
-	return r == unicode.ReplacementChar
-}
-
 func isValidMessageRole(role string) bool {
 	return role == "system" || role == "user" || role == "assistant"
 }

+ 48 - 16
parser/parser_test.go

@@ -11,6 +11,8 @@ import (
 
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
+	"golang.org/x/text/encoding"
+	"golang.org/x/text/encoding/unicode"
 )
 
 func TestParseFileFile(t *testing.T) {
@@ -517,14 +519,6 @@ PARAMETER param1 1
 PARAMETER param2 4096
 SYSTEM You are a utf16 file.
 `
-	// simulate a utf16 le file
-	utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...))
-	buf := new(bytes.Buffer)
-	err := binary.Write(buf, binary.LittleEndian, utf16File)
-	require.NoError(t, err)
-
-	actual, err := ParseFile(buf)
-	require.NoError(t, err)
 
 	expected := []Command{
 		{Name: "model", Args: "bob"},
@@ -533,14 +527,52 @@ SYSTEM You are a utf16 file.
 		{Name: "system", Args: "You are a utf16 file."},
 	}
 
-	assert.Equal(t, expected, actual.Commands)
+	t.Run("le", func(t *testing.T) {
+		var b bytes.Buffer
+		require.NoError(t, binary.Write(&b, binary.LittleEndian, []byte{0xff, 0xfe}))
+		require.NoError(t, binary.Write(&b, binary.LittleEndian, utf16.Encode([]rune(data))))
 
-	// simulate a utf16 be file
-	buf = new(bytes.Buffer)
-	err = binary.Write(buf, binary.BigEndian, utf16File)
-	require.NoError(t, err)
+		actual, err := ParseFile(&b)
+		require.NoError(t, err)
 
-	actual, err = ParseFile(buf)
-	require.NoError(t, err)
-	assert.Equal(t, expected, actual.Commands)
+		assert.Equal(t, expected, actual.Commands)
+	})
+
+	t.Run("be", func(t *testing.T) {
+		var b bytes.Buffer
+		require.NoError(t, binary.Write(&b, binary.BigEndian, []byte{0xfe, 0xff}))
+		require.NoError(t, binary.Write(&b, binary.BigEndian, utf16.Encode([]rune(data))))
+
+		actual, err := ParseFile(&b)
+		require.NoError(t, err)
+		assert.Equal(t, expected, actual.Commands)
+	})
+}
+
+func TestParseMultiByte(t *testing.T) {
+	input := `FROM test
+	SYSTEM 你好👋`
+
+	expect := []Command{
+		{Name: "model", Args: "test"},
+		{Name: "system", Args: "你好👋"},
+	}
+
+	encodings := []encoding.Encoding{
+		unicode.UTF8,
+		unicode.UTF16(unicode.LittleEndian, unicode.UseBOM),
+		unicode.UTF16(unicode.BigEndian, unicode.UseBOM),
+	}
+
+	for _, encoding := range encodings {
+		t.Run(fmt.Sprintf("%s", encoding), func(t *testing.T) {
+			s, err := encoding.NewEncoder().String(input)
+			require.NoError(t, err)
+
+			actual, err := ParseFile(strings.NewReader(s))
+			require.NoError(t, err)
+
+			assert.Equal(t, expect, actual.Commands)
+		})
+	}
 }