浏览代码

Merge pull request #3892 from ollama/mxyng/parser

refactor modelfile parser
Michael Yang 1 年之前
父节点
当前提交
e9ae607ece
共有 6 个文件被更改,包括 748 次插入206 次删除
  1. 9 14
      cmd/cmd.go
  2. 254 89
      parser/parser.go
  3. 437 34
      parser/parser_test.go
  4. 42 62
      server/images.go
  5. 6 6
      server/routes.go
  6. 0 1
      server/routes_test.go

+ 9 - 14
cmd/cmd.go

@@ -57,12 +57,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	p := progress.NewProgress(os.Stderr)
 	defer p.Stop()
 
-	modelfile, err := os.ReadFile(filename)
+	modelfile, err := os.Open(filename)
 	if err != nil {
 		return err
 	}
+	defer modelfile.Close()
 
-	commands, err := parser.Parse(bytes.NewReader(modelfile))
+	commands, err := parser.Parse(modelfile)
 	if err != nil {
 		return err
 	}
@@ -76,10 +77,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	spinner := progress.NewSpinner(status)
 	p.Add(status, spinner)
 
-	for _, c := range commands {
-		switch c.Name {
+	for i := range commands {
+		switch commands[i].Name {
 		case "model", "adapter":
-			path := c.Args
+			path := commands[i].Args
 			if path == "~" {
 				path = home
 			} else if strings.HasPrefix(path, "~/") {
@@ -91,7 +92,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 			}
 
 			fi, err := os.Stat(path)
-			if errors.Is(err, os.ErrNotExist) && c.Name == "model" {
+			if errors.Is(err, os.ErrNotExist) && commands[i].Name == "model" {
 				continue
 			} else if err != nil {
 				return err
@@ -114,13 +115,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 				return err
 			}
 
-			name := c.Name
-			if c.Name == "model" {
-				name = "from"
-			}
-
-			re := regexp.MustCompile(fmt.Sprintf(`(?im)^(%s)\s+%s\s*$`, name, c.Args))
-			modelfile = re.ReplaceAll(modelfile, []byte("$1 @"+digest))
+			commands[i].Args = "@"+digest
 		}
 	}
 
@@ -150,7 +145,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 
 	quantization, _ := cmd.Flags().GetString("quantization")
 
-	request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization}
+	request := api.CreateRequest{Name: args[0], Modelfile: parser.Format(commands), Quantization: quantization}
 	if err := client.Create(cmd.Context(), &request, fn); err != nil {
 		return err
 	}

+ 254 - 89
parser/parser.go

@@ -6,8 +6,8 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"log/slog"
-	"slices"
+	"strconv"
+	"strings"
 )
 
 type Command struct {
@@ -15,118 +15,283 @@ type Command struct {
 	Args string
 }
 
-func (c *Command) Reset() {
-	c.Name = ""
-	c.Args = ""
+type state int
+
+const (
+	stateNil state = iota
+	stateName
+	stateValue
+	stateParameter
+	stateMessage
+	stateComment
+)
+
+var (
+	errMissingFrom        = errors.New("no FROM line")
+	errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
+	errInvalidCommand     = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
+)
+
+func Format(cmds []Command) string {
+	var sb strings.Builder
+	for _, cmd := range cmds {
+		name := cmd.Name
+		args := cmd.Args
+
+		switch cmd.Name {
+		case "model":
+			name = "from"
+			args = cmd.Args
+		case "license", "template", "system", "adapter":
+			args = quote(args)
+		case "message":
+			role, message, _ := strings.Cut(cmd.Args, ": ")
+			args = role + " " + quote(message)
+		default:
+			name = "parameter"
+			args = cmd.Name + " " + quote(cmd.Args)
+		}
+
+		fmt.Fprintln(&sb, strings.ToUpper(name), args)
+	}
+
+	return sb.String()
 }
 
-func Parse(reader io.Reader) ([]Command, error) {
-	var commands []Command
-	var command, modelCommand Command
-
-	scanner := bufio.NewScanner(reader)
-	scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize)
-	scanner.Split(scanModelfile)
-	for scanner.Scan() {
-		line := scanner.Bytes()
-
-		fields := bytes.SplitN(line, []byte(" "), 2)
-		if len(fields) == 0 || len(fields[0]) == 0 {
-			continue
-		}
-
-		switch string(bytes.ToUpper(fields[0])) {
-		case "FROM":
-			command.Name = "model"
-			command.Args = string(bytes.TrimSpace(fields[1]))
-			// copy command for validation
-			modelCommand = command
-		case "ADAPTER":
-			command.Name = string(bytes.ToLower(fields[0]))
-			command.Args = string(bytes.TrimSpace(fields[1]))
-		case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
-			command.Name = string(bytes.ToLower(fields[0]))
-			command.Args = string(fields[1])
-		case "PARAMETER":
-			fields = bytes.SplitN(fields[1], []byte(" "), 2)
-			if len(fields) < 2 {
-				return nil, fmt.Errorf("missing value for %s", fields)
-			}
+func Parse(r io.Reader) (cmds []Command, err error) {
+	var cmd Command
+	var curr state
+	var b bytes.Buffer
+	var role string
 
-			command.Name = string(fields[0])
-			command.Args = string(bytes.TrimSpace(fields[1]))
-		case "EMBED":
-			return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
-		case "MESSAGE":
-			command.Name = string(bytes.ToLower(fields[0]))
-			fields = bytes.SplitN(fields[1], []byte(" "), 2)
-			if len(fields) < 2 {
-				return nil, fmt.Errorf("should be in the format <role> <message>")
-			}
-			if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
-				return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
+	br := bufio.NewReader(r)
+	for {
+		r, _, err := br.ReadRune()
+		if errors.Is(err, io.EOF) {
+			break
+		} else if err != nil {
+			return nil, err
+		}
+
+		next, r, err := parseRuneForState(r, curr)
+		if errors.Is(err, io.ErrUnexpectedEOF) {
+			return nil, fmt.Errorf("%w: %s", err, b.String())
+		} else if err != nil {
+			return nil, err
+		}
+
+		// process the state transition, some transitions need to be intercepted and redirected
+		if next != curr {
+			switch curr {
+			case stateName:
+				if !isValidCommand(b.String()) {
+					return nil, errInvalidCommand
+				}
+
+				// next state sometimes depends on the current buffer value
+				switch s := strings.ToLower(b.String()); s {
+				case "from":
+					cmd.Name = "model"
+				case "parameter":
+					// transition to stateParameter which sets command name
+					next = stateParameter
+				case "message":
+					// transition to stateMessage which validates the message role
+					next = stateMessage
+					fallthrough
+				default:
+					cmd.Name = s
+				}
+			case stateParameter:
+				cmd.Name = b.String()
+			case stateMessage:
+				if !isValidMessageRole(b.String()) {
+					return nil, errInvalidMessageRole
+				}
+
+				role = b.String()
+			case stateComment, stateNil:
+				// pass
+			case stateValue:
+				s, ok := unquote(b.String())
+				if !ok || isSpace(r) {
+					if _, err := b.WriteRune(r); err != nil {
+						return nil, err
+					}
+
+					continue
+				}
+
+				if role != "" {
+					s = role + ": " + s
+					role = ""
+				}
+
+				cmd.Args = s
+				cmds = append(cmds, cmd)
 			}
-			command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
-		default:
-			if !bytes.HasPrefix(fields[0], []byte("#")) {
-				// log a warning for unknown commands
-				slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0]))
+
+			b.Reset()
+			curr = next
+		}
+
+		if strconv.IsPrint(r) {
+			if _, err := b.WriteRune(r); err != nil {
+				return nil, err
 			}
-			continue
+		}
+	}
+
+	// flush the buffer
+	switch curr {
+	case stateComment, stateNil:
+		// pass; nothing to flush
+	case stateValue:
+		s, ok := unquote(b.String())
+		if !ok {
+			return nil, io.ErrUnexpectedEOF
 		}
 
-		commands = append(commands, command)
-		command.Reset()
+		if role != "" {
+			s = role + ": " + s
+		}
+
+		cmd.Args = s
+		cmds = append(cmds, cmd)
+	default:
+		return nil, io.ErrUnexpectedEOF
 	}
 
-	if modelCommand.Args == "" {
-		return nil, errors.New("no FROM line for the model was specified")
+	for _, cmd := range cmds {
+		if cmd.Name == "model" {
+			return cmds, nil
+		}
 	}
 
-	return commands, scanner.Err()
+	return nil, errMissingFrom
 }
 
-func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) {
-	advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF)
-	if err != nil {
-		return 0, nil, err
+func parseRuneForState(r rune, cs state) (state, rune, error) {
+	switch cs {
+	case stateNil:
+		switch {
+		case r == '#':
+			return stateComment, 0, nil
+		case isSpace(r), isNewline(r):
+			return stateNil, 0, nil
+		default:
+			return stateName, r, nil
+		}
+	case stateName:
+		switch {
+		case isAlpha(r):
+			return stateName, r, nil
+		case isSpace(r):
+			return stateValue, 0, nil
+		default:
+			return stateNil, 0, errInvalidCommand
+		}
+	case stateValue:
+		switch {
+		case isNewline(r):
+			return stateNil, r, nil
+		case isSpace(r):
+			return stateNil, r, nil
+		default:
+			return stateValue, r, nil
+		}
+	case stateParameter:
+		switch {
+		case isAlpha(r), isNumber(r), r == '_':
+			return stateParameter, r, nil
+		case isSpace(r):
+			return stateValue, 0, nil
+		default:
+			return stateNil, 0, io.ErrUnexpectedEOF
+		}
+	case stateMessage:
+		switch {
+		case isAlpha(r):
+			return stateMessage, r, nil
+		case isSpace(r):
+			return stateValue, 0, nil
+		default:
+			return stateNil, 0, io.ErrUnexpectedEOF
+		}
+	case stateComment:
+		switch {
+		case isNewline(r):
+			return stateNil, 0, nil
+		default:
+			return stateComment, 0, nil
+		}
+	default:
+		return stateNil, 0, errors.New("")
+	}
+}
+
+func quote(s string) string {
+	if strings.Contains(s, "\n") || strings.HasPrefix(s, " ") || strings.HasSuffix(s, " ") {
+		if strings.Contains(s, "\"") {
+			return `"""` + s + `"""`
+		}
+
+		return `"` + s + `"`
 	}
 
-	if advance > 0 && token != nil {
-		return advance, token, nil
+	return s
+}
+
+func unquote(s string) (string, bool) {
+	if len(s) == 0 {
+		return "", false
 	}
 
-	advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF)
-	if err != nil {
-		return 0, nil, err
+	// TODO: single quotes
+	if len(s) >= 3 && s[:3] == `"""` {
+		if len(s) >= 6 && s[len(s)-3:] == `"""` {
+			return s[3 : len(s)-3], true
+		}
+
+		return "", false
 	}
 
-	if advance > 0 && token != nil {
-		return advance, token, nil
+	if len(s) >= 1 && s[0] == '"' {
+		if len(s) >= 2 && s[len(s)-1] == '"' {
+			return s[1 : len(s)-1], true
+		}
+
+		return "", false
 	}
 
-	return bufio.ScanLines(data, atEOF)
+	return s, true
 }
 
-func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) {
-	newline := bytes.IndexByte(data, '\n')
+func isAlpha(r rune) bool {
+	return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z'
+}
 
-	if start := bytes.Index(data, openBytes); start >= 0 && start < newline {
-		end := bytes.Index(data[start+len(openBytes):], closeBytes)
-		if end < 0 {
-			if atEOF {
-				return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes)
-			} else {
-				return 0, nil, nil
-			}
-		}
+func isNumber(r rune) bool {
+	return r >= '0' && r <= '9'
+}
 
-		n := start + len(openBytes) + end + len(closeBytes)
+func isSpace(r rune) bool {
+	return r == ' ' || r == '\t'
+}
 
-		newData := data[:start]
-		newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...)
-		return n, newData, nil
-	}
+func isNewline(r rune) bool {
+	return r == '\r' || r == '\n'
+}
+
+func isValidMessageRole(role string) bool {
+	return role == "system" || role == "user" || role == "assistant"
+}
 
-	return 0, nil, nil
+func isValidCommand(cmd string) bool {
+	switch strings.ToLower(cmd) {
+	case "from", "license", "template", "system", "adapter", "parameter", "message":
+		return true
+	default:
+		return false
+	}
 }

+ 437 - 34
parser/parser_test.go

@@ -1,14 +1,16 @@
 package parser
 
 import (
+	"bytes"
+	"fmt"
+	"io"
 	"strings"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
 )
 
-func Test_Parser(t *testing.T) {
-
+func TestParser(t *testing.T) {
 	input := `
 FROM model1
 ADAPTER adapter1
@@ -35,64 +37,465 @@ TEMPLATE template1
 	assert.Equal(t, expectedCommands, commands)
 }
 
-func Test_Parser_NoFromLine(t *testing.T) {
+func TestParserFrom(t *testing.T) {
+	var cases = []struct {
+		input    string
+		expected []Command
+		err      error
+	}{
+		{
+			"FROM foo",
+			[]Command{{Name: "model", Args: "foo"}},
+			nil,
+		},
+		{
+			"FROM /path/to/model",
+			[]Command{{Name: "model", Args: "/path/to/model"}},
+			nil,
+		},
+		{
+			"FROM /path/to/model/fp16.bin",
+			[]Command{{Name: "model", Args: "/path/to/model/fp16.bin"}},
+			nil,
+		},
+		{
+			"FROM llama3:latest",
+			[]Command{{Name: "model", Args: "llama3:latest"}},
+			nil,
+		},
+		{
+			"FROM llama3:7b-instruct-q4_K_M",
+			[]Command{{Name: "model", Args: "llama3:7b-instruct-q4_K_M"}},
+			nil,
+		},
+		{
+			"", nil, errMissingFrom,
+		},
+		{
+			"PARAMETER param1 value1",
+			nil,
+			errMissingFrom,
+		},
+		{
+			"PARAMETER param1 value1\nFROM foo",
+			[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
+			nil,
+		},
+	}
+
+	for _, c := range cases {
+		t.Run("", func(t *testing.T) {
+			commands, err := Parse(strings.NewReader(c.input))
+			assert.ErrorIs(t, err, c.err)
+			assert.Equal(t, c.expected, commands)
+		})
+	}
+}
 
+func TestParserParametersMissingValue(t *testing.T) {
 	input := `
-PARAMETER param1 value1
-PARAMETER param2 value2
+FROM foo
+PARAMETER param1
 `
 
 	reader := strings.NewReader(input)
 
 	_, err := Parse(reader)
-	assert.ErrorContains(t, err, "no FROM line")
+	assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
 }
 
-func Test_Parser_MissingValue(t *testing.T) {
-
+func TestParserBadCommand(t *testing.T) {
 	input := `
 FROM foo
-PARAMETER param1
+BADCOMMAND param1 value1
 `
-
-	reader := strings.NewReader(input)
-
-	_, err := Parse(reader)
-	assert.ErrorContains(t, err, "missing value for [param1]")
+	_, err := Parse(strings.NewReader(input))
+	assert.ErrorIs(t, err, errInvalidCommand)
 
 }
 
-func Test_Parser_Messages(t *testing.T) {
-
-	input := `
+func TestParserMessages(t *testing.T) {
+	var cases = []struct {
+		input    string
+		expected []Command
+		err      error
+	}{
+		{
+			`
+FROM foo
+MESSAGE system You are a Parser. Always Parse things.
+`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "message", Args: "system: You are a Parser. Always Parse things."},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+MESSAGE system You are a Parser. Always Parse things.`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "message", Args: "system: You are a Parser. Always Parse things."},
+			},
+			nil,
+		},
+		{
+			`
 FROM foo
 MESSAGE system You are a Parser. Always Parse things.
 MESSAGE user Hey there!
 MESSAGE assistant Hello, I want to parse all the things!
-`
+`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "message", Args: "system: You are a Parser. Always Parse things."},
+				{Name: "message", Args: "user: Hey there!"},
+				{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+MESSAGE system """
+You are a multiline Parser. Always Parse things.
+"""
+			`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "message", Args: "system: \nYou are a multiline Parser. Always Parse things.\n"},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+MESSAGE badguy I'm a bad guy!
+`,
+			nil,
+			errInvalidMessageRole,
+		},
+		{
+			`
+FROM foo
+MESSAGE system
+`,
+			nil,
+			io.ErrUnexpectedEOF,
+		},
+		{
+			`
+FROM foo
+MESSAGE system`,
+			nil,
+			io.ErrUnexpectedEOF,
+		},
+	}
 
-	reader := strings.NewReader(input)
-	commands, err := Parse(reader)
-	assert.Nil(t, err)
+	for _, c := range cases {
+		t.Run("", func(t *testing.T) {
+			commands, err := Parse(strings.NewReader(c.input))
+			assert.ErrorIs(t, err, c.err)
+			assert.Equal(t, c.expected, commands)
+		})
+	}
+}
 
-	expectedCommands := []Command{
-		{Name: "model", Args: "foo"},
-		{Name: "message", Args: "system: You are a Parser. Always Parse things."},
-		{Name: "message", Args: "user: Hey there!"},
-		{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
+func TestParserQuoted(t *testing.T) {
+	var cases = []struct {
+		multiline string
+		expected  []Command
+		err       error
+	}{
+		{
+			`
+FROM foo
+SYSTEM """
+This is a
+multiline system.
+"""
+			`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "system", Args: "\nThis is a\nmultiline system.\n"},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+SYSTEM """
+This is a
+multiline system."""
+			`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "system", Args: "\nThis is a\nmultiline system."},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+SYSTEM """This is a
+multiline system."""
+			`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "system", Args: "This is a\nmultiline system."},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+SYSTEM """This is a multiline system."""
+			`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "system", Args: "This is a multiline system."},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+SYSTEM """This is a multiline system.""
+			`,
+			nil,
+			io.ErrUnexpectedEOF,
+		},
+		{
+			`
+FROM foo
+SYSTEM "
+			`,
+			nil,
+			io.ErrUnexpectedEOF,
+		},
+		{
+			`
+FROM foo
+SYSTEM """
+This is a multiline system with "quotes".
+"""
+`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "system", Args: "\nThis is a multiline system with \"quotes\".\n"},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+SYSTEM """"""
+`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "system", Args: ""},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+SYSTEM ""
+`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "system", Args: ""},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+SYSTEM "'"
+`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "system", Args: "'"},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+SYSTEM """''"'""'""'"'''''""'""'"""
+`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "system", Args: `''"'""'""'"'''''""'""'`},
+			},
+			nil,
+		},
+		{
+			`
+FROM foo
+TEMPLATE """
+{{ .Prompt }}
+"""`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+				{Name: "template", Args: "\n{{ .Prompt }}\n"},
+			},
+			nil,
+		},
 	}
 
-	assert.Equal(t, expectedCommands, commands)
+	for _, c := range cases {
+		t.Run("", func(t *testing.T) {
+			commands, err := Parse(strings.NewReader(c.multiline))
+			assert.ErrorIs(t, err, c.err)
+			assert.Equal(t, c.expected, commands)
+		})
+	}
 }
 
