Browse Source

limit `num_predict` to `num_ctx`

Jeffrey Morgan 1 year ago
parent
commit
ca7c3f7e0f
1 changed files with 13 additions and 0 deletions
  1. 13 0
      llm/dyn_ext_server.go

+ 13 - 0
llm/dyn_ext_server.go

@@ -172,6 +172,19 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
 		slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images)))
 	}
 
+	// Limit the number of predictions to the maximum context length
+	// this will cause no more than two context shifts
+	// TODO: limit this further to num_ctx - len(prompt) to avoid
+	// any context shifts at all
+	if predict.Options.NumPredict > llm.options.NumCtx {
+		slog.Warn(fmt.Sprintf("requested num_predict is greater than the context length (%d > %d), using %d instead", predict.Options.NumPredict, llm.options.NumCtx, llm.options.NumCtx))
+		predict.Options.NumPredict = llm.options.NumCtx
+	}
+
+	if predict.Options.NumPredict == -1 {
+		predict.Options.NumPredict = llm.options.NumCtx
+	}
+
 	request := map[string]any{
 		"prompt":            predict.Prompt,
 		"stream":            true,