ソースを参照

Cmd changes (#541)

Patrick Devine 1 年間 前
コミット
80dd44e80a
2 ファイル変更75 行追加52 行削除
  1. 69 50
      cmd/cmd.go
  2. 6 2
      server/routes.go

+ 69 - 50
cmd/cmd.go

@@ -33,6 +33,17 @@ import (
 	"github.com/jmorganca/ollama/version"
 )
 
+type Painter struct{}
+
+func (p Painter) Paint(line []rune, l int) []rune {
+	termType := os.Getenv("TERM")
+	if termType == "xterm-256color" && len(line) == 0 {
+		prompt := "Send a message (/? for help)"
+		return []rune(fmt.Sprintf("\033[38;5;245m%s\033[%dD\033[0m", prompt, len(prompt)))
+	}
+	return line
+}
+
 func CreateHandler(cmd *cobra.Command, args []string) error {
 	filename, _ := cmd.Flags().GetString("file")
 	filename, err := filepath.Abs(filename)
@@ -387,71 +398,71 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
 type generateContextKey string
 
 func generate(cmd *cobra.Command, model, prompt string) error {
-	if len(strings.TrimSpace(prompt)) > 0 {
-		client, err := api.FromEnv()
-		if err != nil {
-			return err
-		}
+	client, err := api.FromEnv()
+	if err != nil {
+		return err
+	}
 
-		spinner := NewSpinner("")
-		go spinner.Spin(60 * time.Millisecond)
+	spinner := NewSpinner("")
+	go spinner.Spin(60 * time.Millisecond)
 
-		var latest api.GenerateResponse
+	var latest api.GenerateResponse
 
-		generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
-		if !ok {
-			generateContext = []int{}
-		}
+	generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
+	if !ok {
+		generateContext = []int{}
+	}
 
-		request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
-		fn := func(response api.GenerateResponse) error {
-			if !spinner.IsFinished() {
-				spinner.Finish()
-			}
+	request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
+	fn := func(response api.GenerateResponse) error {
+		if !spinner.IsFinished() {
+			spinner.Finish()
+		}
 
-			latest = response
+		latest = response
 
-			fmt.Print(response.Response)
-			return nil
-		}
+		fmt.Print(response.Response)
+		return nil
+	}
 
-		if err := client.Generate(context.Background(), &request, fn); err != nil {
-			if strings.Contains(err.Error(), "failed to load model") {
-				// tell the user to check the server log, if it exists locally
-				home, nestedErr := os.UserHomeDir()
-				if nestedErr != nil {
-					// return the original error
-					return err
-				}
-				logPath := filepath.Join(home, ".ollama", "logs", "server.log")
-				if _, nestedErr := os.Stat(logPath); nestedErr == nil {
-					err = fmt.Errorf("%w\nFor more details, check the error logs at %s", err, logPath)
-				}
+	if err := client.Generate(context.Background(), &request, fn); err != nil {
+		if strings.Contains(err.Error(), "failed to load model") {
+			// tell the user to check the server log, if it exists locally
+			home, nestedErr := os.UserHomeDir()
+			if nestedErr != nil {
+				// return the original error
+				return err
+			}
+			logPath := filepath.Join(home, ".ollama", "logs", "server.log")
+			if _, nestedErr := os.Stat(logPath); nestedErr == nil {
+				err = fmt.Errorf("%w\nFor more details, check the error logs at %s", err, logPath)
 			}
-			return err
 		}
+		return err
+	}
 
+	if prompt != "" {
 		fmt.Println()
 		fmt.Println()
+	}
 
-		if !latest.Done {
-			return errors.New("unexpected end of response")
-		}
-
-		verbose, err := cmd.Flags().GetBool("verbose")
-		if err != nil {
-			return err
-		}
+	if !latest.Done {
+		return errors.New("unexpected end of response")
+	}
 
-		if verbose {
-			latest.Summary()
-		}
+	verbose, err := cmd.Flags().GetBool("verbose")
+	if err != nil {
+		return err
+	}
 
-		ctx := cmd.Context()
-		ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
-		cmd.SetContext(ctx)
+	if verbose {
+		latest.Summary()
 	}
 
+	ctx := cmd.Context()
+	ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
+	cmd.SetContext(ctx)
+
 	return nil
 }
 
@@ -461,6 +472,11 @@ func generateInteractive(cmd *cobra.Command, model string) error {
 		return err
 	}
 
+	// load the model
+	if err := generate(cmd, model, ""); err != nil {
+		return err
+	}
+
 	completer := readline.NewPrefixCompleter(
 		readline.PcItem("/help"),
 		readline.PcItem("/list"),
@@ -492,6 +508,7 @@ func generateInteractive(cmd *cobra.Command, model string) error {
 	}
 
 	config := readline.Config{
+		Painter:      Painter{},
 		Prompt:       ">>> ",
 		HistoryFile:  filepath.Join(home, ".ollama", "history"),
 		AutoComplete: completer,
@@ -621,8 +638,10 @@ func generateInteractive(cmd *cobra.Command, model string) error {
 			return nil
 		}
 
-		if err := generate(cmd, model, line); err != nil {
-			return err
+		if len(line) > 0 && line[0] != '/' {
+			if err := generate(cmd, model, line); err != nil {
+				return err
+			}
 		}
 	}
 }

+ 6 - 2
server/routes.go

@@ -218,8 +218,12 @@ func GenerateHandler(c *gin.Context) {
 			ch <- r
 		}
 
-		if err := loaded.llm.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
-			ch <- gin.H{"error": err.Error()}
+		if req.Prompt == "" {
+			ch <- api.GenerateResponse{Model: req.Model, Done: true}
+		} else {
+			if err := loaded.llm.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
+				ch <- gin.H{"error": err.Error()}
+			}
 		}
 	}()