|
@@ -6,8 +6,9 @@ import (
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
- "log/slog"
|
|
|
"slices"
|
|
|
+ "strconv"
|
|
|
+ "strings"
|
|
|
)
|
|
|
|
|
|
type Command struct {
|
|
@@ -15,118 +16,219 @@ type Command struct {
|
|
|
Args string
|
|
|
}
|
|
|
|
|
|
-func (c *Command) Reset() {
|
|
|
- c.Name = ""
|
|
|
- c.Args = ""
|
|
|
-}
|
|
|
+type state int
|
|
|
|
|
|
-func Parse(reader io.Reader) ([]Command, error) {
|
|
|
- var commands []Command
|
|
|
- var command, modelCommand Command
|
|
|
+const (
|
|
|
+ stateNil state = iota
|
|
|
+ stateName
|
|
|
+ stateValue
|
|
|
+ stateParameter
|
|
|
+ stateMessage
|
|
|
+ stateComment
|
|
|
+)
|
|
|
|
|
|
- scanner := bufio.NewScanner(reader)
|
|
|
- scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize)
|
|
|
- scanner.Split(scanModelfile)
|
|
|
- for scanner.Scan() {
|
|
|
- line := scanner.Bytes()
|
|
|
+var errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"")
|
|
|
+
|
|
|
+func Parse(r io.Reader) (cmds []Command, err error) {
|
|
|
+ var cmd Command
|
|
|
+ var curr state
|
|
|
+ var b bytes.Buffer
|
|
|
+ var role string
|
|
|
+
|
|
|
+ br := bufio.NewReader(r)
|
|
|
+ for {
|
|
|
+ r, _, err := br.ReadRune()
|
|
|
+ if errors.Is(err, io.EOF) {
|
|
|
+ break
|
|
|
+ } else if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
|
|
|
- fields := bytes.SplitN(line, []byte(" "), 2)
|
|
|
- if len(fields) == 0 || len(fields[0]) == 0 {
|
|
|
- continue
|
|
|
+ next, r, err := parseRuneForState(r, curr)
|
|
|
+ if errors.Is(err, io.ErrUnexpectedEOF) {
|
|
|
+ return nil, fmt.Errorf("%w: %s", err, b.String())
|
|
|
+ } else if err != nil {
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
- switch string(bytes.ToUpper(fields[0])) {
|
|
|
- case "FROM":
|
|
|
- command.Name = "model"
|
|
|
- command.Args = string(bytes.TrimSpace(fields[1]))
|
|
|
- // copy command for validation
|
|
|
- modelCommand = command
|
|
|
- case "ADAPTER":
|
|
|
- command.Name = string(bytes.ToLower(fields[0]))
|
|
|
- command.Args = string(bytes.TrimSpace(fields[1]))
|
|
|
- case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
|
|
|
- command.Name = string(bytes.ToLower(fields[0]))
|
|
|
- command.Args = string(fields[1])
|
|
|
- case "PARAMETER":
|
|
|
- fields = bytes.SplitN(fields[1], []byte(" "), 2)
|
|
|
- if len(fields) < 2 {
|
|
|
- return nil, fmt.Errorf("missing value for %s", fields)
|
|
|
+ if next != curr {
|
|
|
+ switch curr {
|
|
|
+ case stateName, stateParameter:
|
|
|
+ switch s := strings.ToLower(b.String()); s {
|
|
|
+ case "from":
|
|
|
+ cmd.Name = "model"
|
|
|
+ case "parameter":
|
|
|
+ next = stateParameter
|
|
|
+ case "message":
|
|
|
+ next = stateMessage
|
|
|
+ fallthrough
|
|
|
+ default:
|
|
|
+ cmd.Name = s
|
|
|
+ }
|
|
|
+ case stateMessage:
|
|
|
+ if !slices.Contains([]string{"system", "user", "assistant"}, b.String()) {
|
|
|
+ return nil, errInvalidRole
|
|
|
+ }
|
|
|
+
|
|
|
+ role = b.String()
|
|
|
+ case stateComment, stateNil:
|
|
|
+ // pass
|
|
|
+ case stateValue:
|
|
|
+ s := b.String()
|
|
|
+
|
|
|
+ s, ok := unquote(b.String())
|
|
|
+ if !ok || isSpace(r) {
|
|
|
+ if _, err := b.WriteRune(r); err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ if role != "" {
|
|
|
+ s = role + ": " + s
|
|
|
+ role = ""
|
|
|
+ }
|
|
|
+
|
|
|
+ cmd.Args = s
|
|
|
+ cmds = append(cmds, cmd)
|
|
|
}
|
|
|
|
|
|
- command.Name = string(fields[0])
|
|
|
- command.Args = string(bytes.TrimSpace(fields[1]))
|
|
|
- case "EMBED":
|
|
|
- return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
|
|
|
- case "MESSAGE":
|
|
|
- command.Name = string(bytes.ToLower(fields[0]))
|
|
|
- fields = bytes.SplitN(fields[1], []byte(" "), 2)
|
|
|
- if len(fields) < 2 {
|
|
|
- return nil, fmt.Errorf("should be in the format <role> <message>")
|
|
|
- }
|
|
|
- if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
|
|
|
- return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
|
|
|
- }
|
|
|
- command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
|
|
|
- default:
|
|
|
- if !bytes.HasPrefix(fields[0], []byte("#")) {
|
|
|
- // log a warning for unknown commands
|
|
|
- slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0]))
|
|
|
+ b.Reset()
|
|
|
+ curr = next
|
|
|
+ }
|
|
|
+
|
|
|
+ if strconv.IsPrint(r) {
|
|
|
+ if _, err := b.WriteRune(r); err != nil {
|
|
|
+ return nil, err
|
|
|
}
|
|
|
- continue
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // flush the buffer
|
|
|
+ switch curr {
|
|
|
+ case stateComment, stateNil:
|
|
|
+ // pass; nothing to flush
|
|
|
+ case stateValue:
|
|
|
+ if _, ok := unquote(b.String()); !ok {
|
|
|
+ return nil, io.ErrUnexpectedEOF
|
|
|
}
|
|
|
|
|
|
- commands = append(commands, command)
|
|
|
- command.Reset()
|
|
|
+ cmd.Args = b.String()
|
|
|
+ cmds = append(cmds, cmd)
|
|
|
+ default:
|
|
|
+ return nil, io.ErrUnexpectedEOF
|
|
|
}
|
|
|
|
|
|
- if modelCommand.Args == "" {
|
|
|
- return nil, errors.New("no FROM line for the model was specified")
|
|
|
+ for _, cmd := range cmds {
|
|
|
+ if cmd.Name == "model" {
|
|
|
+ return cmds, nil
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- return commands, scanner.Err()
|
|
|
+ return nil, errors.New("no FROM line")
|
|
|
}
|
|
|
|
|
|
-func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
|
- advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF)
|
|
|
- if err != nil {
|
|
|
- return 0, nil, err
|
|
|
+func parseRuneForState(r rune, cs state) (state, rune, error) {
|
|
|
+ switch cs {
|
|
|
+ case stateNil:
|
|
|
+ switch {
|
|
|
+ case r == '#':
|
|
|
+ return stateComment, 0, nil
|
|
|
+ case isSpace(r), isNewline(r):
|
|
|
+ return stateNil, 0, nil
|
|
|
+ default:
|
|
|
+ return stateName, r, nil
|
|
|
+ }
|
|
|
+ case stateName:
|
|
|
+ switch {
|
|
|
+ case isAlpha(r):
|
|
|
+ return stateName, r, nil
|
|
|
+ case isSpace(r):
|
|
|
+ return stateValue, 0, nil
|
|
|
+ default:
|
|
|
+ return stateNil, 0, errors.New("invalid")
|
|
|
+ }
|
|
|
+ case stateValue:
|
|
|
+ switch {
|
|
|
+ case isNewline(r):
|
|
|
+ return stateNil, r, nil
|
|
|
+ case isSpace(r):
|
|
|
+ return stateNil, r, nil
|
|
|
+ default:
|
|
|
+ return stateValue, r, nil
|
|
|
+ }
|
|
|
+ case stateParameter:
|
|
|
+ switch {
|
|
|
+ case isAlpha(r), isNumber(r), r == '_':
|
|
|
+ return stateParameter, r, nil
|
|
|
+ case isSpace(r):
|
|
|
+ return stateValue, 0, nil
|
|
|
+ default:
|
|
|
+ return stateNil, 0, io.ErrUnexpectedEOF
|
|
|
+ }
|
|
|
+ case stateMessage:
|
|
|
+ switch {
|
|
|
+ case isAlpha(r):
|
|
|
+ return stateMessage, r, nil
|
|
|
+ case isSpace(r):
|
|
|
+ return stateValue, 0, nil
|
|
|
+ default:
|
|
|
+ return stateNil, 0, io.ErrUnexpectedEOF
|
|
|
+ }
|
|
|
+ case stateComment:
|
|
|
+ switch {
|
|
|
+ case isNewline(r):
|
|
|
+ return stateNil, 0, nil
|
|
|
+ default:
|
|
|
+ return stateComment, 0, nil
|
|
|
+ }
|
|
|
+ default:
|
|
|
+ return stateNil, 0, errors.New("")
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- if advance > 0 && token != nil {
|
|
|
- return advance, token, nil
|
|
|
+func unquote(s string) (string, bool) {
|
|
|
+ if len(s) == 0 {
|
|
|
+ return "", false
|
|
|
}
|
|
|
|
|
|
- advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF)
|
|
|
- if err != nil {
|
|
|
- return 0, nil, err
|
|
|
+ // TODO: single quotes
|
|
|
+ if len(s) >= 3 && s[:3] == `"""` {
|
|
|
+ if len(s) >= 6 && s[len(s)-3:] == `"""` {
|
|
|
+ return s[3 : len(s)-3], true
|
|
|
+ }
|
|
|
+
|
|
|
+ return "", false
|
|
|
}
|
|
|
|
|
|
- if advance > 0 && token != nil {
|
|
|
- return advance, token, nil
|
|
|
+ if len(s) >= 1 && s[0] == '"' {
|
|
|
+ if len(s) >= 2 && s[len(s)-1] == '"' {
|
|
|
+ return s[1 : len(s)-1], true
|
|
|
+ }
|
|
|
+
|
|
|
+ return "", false
|
|
|
}
|
|
|
|
|
|
- return bufio.ScanLines(data, atEOF)
|
|
|
+ return s, true
|
|
|
}
|
|
|
|
|
|
-func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
|
- newline := bytes.IndexByte(data, '\n')
|
|
|
+func isAlpha(r rune) bool {
|
|
|
+ return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z'
|
|
|
+}
|
|
|
|
|
|
- if start := bytes.Index(data, openBytes); start >= 0 && start < newline {
|
|
|
- end := bytes.Index(data[start+len(openBytes):], closeBytes)
|
|
|
- if end < 0 {
|
|
|
- if atEOF {
|
|
|
- return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes)
|
|
|
- } else {
|
|
|
- return 0, nil, nil
|
|
|
- }
|
|
|
- }
|
|
|
+func isNumber(r rune) bool {
|
|
|
+ return r >= '0' && r <= '9'
|
|
|
+}
|
|
|
|
|
|
- n := start + len(openBytes) + end + len(closeBytes)
|
|
|
+func isSpace(r rune) bool {
|
|
|
+ return r == ' ' || r == '\t'
|
|
|
+}
|
|
|
|
|
|
- newData := data[:start]
|
|
|
- newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...)
|
|
|
- return n, newData, nil
|
|
|
- }
|
|
|
+func isNewline(r rune) bool {
|
|
|
+ return r == '\r' || r == '\n'
|
|
|
+}
|
|
|
|
|
|
- return 0, nil, nil
|
|
|
+func isValidRole(role string) bool {
|
|
|
+ return role == "system" || role == "user" || role == "assistant"
|
|
|
}
|