Selaa lähdekoodia

Move the parser back + handle utf16 files (#4533)

Patrick Devine 11 kuukautta sitten
vanhempi
commit
ccdf0b2a44
6 muutettua tiedostoa jossa 84 lisäystä ja 17 poistoa
  1. 2 1
      cmd/cmd.go
  2. 29 1
      parser/parser.go
  3. 37 1
      parser/parser_test.go
  4. 12 11
      server/images.go
  5. 2 1
      server/routes.go
  6. 2 2
      server/routes_test.go

+ 2 - 1
cmd/cmd.go

@@ -35,6 +35,7 @@ import (
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/auth"
 	"github.com/ollama/ollama/format"
+	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/progress"
 	"github.com/ollama/ollama/server"
 	"github.com/ollama/ollama/types/errtypes"
@@ -63,7 +64,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	}
 	defer f.Close()
 
-	modelfile, err := model.ParseFile(f)
+	modelfile, err := parser.ParseFile(f)
 	if err != nil {
 		return err
 	}

+ 29 - 1
types/model/file.go → parser/parser.go

@@ -1,4 +1,4 @@
-package model
+package parser
 
 import (
 	"bufio"
@@ -8,6 +8,7 @@ import (
 	"io"
 	"strconv"
 	"strings"
+	"unicode"
 )
 
 type File struct {
@@ -68,6 +69,11 @@ func ParseFile(r io.Reader) (*File, error) {
 	var b bytes.Buffer
 	var role string
 
+	var lineCount int
+	var linePos int
+
+	var utf16 bool
+
 	var f File
 
 	br := bufio.NewReader(r)
@@ -79,6 +85,17 @@ func ParseFile(r io.Reader) (*File, error) {
 			return nil, err
 		}
 
+		// the utf16 byte order mark will be read as "unreadable" by ReadRune()
+		if isUnreadable(r) && lineCount == 0 && linePos == 0 {
+			utf16 = true
+			continue
+		}
+
+		// skip the second byte if we're reading utf16
+		if utf16 && r == 0 {
+			continue
+		}
+
 		next, r, err := parseRuneForState(r, curr)
 		if errors.Is(err, io.ErrUnexpectedEOF) {
 			return nil, fmt.Errorf("%w: %s", err, b.String())
@@ -86,6 +103,13 @@ func ParseFile(r io.Reader) (*File, error) {
 			return nil, err
 		}
 
+		if isNewline(r) {
+			lineCount++
+			linePos = 0
+		} else {
+			linePos++
+		}
+
 		// process the state transition, some transitions need to be intercepted and redirected
 		if next != curr {
 			switch curr {
@@ -285,6 +309,10 @@ func isNewline(r rune) bool {
 	return r == '\r' || r == '\n'
 }
 
+func isUnreadable(r rune) bool {
+	return r == unicode.ReplacementChar
+}
+
 func isValidMessageRole(role string) bool {
 	return role == "system" || role == "user" || role == "assistant"
 }

+ 37 - 1
types/model/file_test.go → parser/parser_test.go

@@ -1,11 +1,13 @@
-package model
+package parser
 
 import (
 	"bytes"
+	"encoding/binary"
 	"fmt"
 	"io"
 	"strings"
 	"testing"
+	"unicode/utf16"
 
 	"github.com/stretchr/testify/assert"
 )
@@ -509,3 +511,37 @@ SYSTEM ""
 	}
 
 }
+
+func TestParseFileUTF16ParseFile(t *testing.T) {
+	data := `FROM bob
+PARAMETER param1 1
+PARAMETER param2 4096
+SYSTEM You are a utf16 file.
+`
+	// simulate a utf16 le file
+	utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...))
+	buf := new(bytes.Buffer)
+	err := binary.Write(buf, binary.LittleEndian, utf16File)
+	assert.NoError(t, err)
+
+	actual, err := ParseFile(buf)
+	assert.NoError(t, err)
+
+	expected := []Command{
+		{Name: "model", Args: "bob"},
+		{Name: "param1", Args: "1"},
+		{Name: "param2", Args: "4096"},
+		{Name: "system", Args: "You are a utf16 file."},
+	}
+
+	assert.Equal(t, expected, actual.Commands)
+
+	// simulate a utf16 be file
+	buf = new(bytes.Buffer)
+	err = binary.Write(buf, binary.BigEndian, utf16File)
+	assert.NoError(t, err)
+
+	actual, err = ParseFile(buf)
+	assert.NoError(t, err)
+	assert.Equal(t, expected, actual.Commands)
+}

+ 12 - 11
server/images.go

@@ -27,6 +27,7 @@ import (
 	"github.com/ollama/ollama/auth"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/llm"
+	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/server/envconfig"
 	"github.com/ollama/ollama/types/errtypes"
 	"github.com/ollama/ollama/types/model"
@@ -61,36 +62,36 @@ func (m *Model) IsEmbedding() bool {
 }
 
 func (m *Model) String() string {
-	var modelfile model.File
+	var modelfile parser.File
 
-	modelfile.Commands = append(modelfile.Commands, model.Command{
+	modelfile.Commands = append(modelfile.Commands, parser.Command{
 		Name: "model",
 		Args: m.ModelPath,
 	})
 
 	for _, adapter := range m.AdapterPaths {
-		modelfile.Commands = append(modelfile.Commands, model.Command{
+		modelfile.Commands = append(modelfile.Commands, parser.Command{
 			Name: "adapter",
 			Args: adapter,
 		})
 	}
 
 	for _, projector := range m.ProjectorPaths {
-		modelfile.Commands = append(modelfile.Commands, model.Command{
+		modelfile.Commands = append(modelfile.Commands, parser.Command{
 			Name: "model",
 			Args: projector,
 		})
 	}
 
 	if m.Template != "" {
-		modelfile.Commands = append(modelfile.Commands, model.Command{
+		modelfile.Commands = append(modelfile.Commands, parser.Command{
 			Name: "template",
 			Args: m.Template,
 		})
 	}
 
 	if m.System != "" {
-		modelfile.Commands = append(modelfile.Commands, model.Command{
+		modelfile.Commands = append(modelfile.Commands, parser.Command{
 			Name: "system",
 			Args: m.System,
 		})
@@ -100,13 +101,13 @@ func (m *Model) String() string {
 		switch v := v.(type) {
 		case []any:
 			for _, s := range v {
-				modelfile.Commands = append(modelfile.Commands, model.Command{
+				modelfile.Commands = append(modelfile.Commands, parser.Command{
 					Name: k,
 					Args: fmt.Sprintf("%v", s),
 				})
 			}
 		default:
-			modelfile.Commands = append(modelfile.Commands, model.Command{
+			modelfile.Commands = append(modelfile.Commands, parser.Command{
 				Name: k,
 				Args: fmt.Sprintf("%v", v),
 			})
@@ -114,14 +115,14 @@ func (m *Model) String() string {
 	}
 
 	for _, license := range m.License {
-		modelfile.Commands = append(modelfile.Commands, model.Command{
+		modelfile.Commands = append(modelfile.Commands, parser.Command{
 			Name: "license",
 			Args: license,
 		})
 	}
 
 	for _, msg := range m.Messages {
-		modelfile.Commands = append(modelfile.Commands, model.Command{
+		modelfile.Commands = append(modelfile.Commands, parser.Command{
 			Name: "message",
 			Args: fmt.Sprintf("%s %s", msg.Role, msg.Content),
 		})
@@ -314,7 +315,7 @@ func realpath(rel, from string) string {
 	return abspath
 }
 
-func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) {
+func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *parser.File, fn func(resp api.ProgressResponse)) (err error) {
 	config := ConfigV2{
 		OS:           "linux",
 		Architecture: "amd64",

+ 2 - 1
server/routes.go

@@ -29,6 +29,7 @@ import (
 	"github.com/ollama/ollama/gpu"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/openai"
+	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/server/envconfig"
 	"github.com/ollama/ollama/types/errtypes"
 	"github.com/ollama/ollama/types/model"
@@ -539,7 +540,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
 		r = f
 	}
 
-	modelfile, err := model.ParseFile(r)
+	modelfile, err := parser.ParseFile(r)
 	if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return

+ 2 - 2
server/routes_test.go

@@ -17,7 +17,7 @@ import (
 	"github.com/stretchr/testify/assert"
 
 	"github.com/ollama/ollama/api"
-	"github.com/ollama/ollama/types/model"
+	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/version"
 )
 
@@ -56,7 +56,7 @@ func Test_Routes(t *testing.T) {
 		fname := createTestFile(t, "ollama-model")
 
 		r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
-		modelfile, err := model.ParseFile(r)
+		modelfile, err := parser.ParseFile(r)
 		assert.Nil(t, err)
 		fn := func(resp api.ProgressResponse) {
 			t.Logf("Status: %s", resp.Status)