Explorar o código

allow setting the system and template for prompts in the repl (#1335)

Patrick Devine hai 1 ano
pai
achega
6681d37861
Modificáronse 1 ficheiros con 87 adicións e 21 borrados
  1. 87 21
      cmd/cmd.go

+ 87 - 21
cmd/cmd.go

@@ -464,6 +464,8 @@ type generateOptions struct {
 	Prompt   string
 	WordWrap bool
 	Format   string
+	System   string
+	Template string
 	Options  map[string]interface{}
 }
 
@@ -506,11 +508,13 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
 	var wordBuffer string
 
 	request := api.GenerateRequest{
-		Model:   opts.Model,
-		Prompt:  opts.Prompt,
-		Context: generateContext,
-		Format:  opts.Format,
-		Options: opts.Options,
+		Model:    opts.Model,
+		Prompt:   opts.Prompt,
+		Context:  generateContext,
+		Format:   opts.Format,
+		System:   opts.System,
+		Template: opts.Template,
+		Options:  opts.Options,
 	}
 	fn := func(response api.GenerateResponse) error {
 		p.StopAndClear()
@@ -576,6 +580,15 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
 	return nil
 }
 
+type MultilineState int
+
+const (
+	MultilineNone MultilineState = iota
+	MultilinePrompt
+	MultilineSystem
+	MultilineTemplate
+)
+
 func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 	// load the model
 	loadOpts := generateOptions{
@@ -599,15 +612,17 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 
 	usageSet := func() {
 		fmt.Fprintln(os.Stderr, "Available Commands:")
-		fmt.Fprintln(os.Stderr, "  /set parameter    Set a parameter")
-		fmt.Fprintln(os.Stderr, "  /set history      Enable history")
-		fmt.Fprintln(os.Stderr, "  /set nohistory    Disable history")
-		fmt.Fprintln(os.Stderr, "  /set wordwrap     Enable wordwrap")
-		fmt.Fprintln(os.Stderr, "  /set nowordwrap   Disable wordwrap")
-		fmt.Fprintln(os.Stderr, "  /set format json  Enable JSON mode")
-		fmt.Fprintln(os.Stderr, "  /set noformat     Disable formatting")
-		fmt.Fprintln(os.Stderr, "  /set verbose      Show LLM stats")
-		fmt.Fprintln(os.Stderr, "  /set quiet        Disable LLM stats")
+		fmt.Fprintln(os.Stderr, "  /set parameter ...     Set a parameter")
+		fmt.Fprintln(os.Stderr, "  /set system <string>   Set system prompt")
+		fmt.Fprintln(os.Stderr, "  /set template <string> Set prompt template")
+		fmt.Fprintln(os.Stderr, "  /set history           Enable history")
+		fmt.Fprintln(os.Stderr, "  /set nohistory         Disable history")
+		fmt.Fprintln(os.Stderr, "  /set wordwrap          Enable wordwrap")
+		fmt.Fprintln(os.Stderr, "  /set nowordwrap        Disable wordwrap")
+		fmt.Fprintln(os.Stderr, "  /set format json       Enable JSON mode")
+		fmt.Fprintln(os.Stderr, "  /set noformat          Disable formatting")
+		fmt.Fprintln(os.Stderr, "  /set verbose           Show LLM stats")
+		fmt.Fprintln(os.Stderr, "  /set quiet             Disable LLM stats")
 		fmt.Fprintln(os.Stderr, "")
 	}
 
@@ -650,6 +665,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 	fmt.Print(readline.StartBracketedPaste)
 	defer fmt.Printf(readline.EndBracketedPaste)
 
+	var multiline MultilineState
 	var prompt string
 
 	for {
@@ -684,8 +700,21 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 
 			prompt = strings.TrimPrefix(prompt, `"""`)
 			scanner.Prompt.UseAlt = false
+
+			switch multiline {
+			case MultilineSystem:
+				opts.System = prompt
+				prompt = ""
+				fmt.Println("Set system template.\n")
+			case MultilineTemplate:
+				opts.Template = prompt
+				prompt = ""
+				fmt.Println("Set model template.\n")
+			}
+			multiline = MultilineNone
 		case strings.HasPrefix(line, `"""`) && len(prompt) == 0:
 			scanner.Prompt.UseAlt = true
+			multiline = MultilinePrompt
 			prompt += line + "\n"
 			continue
 		case scanner.Pasting:
@@ -742,6 +771,37 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 					}
 					fmt.Printf("Set parameter '%s' to '%s'\n\n", args[2], strings.Join(params, ", "))
 					opts.Options[args[2]] = fp[args[2]]
+				case "system", "template":
+					if len(args) < 3 {
+						usageSet()
+						continue
+					}
+					line := strings.Join(args[2:], " ")
+					line = strings.TrimPrefix(line, `"""`)
+					if strings.HasPrefix(args[2], `"""`) {
+						cut, found := strings.CutSuffix(line, `"""`)
+						prompt += cut + "\n"
+						if found {
+							opts.System = prompt
+							if args[1] == "system" {
+								fmt.Println("Set system template.\n")
+							} else {
+								fmt.Println("Set prompt template.\n")
+							}
+							prompt = ""
+						} else {
+							prompt = `"""` + prompt
+							if args[1] == "system" {
+								multiline = MultilineSystem
+							} else {
+								multiline = MultilineTemplate
+							}
+							scanner.Prompt.UseAlt = true
+						}
+					} else {
+						opts.System = line
+						fmt.Println("Set system template.\n")
+					}
 				default:
 					fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
 				}
@@ -786,16 +846,22 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 						fmt.Println(resp.Parameters)
 					}
 				case "system":
-					if resp.System == "" {
+					switch {
+					case opts.System != "":
+						fmt.Println(opts.System + "\n")
+					case resp.System != "":
+						fmt.Println(resp.System + "\n")
+					default:
 						fmt.Print("No system prompt was specified for this model.\n\n")
-					} else {
-						fmt.Println(resp.System)
 					}
 				case "template":
-					if resp.Template == "" {
-						fmt.Print("No prompt template was specified for this model.\n\n")
-					} else {
+					switch {
+					case opts.Template != "":
+						fmt.Println(opts.Template + "\n")
+					case resp.Template != "":
 						fmt.Println(resp.Template)
+					default:
+						fmt.Print("No prompt template was specified for this model.\n\n")
 					}
 				default:
 					fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
@@ -825,7 +891,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 			prompt += line
 		}
 
-		if len(prompt) > 0 && prompt[0] != '/' {
+		if len(prompt) > 0 && multiline == MultilineNone {
 			opts.Prompt = prompt
 			if err := generate(cmd, opts); err != nil {
 				return err