Forráskód Böngészése

Save and load sessions (#2063)

Patrick Devine 1 éve
szülő
commit
7c40a67841
8 módosított fájl, 312 hozzáadás és 39 törlés
  1. 2 0
      api/types.go
  2. 11 9
      cmd/cmd.go
  3. 127 24
      cmd/interactive.go
  4. 65 0
      cmd/interactive_test.go
  5. 11 0
      parser/parser.go
  6. 35 0
      parser/parser_test.go
  7. 47 5
      server/images.go
  8. 14 1
      server/routes.go

+ 2 - 0
api/types.go

@@ -171,6 +171,7 @@ type ShowResponse struct {
 	Template   string       `json:"template,omitempty"`
 	System     string       `json:"system,omitempty"`
 	Details    ModelDetails `json:"details,omitempty"`
+	Messages   []Message    `json:"messages,omitempty"`
 }
 
 type CopyRequest struct {
@@ -236,6 +237,7 @@ type GenerateResponse struct {
 }
 
 type ModelDetails struct {
+	ParentModel       string   `json:"parent_model"`
 	Format            string   `json:"format"`
 	Family            string   `json:"family"`
 	Families          []string `json:"families"`

+ 11 - 9
cmd/cmd.go

@@ -458,15 +458,17 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
 type generateContextKey string
 
 type runOptions struct {
-	Model    string
-	Prompt   string
-	Messages []api.Message
-	WordWrap bool
-	Format   string
-	System   string
-	Template string
-	Images   []api.ImageData
-	Options  map[string]interface{}
+	Model       string
+	ParentModel string
+	Prompt      string
+	Messages    []api.Message
+	WordWrap    bool
+	Format      string
+	System      string
+	Template    string
+	Images      []api.ImageData
+	Options     map[string]interface{}
+	MultiModal  bool
 }
 
 type displayResponseState struct {

+ 127 - 24
cmd/interactive.go

@@ -7,12 +7,14 @@ import (
 	"net/http"
 	"os"
 	"regexp"
+	"sort"
 	"strings"
 
 	"github.com/spf13/cobra"
 	"golang.org/x/exp/slices"
 
 	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/progress"
 	"github.com/jmorganca/ollama/readline"
 )
 
@@ -25,43 +27,75 @@ const (
 	MultilineTemplate
 )
 
-func modelIsMultiModal(cmd *cobra.Command, name string) bool {
-	// get model details
+func loadModel(cmd *cobra.Command, opts *runOptions) error {
 	client, err := api.ClientFromEnvironment()
 	if err != nil {
-		fmt.Println("error: couldn't connect to ollama server")
-		return false
+		return err
 	}
 
-	req := api.ShowRequest{Name: name}
-	resp, err := client.Show(cmd.Context(), &req)
+	p := progress.NewProgress(os.Stderr)
+	defer p.StopAndClear()
+
+	spinner := progress.NewSpinner("")
+	p.Add("", spinner)
+
+	showReq := api.ShowRequest{Name: opts.Model}
+	showResp, err := client.Show(cmd.Context(), &showReq)
 	if err != nil {
-		return false
+		return err
 	}
+	opts.MultiModal = slices.Contains(showResp.Details.Families, "clip")
+	opts.ParentModel = showResp.Details.ParentModel
 
-	return slices.Contains(resp.Details.Families, "clip")
-}
-
-func generateInteractive(cmd *cobra.Command, opts runOptions) error {
-	multiModal := modelIsMultiModal(cmd, opts.Model)
+	if len(showResp.Messages) > 0 {
+		opts.Messages = append(opts.Messages, showResp.Messages...)
+	}
 
-	// load the model
-	loadOpts := runOptions{
+	chatReq := &api.ChatRequest{
 		Model:    opts.Model,
-		Prompt:   "",
 		Messages: []api.Message{},
 	}
-	if _, err := chat(cmd, loadOpts); err != nil {
+	err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
+		p.StopAndClear()
+		if len(opts.Messages) > 0 {
+			for _, msg := range opts.Messages {
+				switch msg.Role {
+				case "user":
+					fmt.Printf(">>> %s\n", msg.Content)
+				case "assistant":
+					state := &displayResponseState{}
+					displayResponse(msg.Content, opts.WordWrap, state)
+					fmt.Println()
+					fmt.Println()
+				}
+			}
+		}
+		return nil
+	})
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func generateInteractive(cmd *cobra.Command, opts runOptions) error {
+	opts.Messages = make([]api.Message, 0)
+
+	err := loadModel(cmd, &opts)
+	if err != nil {
 		return err
 	}
 
 	usage := func() {
 		fmt.Fprintln(os.Stderr, "Available Commands:")
-		fmt.Fprintln(os.Stderr, "  /set          Set session variables")
-		fmt.Fprintln(os.Stderr, "  /show         Show model information")
-		fmt.Fprintln(os.Stderr, "  /bye          Exit")
-		fmt.Fprintln(os.Stderr, "  /?, /help     Help for a command")
-		fmt.Fprintln(os.Stderr, "  /? shortcuts  Help for keyboard shortcuts")
+		fmt.Fprintln(os.Stderr, "  /set            Set session variables")
+		fmt.Fprintln(os.Stderr, "  /show           Show model information")
+		fmt.Fprintln(os.Stderr, "  /load <model>   Load a session or model")
+		fmt.Fprintln(os.Stderr, "  /save <model>   Save your current session")
+		fmt.Fprintln(os.Stderr, "  /bye            Exit")
+		fmt.Fprintln(os.Stderr, "  /?, /help       Help for a command")
+		fmt.Fprintln(os.Stderr, "  /? shortcuts    Help for keyboard shortcuts")
 		fmt.Fprintln(os.Stderr, "")
 		fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
 		fmt.Fprintln(os.Stderr, "")
@@ -140,7 +174,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 
 	var sb strings.Builder
 	var multiline MultilineState
-	opts.Messages = make([]api.Message, 0)
 
 	for {
 		line, err := scanner.Readline()
@@ -203,6 +236,44 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 			if err := ListHandler(cmd, args[1:]); err != nil {
 				return err
 			}
+		case strings.HasPrefix(line, "/load"):
+			args := strings.Fields(line)
+			if len(args) != 2 {
+				fmt.Println("Usage:\n  /load <modelname>")
+				continue
+			}
+			opts.Model = args[1]
+			opts.Messages = []api.Message{}
+			fmt.Printf("Loading model '%s'\n", opts.Model)
+			if err := loadModel(cmd, &opts); err != nil {
+				return err
+			}
+			continue
+		case strings.HasPrefix(line, "/save"):
+			args := strings.Fields(line)
+			if len(args) != 2 {
+				fmt.Println("Usage:\n  /save <modelname>")
+				continue
+			}
+
+			client, err := api.ClientFromEnvironment()
+			if err != nil {
+				fmt.Println("error: couldn't connect to ollama server")
+				return err
+			}
+
+			req := &api.CreateRequest{
+				Name:      args[1],
+				Modelfile: buildModelfile(opts),
+			}
+			fn := func(resp api.ProgressResponse) error { return nil }
+			err = client.Create(cmd.Context(), req, fn)
+			if err != nil {
+				fmt.Println("error: couldn't save model")
+				return err
+			}
+			fmt.Printf("Created new model '%s'\n", args[1])
+			continue
 		case strings.HasPrefix(line, "/set"):
 			args := strings.Fields(line)
 			if len(args) > 1 {
@@ -389,7 +460,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 			args := strings.Fields(line)
 			isFile := false
 
-			if multiModal {
+			if opts.MultiModal {
 				for _, f := range extractFileNames(line) {
 					if strings.HasPrefix(f, args[0]) {
 						isFile = true
@@ -411,7 +482,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 		if sb.Len() > 0 && multiline == MultilineNone {
 			newMessage := api.Message{Role: "user", Content: sb.String()}
 
-			if multiModal {
+			if opts.MultiModal {
 				msg, images, err := extractFileData(sb.String())
 				if err != nil {
 					return err
@@ -454,6 +525,38 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 	}
 }
 
+func buildModelfile(opts runOptions) string {
+	var mf strings.Builder
+	model := opts.ParentModel
+	if model == "" {
+		model = opts.Model
+	}
+	fmt.Fprintf(&mf, "FROM %s\n", model)
+	if opts.System != "" {
+		fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
+	}
+
+	if opts.Template != "" {
+		fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template)
+	}
+
+	keys := make([]string, 0)
+	for k := range opts.Options {
+		keys = append(keys, k)
+	}
+	sort.Strings(keys)
+	for _, k := range keys {
+		fmt.Fprintf(&mf, "PARAMETER %s %v\n", k, opts.Options[k])
+	}
+	fmt.Fprintln(&mf)
+
+	for _, msg := range opts.Messages {
+		fmt.Fprintf(&mf, "MESSAGE %s \"\"\"%s\"\"\"\n", msg.Role, msg.Content)
+	}
+
+	return mf.String()
+}
+
 func normalizeFilePath(fp string) string {
 	// Define a map of escaped characters and their replacements
 	replacements := map[string]string{

+ 65 - 0
cmd/interactive_test.go

@@ -1,9 +1,13 @@
 package cmd
 
 import (
+	"bytes"
 	"testing"
+	"text/template"
 
 	"github.com/stretchr/testify/assert"
+
+	"github.com/jmorganca/ollama/api"
 )
 
 func TestExtractFilenames(t *testing.T) {
@@ -49,3 +53,64 @@ d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8
 	assert.Contains(t, res[9], "ten.svg")
 	assert.Contains(t, res[9], "E:")
 }
+
+func TestModelfileBuilder(t *testing.T) {
+	opts := runOptions{
+		Model:    "hork",
+		System:   "You are part horse and part shark, but all hork. Do horklike things",
+		Template: "This is a template.",
+		Messages: []api.Message{
+			{Role: "user", Content: "Hey there hork!"},
+			{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
+		},
+		Options: map[string]interface{}{},
+	}
+
+	opts.Options["temperature"] = 0.9
+	opts.Options["seed"] = 42
+	opts.Options["penalize_newline"] = false
+	opts.Options["stop"] = []string{"hi", "there"}
+
+	mf := buildModelfile(opts)
+	expectedModelfile := `FROM {{.Model}}
+SYSTEM """{{.System}}"""
+TEMPLATE """{{.Template}}"""
+PARAMETER penalize_newline false
+PARAMETER seed 42
+PARAMETER stop [hi there]
+PARAMETER temperature 0.9
+
+MESSAGE user """Hey there hork!"""
+MESSAGE assistant """Yes it is true, I am half horse, half shark."""
+`
+
+	tmpl, err := template.New("").Parse(expectedModelfile)
+	assert.Nil(t, err)
+
+	var buf bytes.Buffer
+	err = tmpl.Execute(&buf, opts)
+	assert.Nil(t, err)
+	assert.Equal(t, buf.String(), mf)
+
+	opts.ParentModel = "horseshark"
+	mf = buildModelfile(opts)
+	expectedModelfile = `FROM {{.ParentModel}}
+SYSTEM """{{.System}}"""
+TEMPLATE """{{.Template}}"""
+PARAMETER penalize_newline false
+PARAMETER seed 42
+PARAMETER stop [hi there]
+PARAMETER temperature 0.9
+
+MESSAGE user """Hey there hork!"""
+MESSAGE assistant """Yes it is true, I am half horse, half shark."""
+`
+
+	tmpl, err = template.New("").Parse(expectedModelfile)
+	assert.Nil(t, err)
+
+	var parentBuf bytes.Buffer
+	err = tmpl.Execute(&parentBuf, opts)
+	assert.Nil(t, err)
+	assert.Equal(t, parentBuf.String(), mf)
+}

+ 11 - 0
parser/parser.go

@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"io"
 	"log/slog"
+	"slices"
 )
 
 type Command struct {
@@ -56,6 +57,16 @@ func Parse(reader io.Reader) ([]Command, error) {
 			command.Args = string(bytes.TrimSpace(fields[1]))
 		case "EMBED":
 			return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
+		case "MESSAGE":
+			command.Name = string(bytes.ToLower(fields[0]))
+			fields = bytes.SplitN(fields[1], []byte(" "), 2)
+			if len(fields) < 2 {
+				return nil, fmt.Errorf("should be in the format <role> <message>")
+			}
+			if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
+				return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
+			}
+			command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
 		default:
 			if !bytes.HasPrefix(fields[0], []byte("#")) {
 				// log a warning for unknown commands

+ 35 - 0
parser/parser_test.go

@@ -61,3 +61,38 @@ PARAMETER param1
 	assert.ErrorContains(t, err, "missing value for [param1]")
 
 }
+
+func Test_Parser_Messages(t *testing.T) {
+
+	input := `
+FROM foo
+MESSAGE system You are a Parser. Always Parse things.
+MESSAGE user Hey there!
+MESSAGE assistant Hello, I want to parse all the things!
+`
+
+	reader := strings.NewReader(input)
+	commands, err := Parse(reader)
+	assert.Nil(t, err)
+
+	expectedCommands := []Command{
+		{Name: "model", Args: "foo"},
+		{Name: "message", Args: "system: You are a Parser. Always Parse things."},
+		{Name: "message", Args: "user: Hey there!"},
+		{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
+	}
+
+	assert.Equal(t, expectedCommands, commands)
+}
+
+func Test_Parser_Messages_BadRole(t *testing.T) {
+
+	input := `
+FROM foo
+MESSAGE badguy I'm a bad guy!
+`
+
+	reader := strings.NewReader(input)
+	_, err := Parse(reader)
+	assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
+}

+ 47 - 5
server/images.go

@@ -41,7 +41,7 @@ type Model struct {
 	Config         ConfigV2
 	ShortName      string
 	ModelPath      string
-	OriginalModel  string
+	ParentModel    string
 	AdapterPaths   []string
 	ProjectorPaths []string
 	Template       string
@@ -50,6 +50,12 @@ type Model struct {
 	Digest         string
 	Size           int64
 	Options        map[string]interface{}
+	Messages       []Message
+}
+
+type Message struct {
+	Role    string `json:"role"`
+	Content string `json:"content"`
 }
 
 type PromptVars struct {
@@ -333,7 +339,7 @@ func GetModel(name string) (*Model, error) {
 		switch layer.MediaType {
 		case "application/vnd.ollama.image.model":
 			model.ModelPath = filename
-			model.OriginalModel = layer.From
+			model.ParentModel = layer.From
 		case "application/vnd.ollama.image.embed":
 			// Deprecated in versions  > 0.1.2
 			// TODO: remove this warning in a future version
@@ -374,6 +380,16 @@ func GetModel(name string) (*Model, error) {
 			if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
 				return nil, err
 			}
+		case "application/vnd.ollama.image.messages":
+			msgs, err := os.Open(filename)
+			if err != nil {
+				return nil, err
+			}
+			defer msgs.Close()
+
+			if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
+				return nil, err
+			}
 		case "application/vnd.ollama.image.license":
 			bts, err := os.ReadFile(filename)
 			if err != nil {
@@ -428,12 +444,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 	}
 
 	var layers Layers
+	messages := []string{}
 
 	params := make(map[string][]string)
 	fromParams := make(map[string]any)
 
 	for _, c := range commands {
-		slog.Info(fmt.Sprintf("[%s] - %s", c.Name, c.Args))
 		mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
 
 		switch c.Name {
@@ -607,11 +623,37 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 			}
 
 			layers.Replace(layer)
+		case "message":
+			messages = append(messages, c.Args)
 		default:
 			params[c.Name] = append(params[c.Name], c.Args)
 		}
 	}
 
+	if len(messages) > 0 {
+		fn(api.ProgressResponse{Status: "creating parameters layer"})
+
+		msgs := make([]api.Message, 0)
+
+		for _, m := range messages {
+			// todo: handle images
+			msg := strings.SplitN(m, ": ", 2)
+			msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]})
+		}
+
+		var b bytes.Buffer
+		if err := json.NewEncoder(&b).Encode(msgs); err != nil {
+			return err
+		}
+
+		layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
+		if err != nil {
+			return err
+		}
+
+		layers.Replace(layer)
+	}
+
 	if len(params) > 0 {
 		fn(api.ProgressResponse{Status: "creating parameters layer"})
 
@@ -908,8 +950,8 @@ func ShowModelfile(model *Model) (string, error) {
 	mt.Model = model
 	mt.From = model.ModelPath
 
-	if model.OriginalModel != "" {
-		mt.From = model.OriginalModel
+	if model.ParentModel != "" {
+		mt.From = model.ParentModel
 	}
 
 	modelFile := `# Modelfile generated by "ollama show"

+ 14 - 1
server/routes.go

@@ -659,6 +659,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 	}
 
 	modelDetails := api.ModelDetails{
+		ParentModel:       model.ParentModel,
 		Format:            model.Config.ModelFormat,
 		Family:            model.Config.ModelFamily,
 		Families:          model.Config.ModelFamilies,
@@ -674,11 +675,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 		model.Template = req.Template
 	}
 
+	msgs := make([]api.Message, 0)
+	for _, msg := range model.Messages {
+		msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
+	}
+
 	resp := &api.ShowResponse{
 		License:  strings.Join(model.License, "\n"),
 		System:   model.System,
 		Template: model.Template,
 		Details:  modelDetails,
+		Messages: msgs,
 	}
 
 	var params []string
@@ -1075,7 +1082,13 @@ func ChatHandler(c *gin.Context) {
 
 	// an empty request loads the model
 	if len(req.Messages) == 0 {
-		c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true, Message: api.Message{Role: "assistant"}})
+		resp := api.ChatResponse{
+			CreatedAt: time.Now().UTC(),
+			Model:     req.Model,
+			Done:      true,
+			Message:   api.Message{Role: "assistant"},
+		}
+		c.JSON(http.StatusOK, resp)
 		return
 	}