浏览代码

run prompts

Michael Yang 1 年之前
父节点
当前提交
3d6009aae3
共有 4 个文件被更改,包括 91 次插入32 次删除
  1. 10 12
      api/client.go
  2. 78 20
      cmd/cmd.go
  3. 1 0
      go.mod
  4. 2 0
      go.sum

+ 10 - 12
api/client.go

@@ -5,6 +5,7 @@ import (
 	"bytes"
 	"context"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"net/http"
@@ -63,20 +64,19 @@ func (c *Client) stream(ctx context.Context, method string, path string, reqData
 
 	for {
 		line, err := reader.ReadBytes('\n')
-		if err != nil {
-			if err == io.EOF {
-				break
-			} else {
-				return err // Handle other errors
-			}
+		switch {
+		case errors.Is(err, io.EOF):
+			return nil
+		case err != nil:
+			return err
 		}
+
 		if err := checkError(res, line); err != nil {
 			return err
 		}
+
 		callback(bytes.TrimSuffix(line, []byte("\n")))
 	}
-
-	return nil
 }
 
 func (c *Client) do(ctx context.Context, method string, path string, reqData any, respData any) error {
@@ -124,11 +124,9 @@ func (c *Client) do(ctx context.Context, method string, path string, reqData any
 	return nil
 }
 
-func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback func(token string)) (*GenerateResponse, error) {
+func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback func(bts []byte)) (*GenerateResponse, error) {
 	var res GenerateResponse
-	if err := c.stream(ctx, http.MethodPost, "/api/generate", req, func(token []byte) {
-		callback(string(token))
-	}); err != nil {
+	if err := c.stream(ctx, http.MethodPost, "/api/generate", req, callback); err != nil {
 		return nil, err
 	}
 

+ 78 - 20
cmd/cmd.go

@@ -1,7 +1,9 @@
 package cmd
 
 import (
+	"bufio"
 	"context"
+	"encoding/json"
 	"fmt"
 	"log"
 	"net"
@@ -10,9 +12,11 @@ import (
 	"sync"
 
 	"github.com/gosuri/uiprogress"
+	"github.com/spf13/cobra"
+	"golang.org/x/term"
+
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/server"
-	"github.com/spf13/cobra"
 )
 
 func cacheDir() string {
@@ -28,13 +32,13 @@ func bytesToGB(bytes int) float64 {
 	return float64(bytes) / float64(1<<30)
 }
 
-func run(model string) error {
+func RunRun(cmd *cobra.Command, args []string) error {
 	client, err := NewAPIClient()
 	if err != nil {
 		return err
 	}
 	pr := api.PullRequest{
-		Model: model,
+		Model: args[0],
 	}
 	var bar *uiprogress.Bar
 	mutex := &sync.Mutex{}
@@ -60,10 +64,71 @@ func run(model string) error {
 		return err
 	}
 	fmt.Println("Up to date.")
+	return RunGenerate(cmd, args)
+}
+
+func RunGenerate(_ *cobra.Command, args []string) error {
+	if len(args) > 1 {
+		return generate(args[0], args[1:]...)
+	}
+
+	if term.IsTerminal(int(os.Stdin.Fd())) {
+		return generateInteractive(args[0])
+	}
+
+	return generateBatch(args[0])
+}
+
+func generate(model string, prompts ...string) error {
+	client, err := NewAPIClient()
+	if err != nil {
+		return err
+	}
+
+	for _, prompt := range prompts {
+		client.Generate(context.Background(), &api.GenerateRequest{Model: model, Prompt: prompt}, func(bts []byte) {
+			var resp api.GenerateResponse
+			if err := json.Unmarshal(bts, &resp); err != nil {
+				return
+			}
+
+			fmt.Print(resp.Response)
+		})
+	}
+
+	fmt.Println()
+	fmt.Println()
+	return nil
+}
+
+func generateInteractive(model string) error {
+	fmt.Print(">>> ")
+	scanner := bufio.NewScanner(os.Stdin)
+	for scanner.Scan() {
+		if err := generate(model, scanner.Text()); err != nil {
+			return err
+		}
+
+		fmt.Print(">>> ")
+	}
+
 	return nil
 }
 
-func serve() error {
+func generateBatch(model string) error {
+	scanner := bufio.NewScanner(os.Stdin)
+	for scanner.Scan() {
+		prompt := scanner.Text()
+		fmt.Printf(">>> %s\n", prompt)
+		if err := generate(model, prompt); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func RunServer(_ *cobra.Command, _ []string) error {
 	ln, err := net.Listen("tcp", "127.0.0.1:11434")
 	if err != nil {
 		return err
@@ -82,39 +147,32 @@ func NewCLI() *cobra.Command {
 	log.SetFlags(log.LstdFlags | log.Lshortfile)
 
 	rootCmd := &cobra.Command{
-		Use:   "ollama",
-		Short: "Large language model runner",
+		Use:          "ollama",
+		Short:        "Large language model runner",
+		SilenceUsage: true,
 		CompletionOptions: cobra.CompletionOptions{
 			DisableDefaultCmd: true,
 		},
-		PersistentPreRun: func(cmd *cobra.Command, args []string) {
-			// Disable usage printing on errors
-			cmd.SilenceUsage = true
+		PersistentPreRunE: func(_ *cobra.Command, args []string) error {
 			// create the models directory and it's parent
-			if err := os.MkdirAll(path.Join(cacheDir(), "models"), 0o700); err != nil {
-				panic(err)
-			}
+			return os.MkdirAll(path.Join(cacheDir(), "models"), 0o700)
 		},
 	}
 
 	cobra.EnableCommandSorting = false
 
 	runCmd := &cobra.Command{
-		Use:   "run MODEL",
+		Use:   "run MODEL [PROMPT]",
 		Short: "Run a model",
-		Args:  cobra.ExactArgs(1),
-		RunE: func(cmd *cobra.Command, args []string) error {
-			return run(args[0])
-		},
+		Args:  cobra.MinimumNArgs(1),
+		RunE:  RunRun,
 	}
 
 	serveCmd := &cobra.Command{
 		Use:     "serve",
 		Aliases: []string{"start"},
 		Short:   "Start ollama",
-		RunE: func(cmd *cobra.Command, args []string) error {
-			return serve()
-		},
+		RunE:    RunServer,
 	}
 
 	rootCmd.AddCommand(

+ 1 - 0
go.mod

@@ -35,6 +35,7 @@ require (
 	golang.org/x/crypto v0.10.0 // indirect
 	golang.org/x/net v0.10.0 // indirect
 	golang.org/x/sys v0.10.0 // indirect
+	golang.org/x/term v0.10.0
 	golang.org/x/text v0.10.0 // indirect
 	google.golang.org/protobuf v1.30.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect

+ 2 - 0
go.sum

@@ -106,6 +106,8 @@ golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
 golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
+golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c=
+golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=