Michael Yang před 1 rokem
rodič
revize
abe614c705
2 změnil soubory, kde provedl 118 přidání a 58 odebrání
  1. 5 2
      parser/parser.go
  2. 113 56
      parser/parser_test.go

+ 5 - 2
parser/parser.go

@@ -26,7 +26,10 @@ const (
 	stateComment
 )
 
-var errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"")
+var (
+	errMissingFrom = errors.New("no FROM line")
+	errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"")
+)
 
 func Parse(r io.Reader) (cmds []Command, err error) {
 	var cmd Command
@@ -123,7 +126,7 @@ func Parse(r io.Reader) (cmds []Command, err error) {
 		}
 	}
 
-	return nil, errors.New("no FROM line")
+	return nil, errMissingFrom
 }
 
 func parseRuneForState(r rune, cs state) (state, rune, error) {

+ 113 - 56
parser/parser_test.go

@@ -11,7 +11,6 @@ import (
 )
 
 func TestParser(t *testing.T) {
-
 	input := `
 FROM model1
 ADAPTER adapter1
@@ -38,21 +37,62 @@ TEMPLATE template1
 	assert.Equal(t, expectedCommands, commands)
 }
 
-func TestParserNoFromLine(t *testing.T) {
-
-	input := `
-PARAMETER param1 value1
-PARAMETER param2 value2
-`
-
-	reader := strings.NewReader(input)
+func TestParserFrom(t *testing.T) {
+	var cases = []struct {
+		input    string
+		expected []Command
+		err      error
+	}{
+		{
+			"FROM foo",
+			[]Command{{Name: "model", Args: "foo"}},
+			nil,
+		},
+		{
+			"FROM /path/to/model",
+			[]Command{{Name: "model", Args: "/path/to/model"}},
+			nil,
+		},
+		{
+			"FROM /path/to/model/fp16.bin",
+			[]Command{{Name: "model", Args: "/path/to/model/fp16.bin"}},
+			nil,
+		},
+		{
+			"FROM llama3:latest",
+			[]Command{{Name: "model", Args: "llama3:latest"}},
+			nil,
+		},
+		{
+			"FROM llama3:7b-instruct-q4_K_M",
+			[]Command{{Name: "model", Args: "llama3:7b-instruct-q4_K_M"}},
+			nil,
+		},
+		{
+			"", nil, errMissingFrom,
+		},
+		{
+			"PARAMETER param1 value1",
+			nil,
+			errMissingFrom,
+		},
+		{
+			"PARAMETER param1 value1\nFROM foo",
+			[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
+			nil,
+		},
+	}
 
-	_, err := Parse(reader)
-	assert.ErrorContains(t, err, "no FROM line")
+	for _, c := range cases {
+		t.Run("", func(t *testing.T) {
+			commands, err := Parse(strings.NewReader(c.input))
+			assert.ErrorIs(t, err, c.err)
+			assert.Equal(t, c.expected, commands)
+		})
+	}
 }
 
 func TestParserParametersMissingValue(t *testing.T) {
-
 	input := `
 FROM foo
 PARAMETER param1
@@ -261,6 +301,17 @@ TEMPLATE "'"
 			},
 			nil,
 		},
