Browse Source

Merge pull request #4059 from ollama/mxyng/parser-2

rename parser to model/file
Michael Yang 1 year ago
parent
commit
b7a87a22b6
6 changed files with 151 additions and 124 deletions
  1. 9 10
      cmd/cmd.go
  2. 42 15
      server/images.go
  3. 13 24
      server/routes.go
  4. 4 4
      server/routes_test.go
  5. 38 32
      types/model/file.go
  6. 45 39
      types/model/file_test.go

+ 9 - 10
cmd/cmd.go

@@ -34,7 +34,6 @@ import (
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/auth"
 	"github.com/ollama/ollama/format"
-	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/progress"
 	"github.com/ollama/ollama/server"
 	"github.com/ollama/ollama/types/errtypes"
@@ -57,13 +56,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	p := progress.NewProgress(os.Stderr)
 	defer p.Stop()
 
-	modelfile, err := os.Open(filename)
+	f, err := os.Open(filename)
 	if err != nil {
 		return err
 	}
-	defer modelfile.Close()
+	defer f.Close()
 
-	commands, err := parser.Parse(modelfile)
+	modelfile, err := model.ParseFile(f)
 	if err != nil {
 		return err
 	}
@@ -77,10 +76,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	spinner := progress.NewSpinner(status)
 	p.Add(status, spinner)
 
-	for i := range commands {
-		switch commands[i].Name {
+	for i := range modelfile.Commands {
+		switch modelfile.Commands[i].Name {
 		case "model", "adapter":
-			path := commands[i].Args
+			path := modelfile.Commands[i].Args
 			if path == "~" {
 				path = home
 			} else if strings.HasPrefix(path, "~/") {
@@ -92,7 +91,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 			}
 
 			fi, err := os.Stat(path)
-			if errors.Is(err, os.ErrNotExist) && commands[i].Name == "model" {
+			if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
 				continue
 			} else if err != nil {
 				return err
@@ -115,7 +114,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 				return err
 			}
 
-			commands[i].Args = "@"+digest
+			modelfile.Commands[i].Args = "@" + digest
 		}
 	}
 
@@ -145,7 +144,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 
 	quantization, _ := cmd.Flags().GetString("quantization")
 
-	request := api.CreateRequest{Name: args[0], Modelfile: parser.Format(commands), Quantization: quantization}
+	request := api.CreateRequest{Name: args[0], Modelfile: modelfile.String(), Quantization: quantization}
 	if err := client.Create(cmd.Context(), &request, fn); err != nil {
 		return err
 	}

+ 42 - 15
server/images.go

@@ -29,7 +29,6 @@ import (
 	"github.com/ollama/ollama/convert"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/llm"
-	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/types/errtypes"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
@@ -63,46 +62,74 @@ func (m *Model) IsEmbedding() bool {
 	return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
 }
 
-func (m *Model) Commands() (cmds []parser.Command) {
-	cmds = append(cmds, parser.Command{Name: "model", Args: m.ModelPath})
+func (m *Model) String() string {
+	var modelfile model.File
+
+	modelfile.Commands = append(modelfile.Commands, model.Command{
+		Name: "model",
+		Args: m.ModelPath,
+	})
 
 	if m.Template != "" {
-		cmds = append(cmds, parser.Command{Name: "template", Args: m.Template})
+		modelfile.Commands = append(modelfile.Commands, model.Command{
+			Name: "template",
+			Args: m.Template,
+		})
 	}
 
 	if m.System != "" {
-		cmds = append(cmds, parser.Command{Name: "system", Args: m.System})
+		modelfile.Commands = append(modelfile.Commands, model.Command{
+			Name: "system",
+			Args: m.System,
+		})
 	}
 
 	for _, adapter := range m.AdapterPaths {
-		cmds = append(cmds, parser.Command{Name: "adapter", Args: adapter})
+		modelfile.Commands = append(modelfile.Commands, model.Command{
+			Name: "adapter",
+			Args: adapter,
+		})
 	}
 
 	for _, projector := range m.ProjectorPaths {
-		cmds = append(cmds, parser.Command{Name: "projector", Args: projector})
+		modelfile.Commands = append(modelfile.Commands, model.Command{
+			Name: "projector",
+			Args: projector,
+		})
 	}
 
 	for k, v := range m.Options {
 		switch v := v.(type) {
 		case []any:
 			for _, s := range v {
-				cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", s)})
+				modelfile.Commands = append(modelfile.Commands, model.Command{
+					Name: k,
+					Args: fmt.Sprintf("%v", s),
+				})
 			}
 		default:
-			cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", v)})
+			modelfile.Commands = append(modelfile.Commands, model.Command{
+				Name: k,
+				Args: fmt.Sprintf("%v", v),
+			})
 		}
 	}
 
 	for _, license := range m.License {
-		cmds = append(cmds, parser.Command{Name: "license", Args: license})
+		modelfile.Commands = append(modelfile.Commands, model.Command{
+			Name: "license",
+			Args: license,
+		})
 	}
 
 	for _, msg := range m.Messages {
-		cmds = append(cmds, parser.Command{Name: "message", Args: fmt.Sprintf("%s %s", msg.Role, msg.Content)})
+		modelfile.Commands = append(modelfile.Commands, model.Command{
+			Name: "message",
+			Args: fmt.Sprintf("%s %s", msg.Role, msg.Content),
+		})
 	}
 
-	return cmds
-
+	return modelfile.String()
 }
 
 type Message struct {
@@ -329,7 +356,7 @@ func realpath(mfDir, from string) string {
 	return abspath
 }
 
-func CreateModel(ctx context.Context, name, modelFileDir, quantization string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
+func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) error {
 	deleteMap := make(map[string]struct{})
 	if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
 		for _, layer := range append(manifest.Layers, manifest.Config) {
@@ -351,7 +378,7 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
 	params := make(map[string][]string)
 	fromParams := make(map[string]any)
 
-	for _, c := range commands {
+	for _, c := range modelfile.Commands {
 		mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
 
 		switch c.Name {

+ 13 - 24
server/routes.go

@@ -1,6 +1,7 @@
 package server
 
 import (
+	"cmp"
 	"context"
 	"encoding/json"
 	"errors"
@@ -28,7 +29,6 @@ import (
 	"github.com/ollama/ollama/gpu"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/openai"
-	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
 )
@@ -522,28 +522,17 @@ func (s *Server) PushModelHandler(c *gin.Context) {
 
 func (s *Server) CreateModelHandler(c *gin.Context) {
 	var req api.CreateRequest
-	err := c.ShouldBindJSON(&req)
-	switch {
-	case errors.Is(err, io.EOF):
+	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
 		return
-	case err != nil:
+	} else if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
 
-	var model string
-	if req.Model != "" {
-		model = req.Model
-	} else if req.Name != "" {
-		model = req.Name
-	} else {
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
-		return
-	}
-
-	if err := ParseModelPath(model).Validate(); err != nil {
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+	name := model.ParseName(cmp.Or(req.Model, req.Name))
+	if !name.IsValid() {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
 		return
 	}
 
@@ -552,19 +541,19 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
 		return
 	}
 
-	var modelfile io.Reader = strings.NewReader(req.Modelfile)
+	var r io.Reader = strings.NewReader(req.Modelfile)
 	if req.Path != "" && req.Modelfile == "" {
-		mf, err := os.Open(req.Path)
+		f, err := os.Open(req.Path)
 		if err != nil {
 			c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
 			return
 		}
-		defer mf.Close()
+		defer f.Close()
 
-		modelfile = mf
+		r = f
 	}
 
-	commands, err := parser.Parse(modelfile)
+	modelfile, err := model.ParseFile(r)
 	if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
@@ -580,7 +569,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
 		ctx, cancel := context.WithCancel(c.Request.Context())
 		defer cancel()
 
-		if err := CreateModel(ctx, model, filepath.Dir(req.Path), req.Quantization, commands, fn); err != nil {
+		if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), req.Quantization, modelfile, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}
 	}()
@@ -732,7 +721,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 	fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"")
 	fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
 	fmt.Fprintf(&sb, "# FROM %s\n\n", model.ShortName)
-	fmt.Fprint(&sb, parser.Format(model.Commands()))
+	fmt.Fprint(&sb, model.String())
 	resp.Modelfile = sb.String()
 
 	return resp, nil

+ 4 - 4
server/routes_test.go

@@ -17,7 +17,7 @@ import (
 	"github.com/stretchr/testify/assert"
 
 	"github.com/ollama/ollama/api"
-	"github.com/ollama/ollama/parser"
+	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
 )
 
@@ -55,13 +55,13 @@ func Test_Routes(t *testing.T) {
 	createTestModel := func(t *testing.T, name string) {
 		fname := createTestFile(t, "ollama-model")
 
-		modelfile := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
-		commands, err := parser.Parse(modelfile)
+		r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
+		modelfile, err := model.ParseFile(r)
 		assert.Nil(t, err)
 		fn := func(resp api.ProgressResponse) {
 			t.Logf("Status: %s", resp.Status)
 		}
-		err = CreateModel(context.TODO(), name, "", "", commands, fn)
+		err = CreateModel(context.TODO(), name, "", "", modelfile, fn)
 		assert.Nil(t, err)
 	}
 

+ 38 - 32
parser/parser.go → types/model/file.go

@@ -1,4 +1,4 @@
-package parser
+package model
 
 import (
 	"bufio"
@@ -10,11 +10,41 @@ import (
 	"strings"
 )
 
+type File struct {
+	Commands []Command
+}
+
+func (f File) String() string {
+	var sb strings.Builder
+	for _, cmd := range f.Commands {
+		fmt.Fprintln(&sb, cmd.String())
+	}
+
+	return sb.String()
+}
+
 type Command struct {
 	Name string
 	Args string
 }
 
+func (c Command) String() string {
+	var sb strings.Builder
+	switch c.Name {
+	case "model":
+		fmt.Fprintf(&sb, "FROM %s", c.Args)
+	case "license", "template", "system", "adapter":
+		fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
+	case "message":
+		role, message, _ := strings.Cut(c.Args, ": ")
+		fmt.Fprintf(&sb, "MESSAGE %s %s", role, quote(message))
+	default:
+		fmt.Fprintf(&sb, "PARAMETER %s %s", c.Name, quote(c.Args))
+	}
+
+	return sb.String()
+}
+
 type state int
 
 const (
@@ -32,38 +62,14 @@ var (
 	errInvalidCommand     = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
 )
 
-func Format(cmds []Command) string {
-	var sb strings.Builder
-	for _, cmd := range cmds {
-		name := cmd.Name
-		args := cmd.Args
-
-		switch cmd.Name {
-		case "model":
-			name = "from"
-			args = cmd.Args
-		case "license", "template", "system", "adapter":
-			args = quote(args)
-		case "message":
-			role, message, _ := strings.Cut(cmd.Args, ": ")
-			args = role + " " + quote(message)
-		default:
-			name = "parameter"
-			args = cmd.Name + " " + quote(cmd.Args)
-		}
-
-		fmt.Fprintln(&sb, strings.ToUpper(name), args)
-	}
-
-	return sb.String()
-}
-
-func Parse(r io.Reader) (cmds []Command, err error) {
+func ParseFile(r io.Reader) (*File, error) {
 	var cmd Command
 	var curr state
 	var b bytes.Buffer
 	var role string
 
+	var f File
+
 	br := bufio.NewReader(r)
 	for {
 		r, _, err := br.ReadRune()
@@ -128,7 +134,7 @@ func Parse(r io.Reader) (cmds []Command, err error) {
 				}
 
 				cmd.Args = s
-				cmds = append(cmds, cmd)
+				f.Commands = append(f.Commands, cmd)
 			}
 
 			b.Reset()
@@ -157,14 +163,14 @@ func Parse(r io.Reader) (cmds []Command, err error) {
 		}
 
 		cmd.Args = s
-		cmds = append(cmds, cmd)
+		f.Commands = append(f.Commands, cmd)
 	default:
 		return nil, io.ErrUnexpectedEOF
 	}
 
-	for _, cmd := range cmds {
+	for _, cmd := range f.Commands {
 		if cmd.Name == "model" {
-			return cmds, nil
+			return &f, nil
 		}
 	}
 

+ 45 - 39
parser/parser_test.go → types/model/file_test.go

@@ -1,4 +1,4 @@
-package parser
+package model
 
 import (
 	"bytes"
@@ -10,7 +10,7 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
-func TestParser(t *testing.T) {
+func TestParseFileFile(t *testing.T) {
 	input := `
 FROM model1
 ADAPTER adapter1
@@ -22,8 +22,8 @@ TEMPLATE template1
 
 	reader := strings.NewReader(input)
 
-	commands, err := Parse(reader)
-	assert.Nil(t, err)
+	modelfile, err := ParseFile(reader)
+	assert.NoError(t, err)
 
 	expectedCommands := []Command{
 		{Name: "model", Args: "model1"},
@@ -34,10 +34,10 @@ TEMPLATE template1
 		{Name: "template", Args: "template1"},
 	}
 
-	assert.Equal(t, expectedCommands, commands)
+	assert.Equal(t, expectedCommands, modelfile.Commands)
 }
 
-func TestParserFrom(t *testing.T) {
+func TestParseFileFrom(t *testing.T) {
 	var cases = []struct {
 		input    string
 		expected []Command
@@ -85,14 +85,16 @@ func TestParserFrom(t *testing.T) {
 
 	for _, c := range cases {
 		t.Run("", func(t *testing.T) {
-			commands, err := Parse(strings.NewReader(c.input))
+			modelfile, err := ParseFile(strings.NewReader(c.input))
 			assert.ErrorIs(t, err, c.err)
-			assert.Equal(t, c.expected, commands)
+			if modelfile != nil {
+				assert.Equal(t, c.expected, modelfile.Commands)
+			}
 		})
 	}
 }
 
-func TestParserParametersMissingValue(t *testing.T) {
+func TestParseFileParametersMissingValue(t *testing.T) {
 	input := `
 FROM foo
 PARAMETER param1
@@ -100,21 +102,21 @@ PARAMETER param1
 
 	reader := strings.NewReader(input)
 
-	_, err := Parse(reader)
+	_, err := ParseFile(reader)
 	assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
 }
 
-func TestParserBadCommand(t *testing.T) {
+func TestParseFileBadCommand(t *testing.T) {
 	input := `
 FROM foo
 BADCOMMAND param1 value1
 `
-	_, err := Parse(strings.NewReader(input))
+	_, err := ParseFile(strings.NewReader(input))
 	assert.ErrorIs(t, err, errInvalidCommand)
 
 }
 
-func TestParserMessages(t *testing.T) {
+func TestParseFileMessages(t *testing.T) {
 	var cases = []struct {
 		input    string
 		expected []Command
@@ -123,34 +125,34 @@ func TestParserMessages(t *testing.T) {
 		{
 			`
 FROM foo
-MESSAGE system You are a Parser. Always Parse things.
+MESSAGE system You are a file parser. Always parse things.
 `,
 			[]Command{
 				{Name: "model", Args: "foo"},
-				{Name: "message", Args: "system: You are a Parser. Always Parse things."},
+				{Name: "message", Args: "system: You are a file parser. Always parse things."},
 			},
 			nil,
 		},
 		{
 			`
 FROM foo
-MESSAGE system You are a Parser. Always Parse things.`,
+MESSAGE system You are a file parser. Always parse things.`,
 			[]Command{
 				{Name: "model", Args: "foo"},
-				{Name: "message", Args: "system: You are a Parser. Always Parse things."},
+				{Name: "message", Args: "system: You are a file parser. Always parse things."},
 			},
 			nil,
 		},
 		{
 			`
 FROM foo
-MESSAGE system You are a Parser. Always Parse things.
+MESSAGE system You are a file parser. Always parse things.
 MESSAGE user Hey there!
 MESSAGE assistant Hello, I want to parse all the things!
 `,
 			[]Command{
 				{Name: "model", Args: "foo"},
-				{Name: "message", Args: "system: You are a Parser. Always Parse things."},
+				{Name: "message", Args: "system: You are a file parser. Always parse things."},
 				{Name: "message", Args: "user: Hey there!"},
 				{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
 			},
@@ -160,12 +162,12 @@ MESSAGE assistant Hello, I want to parse all the things!
 			`
 FROM foo
 MESSAGE system """
-You are a multiline Parser. Always Parse things.
+You are a multiline file parser. Always parse things.
 """
 			`,
 			[]Command{
 				{Name: "model", Args: "foo"},
-				{Name: "message", Args: "system: \nYou are a multiline Parser. Always Parse things.\n"},
+				{Name: "message", Args: "system: \nYou are a multiline file parser. Always parse things.\n"},
 			},
 			nil,
 		},
@@ -196,14 +198,16 @@ MESSAGE system`,
 
 	for _, c := range cases {
 		t.Run("", func(t *testing.T) {
-			commands, err := Parse(strings.NewReader(c.input))
+			modelfile, err := ParseFile(strings.NewReader(c.input))
 			assert.ErrorIs(t, err, c.err)
-			assert.Equal(t, c.expected, commands)
+			if modelfile != nil {
+				assert.Equal(t, c.expected, modelfile.Commands)
+			}
 		})
 	}
 }
 
-func TestParserQuoted(t *testing.T) {
+func TestParseFileQuoted(t *testing.T) {
 	var cases = []struct {
 		multiline string
 		expected  []Command
@@ -348,14 +352,16 @@ TEMPLATE """
 
 	for _, c := range cases {
 		t.Run("", func(t *testing.T) {
-			commands, err := Parse(strings.NewReader(c.multiline))
+			modelfile, err := ParseFile(strings.NewReader(c.multiline))
 			assert.ErrorIs(t, err, c.err)
-			assert.Equal(t, c.expected, commands)
+			if modelfile != nil {
+				assert.Equal(t, c.expected, modelfile.Commands)
+			}
 		})
 	}
 }
 
-func TestParserParameters(t *testing.T) {
+func TestParseFileParameters(t *testing.T) {
 	var cases = map[string]struct {
 		name, value string
 	}{
@@ -404,18 +410,18 @@ func TestParserParameters(t *testing.T) {
 			var b bytes.Buffer
 			fmt.Fprintln(&b, "FROM foo")
 			fmt.Fprintln(&b, "PARAMETER", k)
-			commands, err := Parse(&b)
-			assert.Nil(t, err)
+			modelfile, err := ParseFile(&b)
+			assert.NoError(t, err)
 
 			assert.Equal(t, []Command{
 				{Name: "model", Args: "foo"},
 				{Name: v.name, Args: v.value},
-			}, commands)
+			}, modelfile.Commands)
 		})
 	}
 }
 
-func TestParserComments(t *testing.T) {
+func TestParseFileComments(t *testing.T) {
 	var cases = []struct {
 		input    string
 		expected []Command
@@ -433,14 +439,14 @@ FROM foo
 
 	for _, c := range cases {
 		t.Run("", func(t *testing.T) {
-			commands, err := Parse(strings.NewReader(c.input))
-			assert.Nil(t, err)
-			assert.Equal(t, c.expected, commands)
+			modelfile, err := ParseFile(strings.NewReader(c.input))
+			assert.NoError(t, err)
+			assert.Equal(t, c.expected, modelfile.Commands)
 		})
 	}
 }
 
-func TestParseFormatParse(t *testing.T) {
+func TestParseFileFormatParseFile(t *testing.T) {
 	var cases = []string{
 		`
 FROM foo
@@ -449,7 +455,7 @@ LICENSE MIT
 PARAMETER param1 value1
 PARAMETER param2 value2
 TEMPLATE template1
-MESSAGE system You are a Parser. Always Parse things.
+MESSAGE system You are a file parser. Always parse things.
 MESSAGE user Hey there!
 MESSAGE assistant Hello, I want to parse all the things!
 `,
@@ -488,13 +494,13 @@ MESSAGE assistant Hello, I want to parse all the things!
 
 	for _, c := range cases {
 		t.Run("", func(t *testing.T) {
-			commands, err := Parse(strings.NewReader(c))
+			modelfile, err := ParseFile(strings.NewReader(c))
 			assert.NoError(t, err)
 
-			commands2, err := Parse(strings.NewReader(Format(commands)))
+			modelfile2, err := ParseFile(strings.NewReader(modelfile.String()))
 			assert.NoError(t, err)
 
-			assert.Equal(t, commands, commands2)
+			assert.Equal(t, modelfile, modelfile2)
 		})
 	}