|
@@ -81,6 +81,7 @@ from config import (
|
|
|
TASK_MODEL_EXTERNAL,
|
|
|
TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
|
|
AppConfig,
|
|
|
)
|
|
|
from constants import ERROR_MESSAGES
|
|
@@ -144,6 +145,9 @@ app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMP
|
|
|
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
|
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
|
)
|
|
|
+app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
|
|
|
+ SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
|
|
|
+)
|
|
|
|
|
|
app.state.MODELS = {}
|
|
|
|
|
@@ -596,6 +600,12 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|
|
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
|
|
|
print("generate_search_query")
|
|
|
|
|
|
+ if len(form_data["prompt"]) < app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
+ detail=f"Skip search query generation for short prompts (< {app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD} characters)",
|
|
|
+ )
|
|
|
+
|
|
|
model_id = form_data["model"]
|
|
|
if model_id not in app.state.MODELS:
|
|
|
raise HTTPException(
|