-func Test_Parser_Messages_BadRole(t *testing.T) {
+func TestParserParameters(t *testing.T) {
+	var cases = map[string]struct {
+		name, value string
+	}{
+		"numa true":                    {"numa", "true"},
+		"num_ctx 1":                    {"num_ctx", "1"},
+		"num_batch 1":                  {"num_batch", "1"},
+		"num_gqa 1":                    {"num_gqa", "1"},
+		"num_gpu 1":                    {"num_gpu", "1"},
+		"main_gpu 1":                   {"main_gpu", "1"},
+		"low_vram true":                {"low_vram", "true"},
+		"f16_kv true":                  {"f16_kv", "true"},
+		"logits_all true":              {"logits_all", "true"},
+		"vocab_only true":              {"vocab_only", "true"},
+		"use_mmap true":                {"use_mmap", "true"},
+		"use_mlock true":               {"use_mlock", "true"},
+		"num_thread 1":                 {"num_thread", "1"},
+		"num_keep 1":                   {"num_keep", "1"},
+		"seed 1":                       {"seed", "1"},
+		"num_predict 1":                {"num_predict", "1"},
+		"top_k 1":                      {"top_k", "1"},
+		"top_p 1.0":                    {"top_p", "1.0"},
+		"tfs_z 1.0":                    {"tfs_z", "1.0"},
+		"typical_p 1.0":                {"typical_p", "1.0"},
+		"repeat_last_n 1":              {"repeat_last_n", "1"},
+		"temperature 1.0":              {"temperature", "1.0"},
+		"repeat_penalty 1.0":           {"repeat_penalty", "1.0"},
+		"presence_penalty 1.0":         {"presence_penalty", "1.0"},
+		"frequency_penalty 1.0":        {"frequency_penalty", "1.0"},
+		"mirostat 1":                   {"mirostat", "1"},
+		"mirostat_tau 1.0":             {"mirostat_tau", "1.0"},
+		"mirostat_eta 1.0":             {"mirostat_eta", "1.0"},
+		"penalize_newline true":        {"penalize_newline", "true"},
+		"stop ### User:":               {"stop", "### User:"},
+		"stop ### User: ":              {"stop", "### User: "},
+		"stop \"### User:\"":           {"stop", "### User:"},
+		"stop \"### User: \"":          {"stop", "### User: "},
+		"stop \"\"\"### User:\"\"\"":   {"stop", "### User:"},
+		"stop \"\"\"### User:\n\"\"\"": {"stop", "### User:\n"},
+		"stop <|endoftext|>":           {"stop", "<|endoftext|>"},
+		"stop <|eot_id|>":              {"stop", "<|eot_id|>"},
+		"stop </s>":                    {"stop", "</s>"},
+	}
 
-	input := `
+	for k, v := range cases {
+		t.Run(k, func(t *testing.T) {
+			var b bytes.Buffer
+			fmt.Fprintln(&b, "FROM foo")
+			fmt.Fprintln(&b, "PARAMETER", k)
+			commands, err := Parse(&b)
+			assert.Nil(t, err)
+
+			assert.Equal(t, []Command{
+				{Name: "model", Args: "foo"},
+				{Name: v.name, Args: v.value},
+			}, commands)
+		})
+	}
+}
+
+func TestParserComments(t *testing.T) {
+	var cases = []struct {
+		input    string
+		expected []Command
+	}{
+		{
+			`
+# comment
 FROM foo
-MESSAGE badguy I'm a bad guy!
-`
+	`,
+			[]Command{
+				{Name: "model", Args: "foo"},
+			},
+		},
+	}
+
+	for _, c := range cases {
+		t.Run("", func(t *testing.T) {
+			commands, err := Parse(strings.NewReader(c.input))
+			assert.Nil(t, err)
+			assert.Equal(t, c.expected, commands)
+		})
+	}
+}
+
+func TestParseFormatParse(t *testing.T) {
+	var cases = []string{
+		`
+FROM foo
+ADAPTER adapter1
+LICENSE MIT
+PARAMETER param1 value1
+PARAMETER param2 value2
+TEMPLATE template1
+MESSAGE system You are a Parser. Always Parse things.
+MESSAGE user Hey there!
+MESSAGE assistant Hello, I want to parse all the things!
+`,
+		`
+FROM foo
+ADAPTER adapter1
+LICENSE MIT
+PARAMETER param1 value1
+PARAMETER param2 value2
+TEMPLATE template1
+MESSAGE system """
+You are a store greeter. Always responsed with "Hello!".
+"""
+MESSAGE user Hey there!
+MESSAGE assistant Hello, I want to parse all the things!
+`,
+		`
+FROM foo
+ADAPTER adapter1
+LICENSE """
+Very long and boring legal text.
+Blah blah blah.
+"Oh look, a quote!"
+"""
+
+PARAMETER param1 value1
+PARAMETER param2 value2
+TEMPLATE template1
+MESSAGE system """
+You are a store greeter. Always responsed with "Hello!".
+"""
+MESSAGE user Hey there!
+MESSAGE assistant Hello, I want to parse all the things!
+`,
+	}
+
+	for _, c := range cases {
+		t.Run("", func(t *testing.T) {
+			commands, err := Parse(strings.NewReader(c))
+			assert.NoError(t, err)
+
+			commands2, err := Parse(strings.NewReader(Format(commands)))
+			assert.NoError(t, err)
+
+			assert.Equal(t, commands, commands2)
+		})
+	}
 
-	reader := strings.NewReader(input)
-	_, err := Parse(reader)
-	assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
 }

