فهرست منبع

Merge pull request #1614 from jmorganca/mxyng/fix-set-template

fix: set template without triple quotes
Michael Yang 1 سال پیش
والد
کامیت
62023177f6
1فایلهای تغییر یافته به همراه67 افزوده شده و 55 حذف شده
  1. 67 55
      cmd/interactive.go

+ 67 - 55
cmd/interactive.go

@@ -139,8 +139,8 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 	fmt.Print(readline.StartBracketedPaste)
 	defer fmt.Printf(readline.EndBracketedPaste)
 
+	var sb strings.Builder
 	var multiline MultilineState
-	var prompt string
 
 	for {
 		line, err := scanner.Readline()
@@ -154,7 +154,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 			}
 
 			scanner.Prompt.UseAlt = false
-			prompt = ""
+			sb.Reset()
 
 			continue
 		case err != nil:
@@ -162,38 +162,41 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 		}
 
 		switch {
-		case strings.HasPrefix(prompt, `"""`):
-			// if the prompt so far starts with """ then we're in multiline mode
-			// and we need to keep reading until we find a line that ends with """
-			cut, found := strings.CutSuffix(line, `"""`)
-			prompt += cut
-
-			if !found {
-				prompt += "\n"
+		case multiline != MultilineNone:
+			// check if there's a multiline terminating string
+			before, ok := strings.CutSuffix(line, `"""`)
+			sb.WriteString(before)
+			if !ok {
+				fmt.Fprintln(&sb)
 				continue
 			}
 
-			prompt = strings.TrimPrefix(prompt, `"""`)
-			scanner.Prompt.UseAlt = false
-
 			switch multiline {
 			case MultilineSystem:
-				opts.System = prompt
-				prompt = ""
+				opts.System = sb.String()
 				fmt.Println("Set system message.")
+				sb.Reset()
 			case MultilineTemplate:
-				opts.Template = prompt
-				prompt = ""
+				opts.Template = sb.String()
 				fmt.Println("Set prompt template.")
+				sb.Reset()
 			}
+
 			multiline = MultilineNone
-		case strings.HasPrefix(line, `"""`) && len(prompt) == 0:
-			scanner.Prompt.UseAlt = true
-			multiline = MultilinePrompt
-			prompt += line + "\n"
-			continue
+			scanner.Prompt.UseAlt = false
+		case strings.HasPrefix(line, `"""`):
+			line := strings.TrimPrefix(line, `"""`)
+			line, ok := strings.CutSuffix(line, `"""`)
+			sb.WriteString(line)
+			if !ok {
+				// no multiline terminating string; need more input
+				fmt.Fprintln(&sb)
+				multiline = MultilinePrompt
+				scanner.Prompt.UseAlt = true
+				break
+			}
 		case scanner.Pasting:
-			prompt += line + "\n"
+			fmt.Fprintln(&sb, line)
 			continue
 		case strings.HasPrefix(line, "/list"):
 			args := strings.Fields(line)
@@ -251,33 +254,41 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 						usageSet()
 						continue
 					}
+
+					if args[1] == "system" {
+						multiline = MultilineSystem
+					} else if args[1] == "template" {
+						multiline = MultilineTemplate
+					}
+
 					line := strings.Join(args[2:], " ")
-					line = strings.TrimPrefix(line, `"""`)
-					if strings.HasPrefix(args[2], `"""`) {
-						cut, found := strings.CutSuffix(line, `"""`)
-						prompt += cut
-						if found {
-							if args[1] == "system" {
-								opts.System = prompt
-								fmt.Println("Set system message.")
-							} else {
-								opts.Template = prompt
-								fmt.Println("Set prompt template.")
-							}
-							prompt = ""
-						} else {
-							prompt = `"""` + prompt + "\n"
-							if args[1] == "system" {
-								multiline = MultilineSystem
-							} else {
-								multiline = MultilineTemplate
-							}
-							scanner.Prompt.UseAlt = true
-						}
+					line, ok := strings.CutPrefix(line, `"""`)
+					if !ok {
+						multiline = MultilineNone
 					} else {
-						opts.System = line
+						// only cut suffix if the line is multiline
+						line, ok = strings.CutSuffix(line, `"""`)
+						if ok {
+							multiline = MultilineNone
+						}
+					}
+
+					sb.WriteString(line)
+					if multiline != MultilineNone {
+						scanner.Prompt.UseAlt = true
+						continue
+					}
+
+					if args[1] == "system" {
+						opts.System = sb.String()
 						fmt.Println("Set system message.")
+					} else if args[1] == "template" {
+						opts.Template = sb.String()
+						fmt.Println("Set prompt template.")
 					}
+
+					sb.Reset()
+					continue
 				default:
 					fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
 				}
@@ -390,20 +401,20 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 				}
 			}
 
-			if isFile {
-				prompt += line
-			} else {
+			if !isFile {
 				fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0])
 				continue
 			}
+
+			sb.WriteString(line)
 		default:
-			prompt += line
+			sb.WriteString(line)
 		}
 
-		if len(prompt) > 0 && multiline == MultilineNone {
-			opts.Prompt = prompt
+		if sb.Len() > 0 && multiline == MultilineNone {
+			opts.Prompt = sb.String()
 			if multiModal {
-				newPrompt, images, err := extractFileData(prompt)
+				newPrompt, images, err := extractFileData(sb.String())
 				if err != nil {
 					return err
 				}
@@ -419,15 +430,16 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 				if len(opts.Images) == 0 {
 					fmt.Println("This model requires you to add a jpeg, png, or svg image.")
 					fmt.Println()
-					prompt = ""
+					sb.Reset()
 					continue
 				}
 			}
+
 			if err := generate(cmd, opts); err != nil {
 				return err
 			}
 
-			prompt = ""
+			sb.Reset()
 		}
 	}
 }