瀏覽代碼

fix parser name

Michael Yang 1 年之前
父節點
當前提交
bd8eed57fc
共有 2 個文件被更改,包括 32 次插入6 次删除
  1. 21 5
      parser/parser.go
  2. 11 1
      parser/parser_test.go

+ 21 - 5
parser/parser.go

@@ -27,8 +27,9 @@ const (
 )
 
 var (
-	errMissingFrom = errors.New("no FROM line")
-	errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"")
+	errMissingFrom        = errors.New("no FROM line")
+	errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
+	errInvalidCommand     = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
 )
 
 func Format(cmds []Command) string {
@@ -82,7 +83,11 @@ func Parse(r io.Reader) (cmds []Command, err error) {
 		// process the state transition, some transitions need to be intercepted and redirected
 		if next != curr {
 			switch curr {
-			case stateName, stateParameter:
+			case stateName:
+				if !isValidCommand(b.String()) {
+					return nil, errInvalidCommand
+				}
+
 				// next state sometimes depends on the current buffer value
 				switch s := strings.ToLower(b.String()); s {
 				case "from":
@@ -97,9 +102,11 @@ func Parse(r io.Reader) (cmds []Command, err error) {
 				default:
 					cmd.Name = s
 				}
+			case stateParameter:
+				cmd.Name = b.String()
 			case stateMessage:
 				if !isValidMessageRole(b.String()) {
-					return nil, errInvalidRole
+					return nil, errInvalidMessageRole
 				}
 
 				role = b.String()
@@ -182,7 +189,7 @@ func parseRuneForState(r rune, cs state) (state, rune, error) {
 		case isSpace(r):
 			return stateValue, 0, nil
 		default:
-			return stateNil, 0, errors.New("invalid")
+			return stateNil, 0, errInvalidCommand
 		}
 	case stateValue:
 		switch {
@@ -279,3 +286,12 @@ func isNewline(r rune) bool {
 func isValidMessageRole(role string) bool {
 	return role == "system" || role == "user" || role == "assistant"
 }
+
+func isValidCommand(cmd string) bool {
+	switch strings.ToLower(cmd) {
+	case "from", "license", "template", "system", "adapter", "parameter", "message":
+		return true
+	default:
+		return false
+	}
+}

+ 11 - 1
parser/parser_test.go

@@ -104,6 +104,16 @@ PARAMETER param1
 	assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
 }
 
+func TestParserBadCommand(t *testing.T) {
+	input := `
+FROM foo
+BADCOMMAND param1 value1
+`
+	_, err := Parse(strings.NewReader(input))
+	assert.ErrorIs(t, err, errInvalidCommand)
+
+}
+
 func TestParserMessages(t *testing.T) {
 	var cases = []struct {
 		input    string
@@ -165,7 +175,7 @@ FROM foo
 MESSAGE badguy I'm a bad guy!
 `,
 			nil,
-			errInvalidRole,
+			errInvalidMessageRole,
 		},
 		{
 			`