Explorar o código

Allow setting parameters in the REPL (#1294)

Patrick Devine hai 1 ano
pai
achega
cde31cb220
Modificáronse 3 ficheiros con 153 adicións e 86 borrados
  1. 61 0
      api/types.go
  2. 91 26
      cmd/cmd.go
  3. 1 60
      server/images.go

+ 61 - 0
api/types.go

@@ -6,6 +6,7 @@ import (
 	"math"
 	"os"
 	"reflect"
+	"strconv"
 	"strings"
 	"time"
 )
@@ -360,3 +361,63 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
 
 	return nil
 }
+
+// FormatParams converts specified parameter options to their correct types
+func FormatParams(params map[string][]string) (map[string]interface{}, error) {
+	opts := Options{}
+	valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
+	typeOpts := reflect.TypeOf(opts)           // types of the fields in the options struct
+
+	// build map of json struct tags to their types
+	jsonOpts := make(map[string]reflect.StructField)
+	for _, field := range reflect.VisibleFields(typeOpts) {
+		jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
+		if jsonTag != "" {
+			jsonOpts[jsonTag] = field
+		}
+	}
+
+	out := make(map[string]interface{})
+	// iterate params and set values based on json struct tags
+	for key, vals := range params {
+		if opt, ok := jsonOpts[key]; !ok {
+			return nil, fmt.Errorf("unknown parameter '%s'", key)
+		} else {
+			field := valueOpts.FieldByName(opt.Name)
+			if field.IsValid() && field.CanSet() {
+				switch field.Kind() {
+				case reflect.Float32:
+					floatVal, err := strconv.ParseFloat(vals[0], 32)
+					if err != nil {
+						return nil, fmt.Errorf("invalid float value %s", vals)
+					}
+
+					out[key] = float32(floatVal)
+				case reflect.Int:
+					intVal, err := strconv.ParseInt(vals[0], 10, 64)
+					if err != nil {
+						return nil, fmt.Errorf("invalid int value %s", vals)
+					}
+
+					out[key] = intVal
+				case reflect.Bool:
+					boolVal, err := strconv.ParseBool(vals[0])
+					if err != nil {
+						return nil, fmt.Errorf("invalid bool value %s", vals)
+					}
+
+					out[key] = boolVal
+				case reflect.String:
+					out[key] = vals[0]
+				case reflect.Slice:
+					// TODO: only string slices are supported right now
+					out[key] = vals
+				default:
+					return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
+				}
+			}
+		}
+	}
+
+	return out, nil
+}

+ 91 - 26
cmd/cmd.go

@@ -412,10 +412,19 @@ func PullHandler(cmd *cobra.Command, args []string) error {
 }
 
 func RunGenerate(cmd *cobra.Command, args []string) error {
+	interactive := true
+
+	opts := generateOptions{
+		Model:    args[0],
+		WordWrap: os.Getenv("TERM") == "xterm-256color",
+		Options:  map[string]interface{}{},
+	}
+
 	format, err := cmd.Flags().GetString("format")
 	if err != nil {
 		return err
 	}
+	opts.Format = format
 
 	prompts := args[1:]
 
@@ -427,34 +436,38 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
 		}
 
 		prompts = append([]string{string(in)}, prompts...)
+		opts.WordWrap = false
+		interactive = false
 	}
