Explorar o código

treat `ollama run model < file` as entire prompt, not prompt-per-line (#1126)

Previously, `ollama run` treated a non-terminal stdin (such as `ollama run model < file`) as containing one prompt per line. To run inference on a multi-line prompt, the only non-API workaround was to run `ollama run` interactively and wrap the prompt in `"""..."""`.

Now, `ollama run` treats a non-terminal stdin as containing a single prompt. For example, if `myprompt.txt` is a multi-line file, then `ollama run model < myprompt.txt` would treat `myprompt.txt`'s entire contents as the prompt.

Co-authored-by: Quinn Slack <quinn@slack.org>
Jeffrey Morgan hai 1 ano
pai
achega
423862042a
Modificáronse 1 ficheiros con 29 adicións e 53 borrados
  1. 29 53
      cmd/cmd.go

+ 29 - 53
cmd/cmd.go

@@ -1,7 +1,6 @@
 package cmd
 package cmd
 
 
 import (
 import (
-	"bufio"
 	"context"
 	"context"
 	"crypto/ed25519"
 	"crypto/ed25519"
 	"crypto/rand"
 	"crypto/rand"
@@ -350,34 +349,44 @@ func pull(model string, insecure bool) error {
 }
 }
 
 
 func RunGenerate(cmd *cobra.Command, args []string) error {
 func RunGenerate(cmd *cobra.Command, args []string) error {
-	if len(args) > 1 {
-		// join all args into a single prompt
-		wordWrap := false
-		if term.IsTerminal(int(os.Stdout.Fd())) {
-			wordWrap = true
-		}
+	format, err := cmd.Flags().GetString("format")
+	if err != nil {
+		return err
+	}
 
 
-		nowrap, err := cmd.Flags().GetBool("nowordwrap")
-		if err != nil {
-			return err
-		}
-		if nowrap {
-			wordWrap = false
-		}
+	prompts := args[1:]
 
 
-		format, err := cmd.Flags().GetString("format")
+	// prepend stdin to the prompt if provided
+	if !term.IsTerminal(int(os.Stdin.Fd())) {
+		in, err := io.ReadAll(os.Stdin)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
 
 
-		return generate(cmd, args[0], strings.Join(args[1:], " "), wordWrap, format)
+		prompts = append([]string{string(in)}, prompts...)
 	}
 	}
 
 
-	if readline.IsTerminal(int(os.Stdin.Fd())) {
-		return generateInteractive(cmd, args[0])
+	// output is being piped
+	if !term.IsTerminal(int(os.Stdout.Fd())) {
+		return generate(cmd, args[0], strings.Join(prompts, " "), false, format)
 	}
 	}
 
 
-	return generateBatch(cmd, args[0])
+	wordWrap := os.Getenv("TERM") == "xterm-256color"
+
+	nowrap, err := cmd.Flags().GetBool("nowordwrap")
+	if err != nil {
+		return err
+	}
+	if nowrap {
+		wordWrap = false
+	}
+
+	// prompts are provided via stdin or args so don't enter interactive mode
+	if len(prompts) > 0 {
+		return generate(cmd, args[0], strings.Join(prompts, " "), wordWrap, format)
+	}
+
+	return generateInteractive(cmd, args[0], wordWrap, format)
 }
 }
 
 
 type generateContextKey string
 type generateContextKey string
@@ -490,7 +499,7 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
 	return nil
 	return nil
 }
 }
 
 
-func generateInteractive(cmd *cobra.Command, model string) error {
+func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format string) error {
 	// load the model
 	// load the model
 	if err := generate(cmd, model, "", false, ""); err != nil {
 	if err := generate(cmd, model, "", false, ""); err != nil {
 		return err
 		return err
@@ -542,26 +551,6 @@ func generateInteractive(cmd *cobra.Command, model string) error {
 		return err
 		return err
 	}
 	}
 
 
-	format, err := cmd.Flags().GetString("format")
-	if err != nil {
-		return err
-	}
-
-	var wordWrap bool
-	termType := os.Getenv("TERM")
-	if termType == "xterm-256color" {
-		wordWrap = true
-	}
-
-	// override wrapping if the user turned it off
-	nowrap, err := cmd.Flags().GetBool("nowordwrap")
-	if err != nil {
-		return err
-	}
-	if nowrap {
-		wordWrap = false
-	}
-
 	fmt.Print(readline.StartBracketedPaste)
 	fmt.Print(readline.StartBracketedPaste)
 	defer fmt.Printf(readline.EndBracketedPaste)
 	defer fmt.Printf(readline.EndBracketedPaste)
 
 
@@ -715,19 +704,6 @@ func generateInteractive(cmd *cobra.Command, model string) error {
 	}
 	}
 }
 }
 
 
-func generateBatch(cmd *cobra.Command, model string) error {
-	scanner := bufio.NewScanner(os.Stdin)
-	for scanner.Scan() {
-		prompt := scanner.Text()
-		fmt.Printf(">>> %s\n", prompt)
-		if err := generate(cmd, model, prompt, false, ""); err != nil {
-			return err
-		}
-	}
-
-	return nil
-}
-
 func RunServer(cmd *cobra.Command, _ []string) error {
 func RunServer(cmd *cobra.Command, _ []string) error {
 	host, port, err := net.SplitHostPort(os.Getenv("OLLAMA_HOST"))
 	host, port, err := net.SplitHostPort(os.Getenv("OLLAMA_HOST"))
 	if err != nil {
 	if err != nil {