浏览代码

parser: add commands format

Michael Yang 1 年之前
父节点
当前提交
176ad3aa6e
共有 3 个文件被更改,包括 108 次插入15 次删除
  1. 9 15
      cmd/cmd.go
  2. 39 0
      parser/parser.go
  3. 60 0
      parser/parser_test.go

+ 9 - 15
cmd/cmd.go

@@ -17,7 +17,6 @@ import (
 	"os"
 	"os/signal"
 	"path/filepath"
-	"regexp"
 	"runtime"
 	"strings"
 	"syscall"
@@ -57,12 +56,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	p := progress.NewProgress(os.Stderr)
 	defer p.Stop()
 
-	modelfile, err := os.ReadFile(filename)
+	modelfile, err := os.Open(filename)
 	if err != nil {
 		return err
 	}
+	defer modelfile.Close()
 
-	commands, err := parser.Parse(bytes.NewReader(modelfile))
+	commands, err := parser.Parse(modelfile)
 	if err != nil {
 		return err
 	}
@@ -76,10 +76,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	spinner := progress.NewSpinner(status)
 	p.Add(status, spinner)
 
-	for _, c := range commands {
-		switch c.Name {
+	for i := range commands {
+		switch commands[i].Name {
 		case "model", "adapter":
-			path := c.Args
+			path := commands[i].Args
 			if path == "~" {
 				path = home
 			} else if strings.HasPrefix(path, "~/") {
@@ -91,7 +91,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 			}
 
 			fi, err := os.Stat(path)
-			if errors.Is(err, os.ErrNotExist) && c.Name == "model" {
+			if errors.Is(err, os.ErrNotExist) && commands[i].Name == "model" {
 				continue
 			} else if err != nil {
 				return err
@@ -114,13 +114,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 				return err
 			}
 
-			name := c.Name
-			if c.Name == "model" {
-				name = "from"
-			}
-
-			re := regexp.MustCompile(fmt.Sprintf(`(?im)^(%s)\s+%s\s*$`, name, c.Args))
-			modelfile = re.ReplaceAll(modelfile, []byte("$1 @"+digest))
+			commands[i].Args = "@"+digest
 		}
 	}
 
@@ -150,7 +144,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 
 	quantization, _ := cmd.Flags().GetString("quantization")
 
-	request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization}
+	request := api.CreateRequest{Name: args[0], Modelfile: parser.Format(commands), Quantization: quantization}
 	if err := client.Create(cmd.Context(), &request, fn); err != nil {
 		return err
 	}

+ 39 - 0
parser/parser.go

@@ -31,6 +31,33 @@ var (
 	errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"")
 )
 
+func Format(cmds []Command) string {
+	var b bytes.Buffer
+	for _, cmd := range cmds {
+		name := cmd.Name
+		args := cmd.Args
+
+		switch cmd.Name {
+		case "model":
+			name = "from"
+			args = cmd.Args
+		case "license", "template", "system", "adapter":
+			args = quote(args)
+			// pass
+		case "message":
+			role, message, _ := strings.Cut(cmd.Args, ": ")
+			args = role + " " + quote(message)
+		default:
+			name = "parameter"
+			args = cmd.Name + " " + cmd.Args
+		}
+
+		fmt.Fprintln(&b, strings.ToUpper(name), args)
+	}
+
+	return b.String()
+}
+
 func Parse(r io.Reader) (cmds []Command, err error) {
 	var cmd Command
 	var curr state
@@ -197,6 +224,18 @@ func parseRuneForState(r rune, cs state) (state, rune, error) {
 	}
 }
 
+func quote(s string) string {
+	if strings.Contains(s, "\n") || strings.HasSuffix(s, " ") {
+		if strings.Contains(s, "\"") {
+			return `"""` + s + `"""`
+		}
+
+		return strconv.Quote(s)
+	}
+
+	return s
+}
+
 func unquote(s string) (string, bool) {
 	if len(s) == 0 {
 		return "", false

+ 60 - 0
parser/parser_test.go

@@ -429,3 +429,63 @@ FROM foo
 		})
 	}
 }
+
+func TestParseFormatParse(t *testing.T) {
+	var cases = []string{
+		`
+FROM foo
+ADAPTER adapter1
+LICENSE MIT
+PARAMETER param1 value1
+PARAMETER param2 value2
+TEMPLATE template1
+MESSAGE system You are a Parser. Always Parse things.
+MESSAGE user Hey there!
+MESSAGE assistant Hello, I want to parse all the things!
+`,
+		`
+FROM foo
+ADAPTER adapter1
+LICENSE MIT
+PARAMETER param1 value1
+PARAMETER param2 value2
+TEMPLATE template1
+MESSAGE system """
+You are a store greeter. Always responsed with "Hello!".
+"""
+MESSAGE user Hey there!
+MESSAGE assistant Hello, I want to parse all the things!
+`,
+		`
+FROM foo
+ADAPTER adapter1
+LICENSE """
+Very long and boring legal text.
+Blah blah blah.
+"Oh look, a quote!"
+"""
+
+PARAMETER param1 value1
+PARAMETER param2 value2
+TEMPLATE template1
+MESSAGE system """
+You are a store greeter. Always responsed with "Hello!".
+"""
+MESSAGE user Hey there!
+MESSAGE assistant Hello, I want to parse all the things!
+`,
+	}
+
+	for _, c := range cases {
+		t.Run("", func(t *testing.T) {
+			commands, err := Parse(strings.NewReader(c))
+			assert.NoError(t, err)
+
+			commands2, err := Parse(strings.NewReader(Format(commands)))
+			assert.NoError(t, err)
+
+			assert.Equal(t, commands, commands2)
+		})
+	}
+
+}