+ 42 - 62
server/images.go

@@ -21,7 +21,6 @@ import (
 	"runtime"
 	"strconv"
 	"strings"
-	"text/template"
 
 	"golang.org/x/exp/slices"
 
@@ -64,6 +63,48 @@ func (m *Model) IsEmbedding() bool {
 	return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
 }
 
+func (m *Model) Commands() (cmds []parser.Command) {
+	cmds = append(cmds, parser.Command{Name: "model", Args: m.ModelPath})
+
+	if m.Template != "" {
+		cmds = append(cmds, parser.Command{Name: "template", Args: m.Template})
+	}
+
+	if m.System != "" {
+		cmds = append(cmds, parser.Command{Name: "system", Args: m.System})
+	}
+
+	for _, adapter := range m.AdapterPaths {
+		cmds = append(cmds, parser.Command{Name: "adapter", Args: adapter})
+	}
+
+	for _, projector := range m.ProjectorPaths {
+		cmds = append(cmds, parser.Command{Name: "projector", Args: projector})
+	}
+
+	for k, v := range m.Options {
+		switch v := v.(type) {
+		case []any:
+			for _, s := range v {
+				cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", s)})
+			}
+		default:
+			cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", v)})
+		}
+	}
+
+	for _, license := range m.License {
+		cmds = append(cmds, parser.Command{Name: "license", Args: license})
+	}
+
+	for _, msg := range m.Messages {
+		cmds = append(cmds, parser.Command{Name: "message", Args: fmt.Sprintf("%s %s", msg.Role, msg.Content)})
+	}
+
+	return cmds
+
+}
+
 type Message struct {
 	Role    string `json:"role"`
 	Content string `json:"content"`
@@ -901,67 +942,6 @@ func DeleteModel(name string) error {
 	return nil
 }
 
-func ShowModelfile(model *Model) (string, error) {
-	var mt struct {
-		*Model
-		From       string
-		Parameters map[string][]any
-	}
-
-	mt.Parameters = make(map[string][]any)
-	for k, v := range model.Options {
-		if s, ok := v.([]any); ok {
-			mt.Parameters[k] = s
-			continue
-		}
-
-		mt.Parameters[k] = []any{v}
-	}
-
-	mt.Model = model
-	mt.From = model.ModelPath
-
-	if model.ParentModel != "" {
-		mt.From = model.ParentModel
-	}
-
-	modelFile := `# Modelfile generated by "ollama show"
-# To build a new Modelfile based on this one, replace the FROM line with:
-# FROM {{ .ShortName }}
-
-FROM {{ .From }}
-TEMPLATE """{{ .Template }}"""
-
-{{- if .System }}
-SYSTEM """{{ .System }}"""
-{{- end }}
-
-{{- range $adapter := .AdapterPaths }}
-ADAPTER {{ $adapter }}
-{{- end }}
-
-{{- range $k, $v := .Parameters }}
-{{- range $parameter := $v }}
-PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
-{{- end }}
-{{- end }}`
-
-	tmpl, err := template.New("").Parse(modelFile)
-	if err != nil {
-		slog.Info(fmt.Sprintf("error parsing template: %q", err))
-		return "", err
-	}
-
-	var buf bytes.Buffer
-
-	if err = tmpl.Execute(&buf, mt); err != nil {
-		slog.Info(fmt.Sprintf("error executing template: %q", err))
-		return "", err
-	}
-
-	return buf.String(), nil
-}
-
 func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
 	mp := ParseModelPath(name)
 	fn(api.ProgressResponse{Status: "retrieving manifest"})

+ 6 - 6
server/routes.go

@@ -728,12 +728,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 		}
 	}
 
-	mf, err := ShowModelfile(model)
-	if err != nil {
-		return nil, err
-	}
-
-	resp.Modelfile = mf
+	var sb strings.Builder
+	fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"")
+	fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
+	fmt.Fprintf(&sb, "# FROM %s\n\n", model.ShortName)
+	fmt.Fprint(&sb, parser.Format(model.Commands()))
+	resp.Modelfile = sb.String()
 
 	return resp, nil
 }

+ 0 - 1
server/routes_test.go

@@ -238,6 +238,5 @@ func Test_Routes(t *testing.T) {
 		if tc.Expected != nil {
 			tc.Expected(t, resp)
 		}
-
 	}
 }