浏览代码

revert cli to use /api/generate (#1383)

Patrick Devine 1 年之前
父节点
当前提交
bf704423c5
共有 1 个文件被更改,包括 119 次插入115 次删除
  1. 119 115
      cmd/cmd.go

+ 119 - 115
cmd/cmd.go

@@ -159,54 +159,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 		return err
 	}
 
-	interactive := true
-
-	opts := runOptions{
-		Model:    name,
-		WordWrap: os.Getenv("TERM") == "xterm-256color",
-		Options:  map[string]interface{}{},
-	}
-
-	format, err := cmd.Flags().GetString("format")
-	if err != nil {
-		return err
-	}
-	opts.Format = format
-
-	prompts := args[1:]
-
-	// prepend stdin to the prompt if provided
-	if !term.IsTerminal(int(os.Stdin.Fd())) {
-		in, err := io.ReadAll(os.Stdin)
-		if err != nil {
-			return err
-		}
-
-		prompts = append([]string{string(in)}, prompts...)
-		opts.WordWrap = false
-		interactive = false
-	}
-	msg := api.Message{
-		Role:    "user",
-		Content: strings.Join(prompts, " "),
-	}
-	opts.Messages = append(opts.Messages, msg)
-	if len(prompts) > 0 {
-		interactive = false
-	}
-
-	nowrap, err := cmd.Flags().GetBool("nowordwrap")
-	if err != nil {
-		return err
-	}
-	opts.WordWrap = !nowrap
-
-	if !interactive {
-		_, err := chat(cmd, opts)
-		return err
-	}
-
-	return chatInteractive(cmd, opts)
+	return RunGenerate(cmd, args)
 }
 
 func PushHandler(cmd *cobra.Command, args []string) error {
@@ -458,26 +411,83 @@ func PullHandler(cmd *cobra.Command, args []string) error {
 	return nil
 }
 
-type runOptions struct {
+func RunGenerate(cmd *cobra.Command, args []string) error {
+	interactive := true
+
+	opts := generateOptions{
+		Model:    args[0],
+		WordWrap: os.Getenv("TERM") == "xterm-256color",
+		Options:  map[string]interface{}{},
+	}
+
+	format, err := cmd.Flags().GetString("format")
+	if err != nil {
+		return err
+	}
+	opts.Format = format
+
+	prompts := args[1:]
+
+	// prepend stdin to the prompt if provided
+	if !term.IsTerminal(int(os.Stdin.Fd())) {
+		in, err := io.ReadAll(os.Stdin)
+		if err != nil {
+			return err
+		}
+
+		prompts = append([]string{string(in)}, prompts...)
+		opts.WordWrap = false
+		interactive = false
+	}
+	opts.Prompt = strings.Join(prompts, " ")
+	if len(prompts) > 0 {
+		interactive = false
+	}
+
+	nowrap, err := cmd.Flags().GetBool("nowordwrap")
+	if err != nil {
+		return err
+	}
+	opts.WordWrap = !nowrap
+
+	if !interactive {
+		return generate(cmd, opts)
+	}
+
+	return generateInteractive(cmd, opts)
+}
+
+type generateContextKey string
+
+type generateOptions struct {
 	Model    string
-	Messages []api.Message
+	Prompt   string
 	WordWrap bool
 	Format   string
+	System   string
 	Template string
 	Options  map[string]interface{}
 }
 
-func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
+func generate(cmd *cobra.Command, opts generateOptions) error {
 	client, err := api.ClientFromEnvironment()
 	if err != nil {
-		return nil, err
+		return err
 	}
 
 	p := progress.NewProgress(os.Stderr)
 	defer p.StopAndClear()
+
 	spinner := progress.NewSpinner("")
 	p.Add("", spinner)
 
+	var latest api.GenerateResponse
+
+	generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
+	if !ok {
+		generateContext = []int{}
+	}
+
 	termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
 	if err != nil {
 		opts.WordWrap = false
@@ -496,24 +506,24 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
 
 	var currentLineLength int
 	var wordBuffer string
-	var latest api.ChatResponse
-	var fullResponse strings.Builder
-	var role string
 
-	fn := func(response api.ChatResponse) error {
+	request := api.GenerateRequest{
+		Model:    opts.Model,
+		Prompt:   opts.Prompt,
+		Context:  generateContext,
+		Format:   opts.Format,
+		System:   opts.System,
+		Template: opts.Template,
+		Options:  opts.Options,
+	}
+	fn := func(response api.GenerateResponse) error {
 		p.StopAndClear()
+
 		latest = response
-		if response.Message == nil {
-			// warm-up response or done
-			return nil
-		}
-		role = response.Message.Role
-		content := response.Message.Content
-		fullResponse.WriteString(content)
 
 		termWidth, _, _ = term.GetSize(int(os.Stdout.Fd()))
 		if opts.WordWrap && termWidth >= 10 {
-			for _, ch := range content {
+			for _, ch := range response.Response {
 				if currentLineLength+1 > termWidth-5 {
 					if len(wordBuffer) > termWidth-10 {
 						fmt.Printf("%s%c", wordBuffer, ch)
@@ -541,7 +551,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
 				}
 			}
 		} else {
-			fmt.Printf("%s%s", wordBuffer, content)
+			fmt.Printf("%s%s", wordBuffer, response.Response)
 			if len(wordBuffer) > 0 {
 				wordBuffer = ""
 			}
@@ -550,35 +560,35 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
 		return nil
 	}
 
-	req := &api.ChatRequest{
-		Model:    opts.Model,
-		Messages: opts.Messages,
-		Format:   opts.Format,
-		Template: opts.Template,
-		Options:  opts.Options,
-	}
-	if err := client.Chat(cancelCtx, req, fn); err != nil {
+	if err := client.Generate(cancelCtx, &request, fn); err != nil {
 		if errors.Is(err, context.Canceled) {
-			return nil, nil
+			return nil
 		}
-		return nil, err
+		return err
 	}
-
-	if len(opts.Messages) > 0 {
+	if opts.Prompt != "" {
 		fmt.Println()
 		fmt.Println()
 	}
 
+	if !latest.Done {
+		return nil
+	}
+
 	verbose, err := cmd.Flags().GetBool("verbose")
 	if err != nil {
-		return nil, err
+		return err
 	}
 
 	if verbose {
 		latest.Summary()
 	}
 
-	return &api.Message{Role: role, Content: fullResponse.String()}, nil
+	ctx := cmd.Context()
+	ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
+	cmd.SetContext(ctx)
+
+	return nil
 }
 
 type MultilineState int
@@ -590,10 +600,13 @@ const (
 	MultilineTemplate
 )
 
-func chatInteractive(cmd *cobra.Command, opts runOptions) error {
+func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 	// load the model
-	loadOpts := runOptions{Model: opts.Model}
-	if _, err := chat(cmd, loadOpts); err != nil {
+	loadOpts := generateOptions{
+		Model:  opts.Model,
+		Prompt: "",
+	}
+	if err := generate(cmd, loadOpts); err != nil {
 		return err
 	}
 
@@ -664,9 +677,7 @@ func chatInteractive(cmd *cobra.Command, opts runOptions) error {
 	defer fmt.Printf(readline.EndBracketedPaste)
 
 	var multiline MultilineState
-	var content string
-	var systemContent string
-	opts.Messages = make([]api.Message, 0)
+	var prompt string
 
 	for {
 		line, err := scanner.Readline()
@@ -680,7 +691,7 @@ func chatInteractive(cmd *cobra.Command, opts runOptions) error {
 			}
 
 			scanner.Prompt.UseAlt = false
-			content = ""
+			prompt = ""
 
 			continue
 		case err != nil:
@@ -688,37 +699,37 @@ func chatInteractive(cmd *cobra.Command, opts runOptions) error {
 		}
 
 		switch {
-		case strings.HasPrefix(content, `"""`):
+		case strings.HasPrefix(prompt, `"""`):
 			// if the prompt so far starts with """ then we're in multiline mode
 			// and we need to keep reading until we find a line that ends with """
 			cut, found := strings.CutSuffix(line, `"""`)
-			content += cut + "\n"
+			prompt += cut + "\n"
 
 			if !found {
 				continue
 			}
 
-			content = strings.TrimPrefix(content, `"""`)
+			prompt = strings.TrimPrefix(prompt, `"""`)
 			scanner.Prompt.UseAlt = false
 
 			switch multiline {
 			case MultilineSystem:
-				systemContent = content
-				content = ""
+				opts.System = prompt
+				prompt = ""
 				fmt.Println("Set system template.\n")
 			case MultilineTemplate:
-				opts.Template = content
-				content = ""
+				opts.Template = prompt
+				prompt = ""
 				fmt.Println("Set model template.\n")
 			}
 			multiline = MultilineNone
-		case strings.HasPrefix(line, `"""`) && len(content) == 0:
+		case strings.HasPrefix(line, `"""`) && len(prompt) == 0:
 			scanner.Prompt.UseAlt = true
 			multiline = MultilinePrompt
-			content += line + "\n"
+			prompt += line + "\n"
 			continue
 		case scanner.Pasting:
-			content += line + "\n"
+			prompt += line + "\n"
 			continue
 		case strings.HasPrefix(line, "/list"):
 			args := strings.Fields(line)
@@ -780,17 +791,17 @@ func chatInteractive(cmd *cobra.Command, opts runOptions) error {
 					line = strings.TrimPrefix(line, `"""`)
 					if strings.HasPrefix(args[2], `"""`) {
 						cut, found := strings.CutSuffix(line, `"""`)
-						content += cut + "\n"
+						prompt += cut + "\n"
 						if found {
-							systemContent = content
+							opts.System = prompt
 							if args[1] == "system" {
 								fmt.Println("Set system template.\n")
 							} else {
 								fmt.Println("Set prompt template.\n")
 							}
-							content = ""
+							prompt = ""
 						} else {
-							content = `"""` + content
+							prompt = `"""` + prompt
 							if args[1] == "system" {
 								multiline = MultilineSystem
 							} else {
@@ -799,7 +810,7 @@ func chatInteractive(cmd *cobra.Command, opts runOptions) error {
 							scanner.Prompt.UseAlt = true
 						}
 					} else {
-						systemContent = line
+						opts.System = line
 						fmt.Println("Set system template.\n")
 					}
 				default:
@@ -847,8 +858,8 @@ func chatInteractive(cmd *cobra.Command, opts runOptions) error {
 					}
 				case "system":
 					switch {
-					case systemContent != "":
-						fmt.Println(systemContent + "\n")
+					case opts.System != "":
+						fmt.Println(opts.System + "\n")
 					case resp.System != "":
 						fmt.Println(resp.System + "\n")
 					default:
@@ -888,23 +899,16 @@ func chatInteractive(cmd *cobra.Command, opts runOptions) error {
 			fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0])
 			continue
 		default:
-			content += line
+			prompt += line
 		}
 
-		if len(content) > 0 && multiline == MultilineNone {
-			if systemContent != "" {
-				opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: systemContent})
-			}
-			opts.Messages = append(opts.Messages, api.Message{Role: "user", Content: content})
-			assistant, err := chat(cmd, opts)
-			if err != nil {
+		if len(prompt) > 0 && multiline == MultilineNone {
+			opts.Prompt = prompt
+			if err := generate(cmd, opts); err != nil {
 				return err
 			}
-			if assistant != nil {
-				opts.Messages = append(opts.Messages, *assistant)
-			}
 
-			content = ""
+			prompt = ""
 		}
 	}
 }