-
-	// output is being piped
-	if !term.IsTerminal(int(os.Stdout.Fd())) {
-		return generate(cmd, args[0], strings.Join(prompts, " "), false, format)
+	opts.Prompt = strings.Join(prompts, " ")
+	if len(prompts) > 0 {
+		interactive = false
 	}
 
-	wordWrap := os.Getenv("TERM") == "xterm-256color"
-
 	nowrap, err := cmd.Flags().GetBool("nowordwrap")
 	if err != nil {
 		return err
 	}
-	if nowrap {
-		wordWrap = false
-	}
+	opts.WordWrap = !nowrap
 
-	// prompts are provided via stdin or args so don't enter interactive mode
-	if len(prompts) > 0 {
-		return generate(cmd, args[0], strings.Join(prompts, " "), wordWrap, format)
+	if !interactive {
+		return generate(cmd, opts)
 	}
 
-	return generateInteractive(cmd, args[0], wordWrap, format)
+	return generateInteractive(cmd, opts)
 }
 
 type generateContextKey string
 
-func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format string) error {
+type generateOptions struct {
+	Model    string
+	Prompt   string
+	WordWrap bool
+	Format   string
+	Options  map[string]interface{}
+}
+
+func generate(cmd *cobra.Command, opts generateOptions) error {
 	client, err := api.ClientFromEnvironment()
 	if err != nil {
 		return err
@@ -475,7 +488,7 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
 
 	termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
 	if err != nil {
-		wordWrap = false
+		opts.WordWrap = false
 	}
 
 	cancelCtx, cancel := context.WithCancel(context.Background())
@@ -494,13 +507,19 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
 	var currentLineLength int
 	var wordBuffer string
 
-	request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, Format: format}
+	request := api.GenerateRequest{
+		Model:   opts.Model,
+		Prompt:  opts.Prompt,
+		Context: generateContext,
+		Format:  opts.Format,
+		Options: opts.Options,
+	}
 	fn := func(response api.GenerateResponse) error {
 		p.StopAndClear()
 
 		latest = response
 
-		if wordWrap {
+		if opts.WordWrap {
 			for _, ch := range response.Response {
 				if currentLineLength+1 > termWidth-5 {
 					// backtrack the length of the last word and clear to the end of the line
@@ -534,7 +553,7 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
 		}
 		return err
 	}
-	if prompt != "" {
+	if opts.Prompt != "" {
 		fmt.Println()
 		fmt.Println()
 	}
@@ -562,9 +581,13 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
 	return nil
 }
 
-func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format string) error {
+func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 	// load the model
-	if err := generate(cmd, model, "", false, ""); err != nil {
+	loadOpts := generateOptions{
+		Model:  opts.Model,
+		Prompt: "",
+	}
+	if err := generate(cmd, loadOpts); err != nil {
 		return err
 	}
 
@@ -581,6 +604,7 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
 
 	usageSet := func() {
 		fmt.Fprintln(os.Stderr, "Available Commands:")
+		fmt.Fprintln(os.Stderr, "  /set parameter    Set a parameter")
 		fmt.Fprintln(os.Stderr, "  /set history      Enable history")
 		fmt.Fprintln(os.Stderr, "  /set nohistory    Disable history")
 		fmt.Fprintln(os.Stderr, "  /set wordwrap     Enable wordwrap")
@@ -602,6 +626,22 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
 		fmt.Fprintln(os.Stderr, "")
 	}
 
+	// only list out the most common parameters
+	usageParameters := func() {
+		fmt.Fprintln(os.Stderr, "Available Parameters:")
+		fmt.Fprintln(os.Stderr, "  /set parameter seed <int>             Random number seed")
+		fmt.Fprintln(os.Stderr, "  /set parameter num_predict <int>      Max number of tokens to predict")
+		fmt.Fprintln(os.Stderr, "  /set parameter top_k <int>            Pick from top k num of tokens")
+		fmt.Fprintln(os.Stderr, "  /set parameter top_p <float>          Pick token based on sum of probabilities")
+		fmt.Fprintln(os.Stderr, "  /set parameter num_ctx <int>          Set the context size")
+		fmt.Fprintln(os.Stderr, "  /set parameter temperature <float>    Set creativity level")
+		fmt.Fprintln(os.Stderr, "  /set parameter repeat_penalty <float> How strongly to penalize repetitions")
+		fmt.Fprintln(os.Stderr, "  /set parameter repeat_last_n <int>    Set how far back to look for repetitions")
+		fmt.Fprintln(os.Stderr, "  /set parameter num_gpu <int>          The number of layers to send to the GPU")
+		fmt.Fprintln(os.Stderr, "  /set parameter stop \"<string>\", ...   Set the stop parameters")
+		fmt.Fprintln(os.Stderr, "")
+	}
+
 	scanner, err := readline.New(readline.Prompt{
 		Prompt:         ">>> ",
 		AltPrompt:      "... ",
@@ -670,10 +710,10 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
 				case "nohistory":
 					scanner.HistoryDisable()
 				case "wordwrap":
-					wordWrap = true
+					opts.WordWrap = true
 					fmt.Println("Set 'wordwrap' mode.")
 				case "nowordwrap":
-					wordWrap = false
+					opts.WordWrap = false
 					fmt.Println("Set 'nowordwrap' mode.")
 				case "verbose":
 					cmd.Flags().Set("verbose", "true")
@@ -685,12 +725,28 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
 					if len(args) < 3 || args[2] != "json" {
 						fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
 					} else {
-						format = args[2]
+						opts.Format = args[2]
 						fmt.Printf("Set format to '%s' mode.\n", args[2])
 					}
 				case "noformat":
-					format = ""
+					opts.Format = ""
 					fmt.Println("Disabled format.")
+				case "parameter":
+					if len(args) < 4 {
+						usageParameters()
+						continue
+					}
+					var params []string
+					for _, p := range args[3:] {
+						params = append(params, p)
+					}
+					fp, err := api.FormatParams(map[string][]string{args[2]: params})
+					if err != nil {
+						fmt.Printf("Couldn't set parameter: %q\n\n", err)
+						continue
+					}
+					fmt.Printf("Set parameter '%s' to '%s'\n\n", args[2], strings.Join(params, ", "))
+					opts.Options[args[2]] = fp[args[2]]
 				default:
 					fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
 				}
@@ -705,7 +761,7 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
 					fmt.Println("error: couldn't connect to ollama server")
 					return err
 				}
-				resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model})
+				resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: opts.Model})
 				if err != nil {
 					fmt.Println("error: couldn't get model")
 					return err
@@ -724,6 +780,14 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
 					if resp.Parameters == "" {
 						fmt.Print("No parameters were specified for this model.\n\n")
 					} else {
+						if len(opts.Options) > 0 {
+							fmt.Println("User defined parameters:")
+							for k, v := range opts.Options {
+								fmt.Printf("%-*s %v\n", 30, k, v)
+							}
+							fmt.Println()
+						}
+						fmt.Println("Model defined parameters:")
 						fmt.Println(resp.Parameters)
 					}
 				case "system":
@@ -767,7 +831,8 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
 		}
 
 		if len(prompt) > 0 && prompt[0] != '/' {
-			if err := generate(cmd, model, prompt, wordWrap, format); err != nil {
+			opts.Prompt = prompt
+			if err := generate(cmd, opts); err != nil {
 				return err
 			}
 

+ 1 - 60
server/images.go

@@ -14,7 +14,6 @@ import (
 	"net/url"
 	"os"
 	"path/filepath"
-	"reflect"
 	"runtime"
 	"strconv"
 	"strings"
@@ -426,7 +425,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 	if len(params) > 0 {
 		fn(api.ProgressResponse{Status: "creating parameters layer"})
 
-		formattedParams, err := formatParams(params)
+		formattedParams, err := api.FormatParams(params)
 		if err != nil {
 			return err
 		}
@@ -581,64 +580,6 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
 	return newLayer, nil
 }
 
-// formatParams converts specified parameter options to their correct types
-func formatParams(params map[string][]string) (map[string]interface{}, error) {
-	opts := api.Options{}
-	valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
-	typeOpts := reflect.TypeOf(opts)           // types of the fields in the options struct
-
-	// build map of json struct tags to their types
-	jsonOpts := make(map[string]reflect.StructField)
-	for _, field := range reflect.VisibleFields(typeOpts) {
-		jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
-		if jsonTag != "" {
-			jsonOpts[jsonTag] = field
-		}
-	}
-
-	out := make(map[string]interface{})
-	// iterate params and set values based on json struct tags
-	for key, vals := range params {
-		if opt, ok := jsonOpts[key]; ok {
-			field := valueOpts.FieldByName(opt.Name)
-			if field.IsValid() && field.CanSet() {
-				switch field.Kind() {
-				case reflect.Float32:
-					floatVal, err := strconv.ParseFloat(vals[0], 32)
-					if err != nil {
-						return nil, fmt.Errorf("invalid float value %s", vals)
-					}
-
-					out[key] = float32(floatVal)
-				case reflect.Int:
-					intVal, err := strconv.ParseInt(vals[0], 10, 64)
-					if err != nil {
-						return nil, fmt.Errorf("invalid int value %s", vals)
-					}
-
-					out[key] = intVal
-				case reflect.Bool:
-					boolVal, err := strconv.ParseBool(vals[0])
-					if err != nil {
-						return nil, fmt.Errorf("invalid bool value %s", vals)
-					}
-
-					out[key] = boolVal
-				case reflect.String:
-					out[key] = vals[0]
-				case reflect.Slice:
-					// TODO: only string slices are supported right now
-					out[key] = vals
-				default:
-					return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
-				}
-			}
-		}
-	}
-
-	return out, nil
-}
-
 func getLayerDigests(layers []*LayerReader) ([]string, error) {
 	var digests []string
 	for _, l := range layers {