Browse Source

Convert the REPL to use /api/chat for interactive responses (#1936)

Patrick Devine 1 year ago
parent
commit
565f8a3c44
2 changed files with 155 additions and 72 deletions
  1. 124 54
      cmd/cmd.go
  2. 31 18
      cmd/interactive.go

+ 124 - 54
cmd/cmd.go

@@ -35,8 +35,6 @@ import (
 	"github.com/jmorganca/ollama/version"
 	"github.com/jmorganca/ollama/version"
 )
 )
 
 
-type ImageData []byte
-
 func CreateHandler(cmd *cobra.Command, args []string) error {
 func CreateHandler(cmd *cobra.Command, args []string) error {
 	filename, _ := cmd.Flags().GetString("file")
 	filename, _ := cmd.Flags().GetString("file")
 	filename, err := filepath.Abs(filename)
 	filename, err := filepath.Abs(filename)
@@ -415,11 +413,10 @@ func PullHandler(cmd *cobra.Command, args []string) error {
 func RunGenerate(cmd *cobra.Command, args []string) error {
 func RunGenerate(cmd *cobra.Command, args []string) error {
 	interactive := true
 	interactive := true
 
 
-	opts := generateOptions{
+	opts := runOptions{
 		Model:    args[0],
 		Model:    args[0],
 		WordWrap: os.Getenv("TERM") == "xterm-256color",
 		WordWrap: os.Getenv("TERM") == "xterm-256color",
 		Options:  map[string]interface{}{},
 		Options:  map[string]interface{}{},
-		Images:   []ImageData{},
 	}
 	}
 
 
 	format, err := cmd.Flags().GetString("format")
 	format, err := cmd.Flags().GetString("format")
@@ -460,18 +457,135 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
 
 
 type generateContextKey string
 type generateContextKey string
 
 
-type generateOptions struct {
+type runOptions struct {
 	Model    string
 	Model    string
 	Prompt   string
 	Prompt   string
+	Messages []api.Message
 	WordWrap bool
 	WordWrap bool
 	Format   string
 	Format   string
 	System   string
 	System   string
 	Template string
 	Template string
-	Images   []ImageData
+	Images   []api.ImageData
 	Options  map[string]interface{}
 	Options  map[string]interface{}
 }
 }
 
 
-func generate(cmd *cobra.Command, opts generateOptions) error {
+type displayResponseState struct {
+	lineLength int
+	wordBuffer string
+}
+
+func displayResponse(content string, wordWrap bool, state *displayResponseState) {
+	termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
+	if wordWrap && termWidth >= 10 {
+		for _, ch := range content {
+			if state.lineLength+1 > termWidth-5 {
+				if len(state.wordBuffer) > termWidth-10 {
+					fmt.Printf("%s%c", state.wordBuffer, ch)
+					state.wordBuffer = ""
+					state.lineLength = 0
+					continue
+				}
+
+				// backtrack the length of the last word and clear to the end of the line
+				fmt.Printf("\x1b[%dD\x1b[K\n", len(state.wordBuffer))
+				fmt.Printf("%s%c", state.wordBuffer, ch)
+				state.lineLength = len(state.wordBuffer) + 1
+			} else {
+				fmt.Print(string(ch))
+				state.lineLength += 1
+
+				switch ch {
+				case ' ':
+					state.wordBuffer = ""
+				case '\n':
+					state.lineLength = 0
+				default:
+					state.wordBuffer += string(ch)
+				}
+			}
+		}
+	} else {
+		fmt.Printf("%s%s", state.wordBuffer, content)
+		if len(state.wordBuffer) > 0 {
+			state.wordBuffer = ""
+		}
+	}
+}
+
+func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
+	client, err := api.ClientFromEnvironment()
+	if err != nil {
+		return nil, err
+	}
+
+	p := progress.NewProgress(os.Stderr)
+	defer p.StopAndClear()
+
+	spinner := progress.NewSpinner("")
+	p.Add("", spinner)
+
+	cancelCtx, cancel := context.WithCancel(cmd.Context())
+	defer cancel()
+
+	sigChan := make(chan os.Signal, 1)
+	signal.Notify(sigChan, syscall.SIGINT)
+
+	go func() {
+		<-sigChan
+		cancel()
+	}()
+
+	var state *displayResponseState = &displayResponseState{}
+	var latest api.ChatResponse
+	var fullResponse strings.Builder
+	var role string
+
+	fn := func(response api.ChatResponse) error {
+		p.StopAndClear()
+
+		latest = response
+
+		role = response.Message.Role
+		content := response.Message.Content
+		fullResponse.WriteString(content)
+
+		displayResponse(content, opts.WordWrap, state)
+
+		return nil
+	}
+
+	req := &api.ChatRequest{
+		Model:    opts.Model,
+		Messages: opts.Messages,
+		Format:   opts.Format,
+		Options:  opts.Options,
+	}
+
+	if err := client.Chat(cancelCtx, req, fn); err != nil {
+		if errors.Is(err, context.Canceled) {
+			return nil, nil
+		}
+		return nil, err
+	}
+
+	if len(opts.Messages) > 0 {
+		fmt.Println()
+		fmt.Println()
+	}
+
+	verbose, err := cmd.Flags().GetBool("verbose")
+	if err != nil {
+		return nil, err
+	}
+
+	if verbose {
+		latest.Summary()
+	}
+
+	return &api.Message{Role: role, Content: fullResponse.String()}, nil
+}
+
+func generate(cmd *cobra.Command, opts runOptions) error {
 	client, err := api.ClientFromEnvironment()
 	client, err := api.ClientFromEnvironment()
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -490,11 +604,6 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
 		generateContext = []int{}
 		generateContext = []int{}
 	}
 	}
 
 
-	termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
-	if err != nil {
-		opts.WordWrap = false
-	}
-
 	ctx, cancel := context.WithCancel(cmd.Context())
 	ctx, cancel := context.WithCancel(cmd.Context())
 	defer cancel()
 	defer cancel()
 
 
@@ -506,57 +615,19 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
 		cancel()
 		cancel()
 	}()
 	}()
 
 
-	var currentLineLength int
-	var wordBuffer string
+	var state *displayResponseState = &displayResponseState{}
 
 
 	fn := func(response api.GenerateResponse) error {
 	fn := func(response api.GenerateResponse) error {
 		p.StopAndClear()
 		p.StopAndClear()
 
 
 		latest = response
 		latest = response
+		content := response.Response
 
 
-		termWidth, _, _ = term.GetSize(int(os.Stdout.Fd()))
-		if opts.WordWrap && termWidth >= 10 {
-			for _, ch := range response.Response {
-				if currentLineLength+1 > termWidth-5 {
-					if len(wordBuffer) > termWidth-10 {
-						fmt.Printf("%s%c", wordBuffer, ch)
-						wordBuffer = ""
-						currentLineLength = 0
-						continue
-					}
-
-					// backtrack the length of the last word and clear to the end of the line
-					fmt.Printf("\x1b[%dD\x1b[K\n", len(wordBuffer))
-					fmt.Printf("%s%c", wordBuffer, ch)
-					currentLineLength = len(wordBuffer) + 1
-				} else {
-					fmt.Print(string(ch))
-					currentLineLength += 1
-
-					switch ch {
-					case ' ':
-						wordBuffer = ""
-					case '\n':
-						currentLineLength = 0
-					default:
-						wordBuffer += string(ch)
-					}
-				}
-			}
-		} else {
-			fmt.Printf("%s%s", wordBuffer, response.Response)
-			if len(wordBuffer) > 0 {
-				wordBuffer = ""
-			}
-		}
+		displayResponse(content, opts.WordWrap, state)
 
 
 		return nil
 		return nil
 	}
 	}
 
 
-	images := make([]api.ImageData, 0)
-	for _, i := range opts.Images {
-		images = append(images, api.ImageData(i))
-	}
 	request := api.GenerateRequest{
 	request := api.GenerateRequest{
 		Model:    opts.Model,
 		Model:    opts.Model,
 		Prompt:   opts.Prompt,
 		Prompt:   opts.Prompt,
@@ -565,7 +636,6 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
 		System:   opts.System,
 		System:   opts.System,
 		Template: opts.Template,
 		Template: opts.Template,
 		Options:  opts.Options,
 		Options:  opts.Options,
-		Images:   images,
 	}
 	}
 
 
 	if err := client.Generate(ctx, &request, fn); err != nil {
 	if err := client.Generate(ctx, &request, fn); err != nil {

+ 31 - 18
cmd/interactive.go

@@ -1,7 +1,6 @@
 package cmd
 package cmd
 
 
 import (
 import (
-	"context"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
@@ -43,16 +42,16 @@ func modelIsMultiModal(cmd *cobra.Command, name string) bool {
 	return slices.Contains(resp.Details.Families, "clip")
 	return slices.Contains(resp.Details.Families, "clip")
 }
 }
 
 
-func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
+func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 	multiModal := modelIsMultiModal(cmd, opts.Model)
 	multiModal := modelIsMultiModal(cmd, opts.Model)
 
 
 	// load the model
 	// load the model
-	loadOpts := generateOptions{
-		Model:  opts.Model,
-		Prompt: "",
-		Images: []ImageData{},
+	loadOpts := runOptions{
+		Model:    opts.Model,
+		Prompt:   "",
+		Messages: []api.Message{},
 	}
 	}
-	if err := generate(cmd, loadOpts); err != nil {
+	if _, err := chat(cmd, loadOpts); err != nil {
 		return err
 		return err
 	}
 	}
 
 
@@ -141,6 +140,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 
 
 	var sb strings.Builder
 	var sb strings.Builder
 	var multiline MultilineState
 	var multiline MultilineState
+	opts.Messages = make([]api.Message, 0)
 
 
 	for {
 	for {
 		line, err := scanner.Readline()
 		line, err := scanner.Readline()
@@ -409,22 +409,26 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 		}
 		}
 
 
 		if sb.Len() > 0 && multiline == MultilineNone {
 		if sb.Len() > 0 && multiline == MultilineNone {
-			opts.Prompt = sb.String()
+			newMessage := api.Message{Role: "user", Content: sb.String()}
+
 			if multiModal {
 			if multiModal {
-				newPrompt, images, err := extractFileData(sb.String())
+				msg, images, err := extractFileData(sb.String())
 				if err != nil {
 				if err != nil {
 					return err
 					return err
 				}
 				}
-				opts.Prompt = newPrompt
+				newMessage.Content = msg
 
 
 				// reset the context if we find another image
 				// reset the context if we find another image
 				if len(images) > 0 {
 				if len(images) > 0 {
-					opts.Images = images
-					ctx := cmd.Context()
-					ctx = context.WithValue(ctx, generateContextKey("context"), []int{})
-					cmd.SetContext(ctx)
+					newMessage.Images = append(newMessage.Images, images...)
+					// reset the context for the new image
+					opts.Messages = []api.Message{}
+				} else {
+					if len(opts.Messages) > 1 {
+						newMessage.Images = append(newMessage.Images, opts.Messages[len(opts.Messages)-2].Images...)
+					}
 				}
 				}
-				if len(opts.Images) == 0 {
+				if len(newMessage.Images) == 0 {
 					fmt.Println("This model requires you to add a jpeg, png, or svg image.")
 					fmt.Println("This model requires you to add a jpeg, png, or svg image.")
 					fmt.Println()
 					fmt.Println()
 					sb.Reset()
 					sb.Reset()
@@ -432,9 +436,18 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 				}
 				}
 			}
 			}
 
 
-			if err := generate(cmd, opts); err != nil {
+			if opts.System != "" {
+				opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
+			}
+			opts.Messages = append(opts.Messages, newMessage)
+
+			assistant, err := chat(cmd, opts)
+			if err != nil {
 				return err
 				return err
 			}
 			}
+			if assistant != nil {
+				opts.Messages = append(opts.Messages, *assistant)
+			}
 
 
 			sb.Reset()
 			sb.Reset()
 		}
 		}
@@ -476,9 +489,9 @@ func extractFileNames(input string) []string {
 	return re.FindAllString(input, -1)
 	return re.FindAllString(input, -1)
 }
 }
 
 
-func extractFileData(input string) (string, []ImageData, error) {
+func extractFileData(input string) (string, []api.ImageData, error) {
 	filePaths := extractFileNames(input)
 	filePaths := extractFileNames(input)
-	var imgs []ImageData
+	var imgs []api.ImageData
 
 
 	for _, fp := range filePaths {
 	for _, fp := range filePaths {
 		nfp := normalizeFilePath(fp)
 		nfp := normalizeFilePath(fp)