Browse Source

fix: relay request opts to loaded llm prediction (#1761)

Bruce MacDonald 1 year ago
parent
commit
0b3118e0af
5 changed files with 103 additions and 68 deletions
  1. 18 18
      llm/ext_server_common.go
  2. 1 1
      llm/ext_server_default.go
  3. 4 3
      llm/llama.go
  4. 1 1
      llm/shim_ext_server.go
  5. 79 45
      server/routes.go

+ 18 - 18
llm/ext_server_common.go

@@ -153,7 +153,7 @@ func newExtServer(server extServer, model string, adapters, projectors []string,
 	return server, nil
 	return server, nil
 }
 }
 
 
-func predict(llm extServer, opts api.Options, ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
+func predict(ctx context.Context, llm extServer, predict PredictOpts, fn func(PredictResult)) error {
 	resp := newExtServerResp(128)
 	resp := newExtServerResp(128)
 	defer freeExtServerResp(resp)
 	defer freeExtServerResp(resp)
 	var imageData []ImageData
 	var imageData []ImageData
@@ -167,23 +167,23 @@ func predict(llm extServer, opts api.Options, ctx context.Context, predict Predi
 	request := map[string]any{
 	request := map[string]any{
 		"prompt":            predict.Prompt,
 		"prompt":            predict.Prompt,
 		"stream":            true,
 		"stream":            true,
-		"n_predict":         opts.NumPredict,
-		"n_keep":            opts.NumKeep,
-		"temperature":       opts.Temperature,
-		"top_k":             opts.TopK,
-		"top_p":             opts.TopP,
-		"tfs_z":             opts.TFSZ,
-		"typical_p":         opts.TypicalP,
-		"repeat_last_n":     opts.RepeatLastN,
-		"repeat_penalty":    opts.RepeatPenalty,
-		"presence_penalty":  opts.PresencePenalty,
-		"frequency_penalty": opts.FrequencyPenalty,
-		"mirostat":          opts.Mirostat,
-		"mirostat_tau":      opts.MirostatTau,
-		"mirostat_eta":      opts.MirostatEta,
-		"penalize_nl":       opts.PenalizeNewline,
-		"seed":              opts.Seed,
-		"stop":              opts.Stop,
+		"n_predict":         predict.Options.NumPredict,
+		"n_keep":            predict.Options.NumKeep,
+		"temperature":       predict.Options.Temperature,
+		"top_k":             predict.Options.TopK,
+		"top_p":             predict.Options.TopP,
+		"tfs_z":             predict.Options.TFSZ,
+		"typical_p":         predict.Options.TypicalP,
+		"repeat_last_n":     predict.Options.RepeatLastN,
+		"repeat_penalty":    predict.Options.RepeatPenalty,
+		"presence_penalty":  predict.Options.PresencePenalty,
+		"frequency_penalty": predict.Options.FrequencyPenalty,
+		"mirostat":          predict.Options.Mirostat,
+		"mirostat_tau":      predict.Options.MirostatTau,
+		"mirostat_eta":      predict.Options.MirostatEta,
+		"penalize_nl":       predict.Options.PenalizeNewline,
+		"seed":              predict.Options.Seed,
+		"stop":              predict.Options.Stop,
 		"image_data":        imageData,
 		"image_data":        imageData,
 		"cache_prompt":      true,
 		"cache_prompt":      true,
 	}
 	}

+ 1 - 1
llm/ext_server_default.go

@@ -60,7 +60,7 @@ func newDefaultExtServer(model string, adapters, projectors []string, numLayers
 }
 }
 
 
 func (llm *llamaExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
 func (llm *llamaExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
-	return predict(llm, llm.Options, ctx, pred, fn)
+	return predict(ctx, llm, pred, fn)
 }
 }
 
 
 func (llm *llamaExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
 func (llm *llamaExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {

+ 4 - 3
llm/llama.go

@@ -166,9 +166,10 @@ const maxRetries = 3
 const retryDelay = 1 * time.Second
 const retryDelay = 1 * time.Second
 
 
 type PredictOpts struct {
 type PredictOpts struct {
-	Prompt string
-	Format string
-	Images []api.ImageData
+	Prompt  string
+	Format  string
+	Images  []api.ImageData
+	Options api.Options
 }
 }
 
 
 type PredictResult struct {
 type PredictResult struct {

+ 1 - 1
llm/shim_ext_server.go

@@ -92,7 +92,7 @@ func newDynamicShimExtServer(library, model string, adapters, projectors []strin
 }
 }
 
 
 func (llm *shimExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
 func (llm *shimExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
-	return predict(llm, llm.options, ctx, pred, fn)
+	return predict(ctx, llm, pred, fn)
 }
 }
 
 
 func (llm *shimExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
 func (llm *shimExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {

+ 79 - 45
server/routes.go

@@ -64,24 +64,9 @@ var loaded struct {
 var defaultSessionDuration = 5 * time.Minute
 var defaultSessionDuration = 5 * time.Minute
 
 
 // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
 // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
-func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) {
-	model, err := GetModel(modelName)
-	if err != nil {
-		return nil, err
-	}
-
+func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
 	workDir := c.GetString("workDir")
 	workDir := c.GetString("workDir")
 
 
-	opts := api.DefaultOptions()
-	if err := opts.FromMap(model.Options); err != nil {
-		log.Printf("could not load model options: %v", err)
-		return nil, err
-	}
-
-	if err := opts.FromMap(reqOpts); err != nil {
-		return nil, err
-	}
-
 	needLoad := loaded.runner == nil || // is there a model loaded?
 	needLoad := loaded.runner == nil || // is there a model loaded?
 		loaded.ModelPath != model.ModelPath || // has the base model changed?
 		loaded.ModelPath != model.ModelPath || // has the base model changed?
 		!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
 		!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
@@ -105,7 +90,7 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess
 				err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
 				err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
 			}
 			}
 
 
-			return nil, err
+			return err
 		}
 		}
 
 
 		loaded.Model = model
 		loaded.Model = model
@@ -135,7 +120,20 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess
 	}
 	}
 
 
 	loaded.expireTimer.Reset(sessionDuration)
 	loaded.expireTimer.Reset(sessionDuration)
-	return model, nil
+	return nil
+}
+
+func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
+	opts := api.DefaultOptions()
+	if err := opts.FromMap(model.Options); err != nil {
+		return api.Options{}, err
+	}
+
+	if err := opts.FromMap(requestOpts); err != nil {
+		return api.Options{}, err
+	}
+
+	return opts, nil
 }
 }
 
 
 func GenerateHandler(c *gin.Context) {
 func GenerateHandler(c *gin.Context) {
@@ -168,18 +166,30 @@ func GenerateHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	sessionDuration := defaultSessionDuration
-	model, err := load(c, req.Model, req.Options, sessionDuration)
+	model, err := GetModel(req.Model)
 	if err != nil {
 	if err != nil {
 		var pErr *fs.PathError
 		var pErr *fs.PathError
-		switch {
-		case errors.As(err, &pErr):
+		if errors.As(err, &pErr) {
 			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
 			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
-		case errors.Is(err, api.ErrInvalidOpts):
+			return
+		}
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	opts, err := modelOptions(model, req.Options)
+	if err != nil {
+		if errors.Is(err, api.ErrInvalidOpts) {
 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-		default:
-			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
 		}
 		}
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	sessionDuration := defaultSessionDuration
+	if err := load(c, model, opts, sessionDuration); err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}
 
 
@@ -287,9 +297,10 @@ func GenerateHandler(c *gin.Context) {
 
 
 		// Start prediction
 		// Start prediction
 		predictReq := llm.PredictOpts{
 		predictReq := llm.PredictOpts{
-			Prompt: prompt,
-			Format: req.Format,
-			Images: req.Images,
+			Prompt:  prompt,
+			Format:  req.Format,
+			Images:  req.Images,
+			Options: opts,
 		}
 		}
 		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
 		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 			ch <- gin.H{"error": err.Error()}
@@ -347,18 +358,29 @@ func EmbeddingHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	sessionDuration := defaultSessionDuration
-	_, err = load(c, req.Model, req.Options, sessionDuration)
+	model, err := GetModel(req.Model)
 	if err != nil {
 	if err != nil {
 		var pErr *fs.PathError
 		var pErr *fs.PathError
-		switch {
-		case errors.As(err, &pErr):
+		if errors.As(err, &pErr) {
 			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
 			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
-		case errors.Is(err, api.ErrInvalidOpts):
+			return
+		}
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	opts, err := modelOptions(model, req.Options)
+	if err != nil {
+		if errors.Is(err, api.ErrInvalidOpts) {
 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-		default:
-			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
 		}
 		}
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+	sessionDuration := defaultSessionDuration
+	if err := load(c, model, opts, sessionDuration); err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}
 
 
@@ -991,18 +1013,29 @@ func ChatHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	sessionDuration := defaultSessionDuration
-	model, err := load(c, req.Model, req.Options, sessionDuration)
+	model, err := GetModel(req.Model)
 	if err != nil {
 	if err != nil {
 		var pErr *fs.PathError
 		var pErr *fs.PathError
-		switch {
-		case errors.As(err, &pErr):
+		if errors.As(err, &pErr) {
 			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
 			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
-		case errors.Is(err, api.ErrInvalidOpts):
+			return
+		}
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	opts, err := modelOptions(model, req.Options)
+	if err != nil {
+		if errors.Is(err, api.ErrInvalidOpts) {
 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-		default:
-			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
 		}
 		}
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+	sessionDuration := defaultSessionDuration
+	if err := load(c, model, opts, sessionDuration); err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}
 
 
@@ -1053,9 +1086,10 @@ func ChatHandler(c *gin.Context) {
 
 
 		// Start prediction
 		// Start prediction
 		predictReq := llm.PredictOpts{
 		predictReq := llm.PredictOpts{
-			Prompt: prompt,
-			Format: req.Format,
-			Images: images,
+			Prompt:  prompt,
+			Format:  req.Format,
+			Images:  images,
+			Options: opts,
 		}
 		}
 		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
 		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 			ch <- gin.H{"error": err.Error()}