소스 검색

refactor modelfile parser

Michael Yang 1 년 전
부모
커밋
c0a00f68ae
3개의 변경된 파일466개의 추가작업 그리고 111개의 파일을 삭제
  1. 187 85
      parser/parser.go
  2. 279 25
      parser/parser_test.go
  3. 0 1
      server/routes_test.go

+ 187 - 85
parser/parser.go

@@ -6,8 +6,9 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"log/slog"
 	"slices"
+	"strconv"
+	"strings"
 )
 
 type Command struct {
@@ -15,118 +16,219 @@ type Command struct {
 	Args string
 }
 
-func (c *Command) Reset() {
-	c.Name = ""
-	c.Args = ""
-}
+type state int
 
-func Parse(reader io.Reader) ([]Command, error) {
-	var commands []Command
-	var command, modelCommand Command
+const (
+	stateNil state = iota
+	stateName
+	stateValue
+	stateParameter
+	stateMessage
+	stateComment
+)
 
-	scanner := bufio.NewScanner(reader)
-	scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize)
-	scanner.Split(scanModelfile)
-	for scanner.Scan() {
-		line := scanner.Bytes()
+var errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"")
+
+func Parse(r io.Reader) (cmds []Command, err error) {
+	var cmd Command
+	var curr state
+	var b bytes.Buffer
+	var role string
+
+	br := bufio.NewReader(r)
+	for {
+		r, _, err := br.ReadRune()
+		if errors.Is(err, io.EOF) {
+			break
+		} else if err != nil {
+			return nil, err
+		}
 
-		fields := bytes.SplitN(line, []byte(" "), 2)
-		if len(fields) == 0 || len(fields[0]) == 0 {
-			continue
+		next, r, err := parseRuneForState(r, curr)
+		if errors.Is(err, io.ErrUnexpectedEOF) {
+			return nil, fmt.Errorf("%w: %s", err, b.String())
+		} else if err != nil {
+			return nil, err
 		}
 
-		switch string(bytes.ToUpper(fields[0])) {
-		case "FROM":
-			command.Name = "model"
-			command.Args = string(bytes.TrimSpace(fields[1]))
-			// copy command for validation
-			modelCommand = command
-		case "ADAPTER":
-			command.Name = string(bytes.ToLower(fields[0]))
-			command.Args = string(bytes.TrimSpace(fields[1]))
-		case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
-			command.Name = string(bytes.ToLower(fields[0]))
-			command.Args = string(fields[1])
-		case "PARAMETER":
-			fields = bytes.SplitN(fields[1], []byte(" "), 2)
-			if len(fields) < 2 {
-				return nil, fmt.Errorf("missing value for %s", fields)
+		if next != curr {
+			switch curr {
+			case stateName, stateParameter:
+				switch s := strings.ToLower(b.String()); s {
+				case "from":
+					cmd.Name = "model"
+				case "parameter":
+					next = stateParameter
+				case "message":
+					next = stateMessage
+					fallthrough
+				default:
+					cmd.Name = s
+				}
+			case stateMessage:
+				if !slices.Contains([]string{"system", "user", "assistant"}, b.String()) {
+					return nil, errInvalidRole
+				}
+
+				role = b.String()
+			case stateComment, stateNil:
+				// pass
+			case stateValue:
+				s := b.String()
+
+				s, ok := unquote(b.String())
+				if !ok || isSpace(r) {
+					if _, err := b.WriteRune(r); err != nil {
+						return nil, err
+					}
+
+					continue
+				}
+
+				if role != "" {
+					s = role + ": " + s
+					role = ""
+				}
+
+				cmd.Args = s
+				cmds = append(cmds, cmd)
 			}
 
-			command.Name = string(fields[0])
-			command.Args = string(bytes.TrimSpace(fields[1]))
-		case "EMBED":
-			return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
-		case "MESSAGE":
-			command.Name = string(bytes.ToLower(fields[0]))
-			fields = bytes.SplitN(fields[1], []byte(" "), 2)
-			if len(fields) < 2 {
-				return nil, fmt.Errorf("should be in the format <role> <message>")
-			}
-			if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
-				return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
-			}
-			command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
-		default:
-			if !bytes.HasPrefix(fields[0], []byte("#")) {
-				// log a warning for unknown commands
-				slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0]))
+			b.Reset()
+			curr = next
+		}
+
+		if strconv.IsPrint(r) {
+			if _, err := b.WriteRune(r); err != nil {
+				return nil, err
 			}
-			continue
+		}
+	}
+
+	// flush the buffer
+	switch curr {
+	case stateComment, stateNil:
+		// pass; nothing to flush
+	case stateValue:
+		if _, ok := unquote(b.String()); !ok {
+			return nil, io.ErrUnexpectedEOF
 		}
 
-		commands = append(commands, command)
-		command.Reset()
+		cmd.Args = b.String()
+		cmds = append(cmds, cmd)
+	default:
+		return nil, io.ErrUnexpectedEOF
 	}
 
-	if modelCommand.Args == "" {
-		return nil, errors.New("no FROM line for the model was specified")
+	for _, cmd := range cmds {
+		if cmd.Name == "model" {
+			return cmds, nil
+		}
 	}
 
-	return commands, scanner.Err()
+	return nil, errors.New("no FROM line")
 }
 
-func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) {
-	advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF)
-	if err != nil {
-		return 0, nil, err
+func parseRuneForState(r rune, cs state) (state, rune, error) {
+	switch cs {
+	case stateNil:
+		switch {
+		case r == '#':
+			return stateComment, 0, nil
+		case isSpace(r), isNewline(r):
+			return stateNil, 0, nil
+		default:
+			return stateName, r, nil
+		}
+	case stateName:
+		switch {
+		case isAlpha(r):
+			return stateName, r, nil
+		case isSpace(r):
+			return stateValue, 0, nil
+		default:
+			return stateNil, 0, errors.New("invalid")
+		}
+	case stateValue:
+		switch {
+		case isNewline(r):
+			return stateNil, r, nil
+		case isSpace(r):
+			return stateNil, r, nil
+		default:
+			return stateValue, r, nil
+		}
+	case stateParameter:
+		switch {
+		case isAlpha(r), isNumber(r), r == '_':
+			return stateParameter, r, nil
+		case isSpace(r):
+			return stateValue, 0, nil
+		default:
+			return stateNil, 0, io.ErrUnexpectedEOF
+		}
+	case stateMessage:
+		switch {
+		case isAlpha(r):
+			return stateMessage, r, nil
+		case isSpace(r):
+			return stateValue, 0, nil
+		default:
+			return stateNil, 0, io.ErrUnexpectedEOF
+		}
+	case stateComment:
+		switch {
+		case isNewline(r):
+			return stateNil, 0, nil
+		default:
+			return stateComment, 0, nil
+		}
+	default:
+		return stateNil, 0, errors.New("")
 	}
+}
 
-	if advance > 0 && token != nil {
-		return advance, token, nil
+func unquote(s string) (string, bool) {
+	if len(s) == 0 {
+		return "", false
 	}
 
-	advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF)
-	if err != nil {
-		return 0, nil, err
+	// TODO: single quotes
+	if len(s) >= 3 && s[:3] == `"""` {
+		if len(s) >= 6 && s[len(s)-3:] == `"""` {
+			return s[3 : len(s)-3], true
+		}
+
+		return "", false
 	}
 
-	if advance > 0 && token != nil {
-		return advance, token, nil
+	if len(s) >= 1 && s[0] == '"' {
+		if len(s) >= 2 && s[len(s)-1] == '"' {
+			return s[1 : len(s)-1], true
+		}
+
+		return "", false
 	}
 
-	return bufio.ScanLines(data, atEOF)
+	return s, true
 }
 
-func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) {
-	newline := bytes.IndexByte(data, '\n')
+func isAlpha(r rune) bool {
+	return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z'
+}
 
-	if start := bytes.Index(data, openBytes); start >= 0 && start < newline {
-		end := bytes.Index(data[start+len(openBytes):], closeBytes)
-		if end < 0 {
-			if atEOF {
-				return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes)
-			} else {
-				return 0, nil, nil
-			}
-		}
+func isNumber(r rune) bool {
+	return r >= '0' && r <= '9'
+}
 
-		n := start + len(openBytes) + end + len(closeBytes)
+func isSpace(r rune) bool {
+	return r == ' ' || r == '\t'
+}
 
-		newData := data[:start]
-		newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...)
-		return n, newData, nil
-	}
+func isNewline(r rune) bool {
+	return r == '\r' || r == '\n'
+}
 
-	return 0, nil, nil
+func isValidRole(role string) bool {
+	return role == "system" || role == "user" || role == "assistant"
 }

+ 279 - 25
parser/parser_test.go

@@ -1,13 +1,16 @@
 package parser
 
 import (
+	"bytes"
+	"fmt"
+	"io"
 	"strings"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
 )
 
-func Test_Parser(t *testing.T) {
+func TestParser(t *testing.T) {
 
 	input := `
 FROM model1
@@ -35,7 +38,7 @@ TEMPLATE template1
 	assert.Equal(t, expectedCommands, commands)
 }
 
-func Test_Parser_NoFromLine(t *testing.T) {
+func TestParserNoFromLine(t *testing.T) {
 
 	input := `
 PARAMETER param1 value1
@@ -48,7 +51,7 @@ PARAMETER param2 value2
 	assert.ErrorContains(t, err, "no FROM line")
 }
 
-func Test_Parser_MissingValue(t *testing.T) {
+func TestParserParametersMissingValue(t *testing.T) {
 
 	input := `
 FROM foo
@@ -58,41 +61,292 @@ PARAMETER param1
 	reader := strings.NewReader(input)
 
 	_, err := Parse(reader)
-	assert.ErrorContains(t, err, "missing value for [param1]")
-
+	assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
 }
 
-func Test_Parser_Messages(t *testing.T) {
-
-	input := `
+func TestParserMessages(t *testing.T) {
+	var cases = []struct {
+		input    string
+		expected []Command
+		err      error
+	}{
+		{
+			`
+FROM foo
+MESSAGE system You are a Parser. Always Parse things.
+`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "message", Args: "system: You are a Parser. Always Parse things."},
+			},
+			nil,
+		},
+		{
+			`
 FROM foo
 MESSAGE system You are a Parser. Always Parse things.
 MESSAGE user Hey there!
 MESSAGE assistant Hello, I want to parse all the things!
-`
+`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "message", Args: "system: You are a Parser. Always Parse things."},
+				{Name: "message", Args: "user: Hey there!"},
+				{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+MESSAGE system """
+You are a multiline Parser. Always Parse things.
+"""
+			`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "message", Args: "system: \nYou are a multiline Parser. Always Parse things.\n"},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+MESSAGE badguy I'm a bad guy!
+`,
+			nil,
+			errInvalidRole,
+		},
+		{
+			`
+FROM foo
+MESSAGE system
+`,
+			nil,
+			io.ErrUnexpectedEOF,
+		},
+		{
+			`
+FROM foo
+MESSAGE system`,
+			nil,
+			io.ErrUnexpectedEOF,
+		},
+	}
 
-	reader := strings.NewReader(input)
-	commands, err := Parse(reader)
-	assert.Nil(t, err)
+	for _, c := range cases {
+		t.Run("", func(t *testing.T) {
+			commands, err := Parse(strings.NewReader(c.input))
+			assert.ErrorIs(t, err, c.err)
+			assert.Equal(t, c.expected, commands)
+		})
+	}
+}
 
-	expectedCommands := []Command{
-		{Name: "model", Args: "foo"},
-		{Name: "message", Args: "system: You are a Parser. Always Parse things."},
-		{Name: "message", Args: "user: Hey there!"},
-		{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
+func TestParserQuoted(t *testing.T) {
+	var cases = []struct {
+		multiline string
+		expected  []Command
+		err       error
+	}{
+		{
+			`
+FROM foo
+TEMPLATE """
+This is a
+multiline template.
+"""
+			`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "template", Args: "\nThis is a\nmultiline template.\n"},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+TEMPLATE """
+This is a
+multiline template."""
+			`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "template", Args: "\nThis is a\nmultiline template."},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+TEMPLATE """This is a
+multiline template."""
+			`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "template", Args: "This is a\nmultiline template."},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+TEMPLATE """This is a multiline template."""
+			`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "template", Args: "This is a multiline template."},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+TEMPLATE """This is a multiline template.""
+			`,
+			nil,
+			io.ErrUnexpectedEOF,
+		},
+		{
+			`
+FROM foo
+TEMPLATE "
+			`,
+			nil,
+			io.ErrUnexpectedEOF,
+		},
+		{
+			`
+FROM foo
+TEMPLATE """
+This is a multiline template with "quotes".
+"""
+`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "template", Args: "\nThis is a multiline template with \"quotes\".\n"},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+TEMPLATE """"""
+`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "template", Args: ""},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+TEMPLATE ""
+`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "template", Args: ""},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+TEMPLATE "'"
+`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "template", Args: "'"},
+			},
+			nil,
+		},
 	}
 
-	assert.Equal(t, expectedCommands, commands)
+	for _, c := range cases {
+		t.Run("", func(t *testing.T) {
+			commands, err := Parse(strings.NewReader(c.multiline))
+			assert.ErrorIs(t, err, c.err)
+			assert.Equal(t, c.expected, commands)
+		})
+	}
 }
 
-func Test_Parser_Messages_BadRole(t *testing.T) {
+func TestParserParameters(t *testing.T) {
+	var cases = []string{
+		"numa true",
+		"num_ctx 1",
+		"num_batch 1",
+		"num_gqa 1",
+		"num_gpu 1",
+		"main_gpu 1",
+		"low_vram true",
+		"f16_kv true",
+		"logits_all true",
+		"vocab_only true",
+		"use_mmap true",
+		"use_mlock true",
+		"num_thread 1",
+		"num_keep 1",
+		"seed 1",
+		"num_predict 1",
+		"top_k 1",
+		"top_p 1.0",
+		"tfs_z 1.0",
+		"typical_p 1.0",
+		"repeat_last_n 1",
+		"temperature 1.0",
+		"repeat_penalty 1.0",
+		"presence_penalty 1.0",
+		"frequency_penalty 1.0",
+		"mirostat 1",
+		"mirostat_tau 1.0",
+		"mirostat_eta 1.0",
+		"penalize_newline true",
+		"stop foo",
+	}
 
-	input := `
+	for _, c := range cases {
+		t.Run(c, func(t *testing.T) {
+			var b bytes.Buffer
+			fmt.Fprintln(&b, "FROM foo")
+			fmt.Fprintln(&b, "PARAMETER", c)
+			t.Logf("input: %s", b.String())
+			_, err := Parse(&b)
+			assert.Nil(t, err)
+		})
+	}
+}
+
+func TestParserOnlyFrom(t *testing.T) {
+	commands, err := Parse(strings.NewReader("FROM foo"))
+	assert.Nil(t, err)
+
+	expected := []Command{{Name: "model", Args: "foo"}}
+	assert.Equal(t, expected, commands)
+}
+
+func TestParserComments(t *testing.T) {
+	var cases = []struct {
+		input    string
+		expected []Command
+	}{
+		{
+			`
+# comment
 FROM foo
-MESSAGE badguy I'm a bad guy!
-`
+	`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+			},
+		},
+	}
 
-	reader := strings.NewReader(input)
-	_, err := Parse(reader)
-	assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
+	for _, c := range cases {
+		t.Run("", func(t *testing.T) {
+			commands, err := Parse(strings.NewReader(c.input))
+			assert.Nil(t, err)
+			assert.Equal(t, c.expected, commands)
+		})
+	}
 }

+ 0 - 1
server/routes_test.go

@@ -238,6 +238,5 @@ func Test_Routes(t *testing.T) {
 		if tc.Expected != nil {
 			tc.Expected(t, resp)
 		}
-
 	}
 }