浏览代码

Improve command parsing and multiline string handling

Mohit Gaur 1 年之前
父节点
当前提交
ed89da92b4
共有 1 个文件被更改,包括 12 次插入11 次删除
  1. 12 11
      parser/parser.go

+ 12 - 11
parser/parser.go

@@ -2,12 +2,14 @@ package parser
 
 import (
 	"bufio"
-	"bytes"
 	"errors"
 	"fmt"
 	"io"
+	"strings"
 )
 
+const multilineString = `"""`
+
 type Command struct {
 	Name string
 	Args string
@@ -20,7 +22,6 @@ func (c *Command) Reset() {
 
 func Parse(reader io.Reader) ([]Command, error) {
 	var commands []Command
-
 	var command, modelCommand Command
 
 	scanner := bufio.NewScanner(reader)
@@ -33,21 +34,21 @@ func Parse(reader io.Reader) ([]Command, error) {
 			continue
 		}
 
-		switch string(bytes.ToUpper(fields[0])) {
+		switch strings.ToUpper(string(fields[0])) {
 		case "FROM":
 			command.Name = "model"
 			command.Args = string(fields[1])
 			// copy command for validation
 			modelCommand = command
 		case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
-			command.Name = string(bytes.ToLower(fields[0]))
+			command.Name = strings.ToLower(string(fields[0]))
 			command.Args = string(fields[1])
 		case "PARAMETER":
 			fields = bytes.SplitN(fields[1], []byte(" "), 2)
 			command.Name = string(fields[0])
 			command.Args = string(fields[1])
 		default:
-			continue
+			return nil, fmt.Errorf("unknown command: %s", fields[0])
 		}
 
 		commands = append(commands, command)
@@ -55,7 +56,7 @@ func Parse(reader io.Reader) ([]Command, error) {
 	}
 
 	if modelCommand.Args == "" {
-		return nil, fmt.Errorf("no FROM line for the model was specified")
+		return nil, errors.New("no FROM line for the model was specified")
 	}
 
 	return commands, scanner.Err()
@@ -64,18 +65,18 @@ func Parse(reader io.Reader) ([]Command, error) {
 func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) {
 	newline := bytes.IndexByte(data, '\n')
 
-	if start := bytes.Index(data, []byte(`"""`)); start >= 0 && start < newline {
-		end := bytes.Index(data[start+3:], []byte(`"""`))
+	if start := bytes.Index(data, []byte(multilineString)); start >= 0 && start < newline {
+		end := bytes.Index(data[start+len(multilineString):], []byte(multilineString))
 		if end < 0 {
 			if atEOF {
-				return 0, nil, errors.New(`unterminated multiline string: """`)
+				return 0, nil, errors.New("unterminated multiline string: " + multilineString)
 			} else {
 				return 0, nil, nil
 			}
 		}
 
-		n := start + 3 + end + 3
-		return n, bytes.Replace(data[:n], []byte(`"""`), []byte(""), 2), nil
+		n := start + len(multilineString) + end + len(multilineString)
+		return n, data[:n], nil
 	}
 
 	return bufio.ScanLines(data, atEOF)