+		{
+			`
+FROM foo
+TEMPLATE """''"'""'""'"'''''""'""'"""
+`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "template", Args: `''"'""'""'"'''''""'""'`},
+			},
+			nil,
+		},
 	}
 
 	for _, c := range cases {
@@ -273,59 +324,65 @@ TEMPLATE "'"
 }
 
 func TestParserParameters(t *testing.T) {
-	var cases = []string{
-		"numa true",
-		"num_ctx 1",
-		"num_batch 1",
-		"num_gqa 1",
-		"num_gpu 1",
-		"main_gpu 1",
-		"low_vram true",
-		"f16_kv true",
-		"logits_all true",
-		"vocab_only true",
-		"use_mmap true",
-		"use_mlock true",
-		"num_thread 1",
-		"num_keep 1",
-		"seed 1",
-		"num_predict 1",
-		"top_k 1",
-		"top_p 1.0",
-		"tfs_z 1.0",
-		"typical_p 1.0",
-		"repeat_last_n 1",
-		"temperature 1.0",
-		"repeat_penalty 1.0",
-		"presence_penalty 1.0",
-		"frequency_penalty 1.0",
-		"mirostat 1",
-		"mirostat_tau 1.0",
-		"mirostat_eta 1.0",
-		"penalize_newline true",
-		"stop foo",
+	var cases = map[string]struct {
+		name, value string
+	}{
+		"numa true":                    {"numa", "true"},
+		"num_ctx 1":                    {"num_ctx", "1"},
+		"num_batch 1":                  {"num_batch", "1"},
+		"num_gqa 1":                    {"num_gqa", "1"},
+		"num_gpu 1":                    {"num_gpu", "1"},
+		"main_gpu 1":                   {"main_gpu", "1"},
+		"low_vram true":                {"low_vram", "true"},
+		"f16_kv true":                  {"f16_kv", "true"},
+		"logits_all true":              {"logits_all", "true"},
+		"vocab_only true":              {"vocab_only", "true"},
+		"use_mmap true":                {"use_mmap", "true"},
+		"use_mlock true":               {"use_mlock", "true"},
+		"num_thread 1":                 {"num_thread", "1"},
+		"num_keep 1":                   {"num_keep", "1"},
+		"seed 1":                       {"seed", "1"},
+		"num_predict 1":                {"num_predict", "1"},
+		"top_k 1":                      {"top_k", "1"},
+		"top_p 1.0":                    {"top_p", "1.0"},
+		"tfs_z 1.0":                    {"tfs_z", "1.0"},
+		"typical_p 1.0":                {"typical_p", "1.0"},
+		"repeat_last_n 1":              {"repeat_last_n", "1"},
+		"temperature 1.0":              {"temperature", "1.0"},
+		"repeat_penalty 1.0":           {"repeat_penalty", "1.0"},
+		"presence_penalty 1.0":         {"presence_penalty", "1.0"},
+		"frequency_penalty 1.0":        {"frequency_penalty", "1.0"},
+		"mirostat 1":                   {"mirostat", "1"},
+		"mirostat_tau 1.0":             {"mirostat_tau", "1.0"},
+		"mirostat_eta 1.0":             {"mirostat_eta", "1.0"},
+		"penalize_newline true":        {"penalize_newline", "true"},
+		"stop ### User:":               {"stop", "### User:"},
+		"stop ### User: ":              {"stop", "### User: "},
+		"stop \"### User:\"":           {"stop", "### User:"},
+		"stop \"### User: \"":          {"stop", "### User: "},
+		"stop \"\"\"### User:\"\"\"":   {"stop", "### User:"},
+		"stop \"\"\"### User:\n\"\"\"": {"stop", "### User:\n"},
+		"stop <|endoftext|>":           {"stop", "<|endoftext|>"},
+		"stop <|eot_id|>":              {"stop", "<|eot_id|>"},
+		"stop </s>":                    {"stop", "</s>"},
 	}
 
-	for _, c := range cases {
-		t.Run(c, func(t *testing.T) {
+	for k, v := range cases {
+		t.Run(k, func(t *testing.T) {
 			var b bytes.Buffer
 			fmt.Fprintln(&b, "FROM foo")
-			fmt.Fprintln(&b, "PARAMETER", c)
-			t.Logf("input: %s", b.String())
-			_, err := Parse(&b)
+			fmt.Fprintln(&b, "PARAMETER", k)
+			commands, err := Parse(&b)
 			assert.Nil(t, err)
+
+			assert.Equal(t, []Command{
+				{Name: "model", Args: "foo"},
+				{Name: v.name, Args: v.value},
+			}, commands)
 		})
 	}
 }
 
-func TestParserOnlyFrom(t *testing.T) {
-	commands, err := Parse(strings.NewReader("FROM foo"))
-	assert.Nil(t, err)
-
-	expected := []Command{{Name: "model", Args: "foo"}}
-	assert.Equal(t, expected, commands)
-}
-
 func TestParserComments(t *testing.T) {
 	var cases = []struct {
 		input    string