|
@@ -78,11 +78,13 @@ from open_webui.config import (
|
|
|
ENV,
|
|
|
FRONTEND_BUILD_DIR,
|
|
|
OAUTH_PROVIDERS,
|
|
|
- ENABLE_SEARCH_QUERY,
|
|
|
- SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
|
STATIC_DIR,
|
|
|
TASK_MODEL,
|
|
|
TASK_MODEL_EXTERNAL,
|
|
|
+ ENABLE_SEARCH_QUERY_GENERATION,
|
|
|
+ ENABLE_RETRIEVAL_QUERY_GENERATION,
|
|
|
+ QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
|
TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
|
TAGS_GENERATION_PROMPT_TEMPLATE,
|
|
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
@@ -122,7 +124,7 @@ from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
|
|
from open_webui.utils.task import (
|
|
|
moa_response_generation_template,
|
|
|
tags_generation_template,
|
|
|
- search_query_generation_template,
|
|
|
+ query_generation_template,
|
|
|
emoji_generation_template,
|
|
|
title_generation_template,
|
|
|
tools_function_calling_generation_template,
|
|
@@ -206,10 +208,9 @@ app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
|
|
|
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
|
|
|
|
|
|
|
|
|
-app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY
|
|
|
-app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
|
|
- SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
|
-)
|
|
|
+app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION
|
|
|
+app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
|
|
|
+app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE
|
|
|
|
|
|
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
|
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
@@ -492,14 +493,41 @@ async def chat_completion_tools_handler(
|
|
|
return body, {"contexts": contexts, "citations": citations}
|
|
|
|
|
|
|
|
|
-async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
|
|
|
+async def chat_completion_files_handler(
|
|
|
+ body: dict, user: UserModel
|
|
|
+) -> tuple[dict, dict[str, list]]:
|
|
|
contexts = []
|
|
|
citations = []
|
|
|
|
|
|
+ try:
|
|
|
+ queries_response = await generate_queries(
|
|
|
+ {
|
|
|
+ "model": body["model"],
|
|
|
+ "messages": body["messages"],
|
|
|
+ "type": "retrieval",
|
|
|
+ },
|
|
|
+ user,
|
|
|
+ )
|
|
|
+ queries_response = queries_response["choices"][0]["message"]["content"]
|
|
|
+
|
|
|
+ try:
|
|
|
+ queries_response = json.loads(queries_response)
|
|
|
+ except Exception as e:
|
|
|
+ queries_response = {"queries": []}
|
|
|
+
|
|
|
+ queries = queries_response.get("queries", [])
|
|
|
+ except Exception as e:
|
|
|
+ queries = []
|
|
|
+
|
|
|
+ if len(queries) == 0:
|
|
|
+ queries = [get_last_user_message(body["messages"])]
|
|
|
+
|
|
|
+ print(f"{queries=}")
|
|
|
+
|
|
|
if files := body.get("metadata", {}).get("files", None):
|
|
|
contexts, citations = get_rag_context(
|
|
|
files=files,
|
|
|
- messages=body["messages"],
|
|
|
+ queries=queries,
|
|
|
embedding_function=retrieval_app.state.EMBEDDING_FUNCTION,
|
|
|
k=retrieval_app.state.config.TOP_K,
|
|
|
reranking_function=retrieval_app.state.sentence_transformer_rf,
|
|
@@ -643,7 +671,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
log.exception(e)
|
|
|
|
|
|
try:
|
|
|
- body, flags = await chat_completion_files_handler(body)
|
|
|
+ body, flags = await chat_completion_files_handler(body, user)
|
|
|
contexts.extend(flags.get("contexts", []))
|
|
|
citations.extend(flags.get("citations", []))
|
|
|
except Exception as e:
|
|
@@ -1579,8 +1607,9 @@ async def get_task_config(user=Depends(get_verified_user)):
|
|
|
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
|
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
|
|
"ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION,
|
|
|
- "ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
|
|
|
- "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
|
|
+ "ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
|
|
+ "QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
|
}
|
|
|
|
|
@@ -1591,8 +1620,9 @@ class TaskConfigForm(BaseModel):
|
|
|
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
|
|
TAGS_GENERATION_PROMPT_TEMPLATE: str
|
|
|
ENABLE_TAGS_GENERATION: bool
|
|
|
- SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
|
|
|
- ENABLE_SEARCH_QUERY: bool
|
|
|
+ ENABLE_SEARCH_QUERY_GENERATION: bool
|
|
|
+ ENABLE_RETRIEVAL_QUERY_GENERATION: bool
|
|
|
+ QUERY_GENERATION_PROMPT_TEMPLATE: str
|
|
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
|
|
|
|
|
|
|
|
@@ -1607,11 +1637,16 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
|
|
|
form_data.TAGS_GENERATION_PROMPT_TEMPLATE
|
|
|
)
|
|
|
app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
|
|
|
+ app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
|
|
|
+ form_data.ENABLE_SEARCH_QUERY_GENERATION
|
|
|
+ )
|
|
|
+ app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
|
|
|
+ form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
|
|
|
+ )
|
|
|
|
|
|
- app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
|
|
- form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
|
+ app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
|
|
|
+ form_data.QUERY_GENERATION_PROMPT_TEMPLATE
|
|
|
)
|
|
|
- app.state.config.ENABLE_SEARCH_QUERY = form_data.ENABLE_SEARCH_QUERY
|
|
|
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
|
|
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
|
)
|
|
@@ -1622,8 +1657,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
|
|
|
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
|
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
|
|
"ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION,
|
|
|
- "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
|
- "ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
|
|
|
+ "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
|
|
+ "ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
|
|
+ "QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
|
}
|
|
|
|
|
@@ -1799,14 +1835,22 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
|
|
|
return await generate_chat_completions(form_data=payload, user=user)
|
|
|
|
|
|
|
|
|
-@app.post("/api/task/query/completions")
|
|
|
-async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
|
|
|
- print("generate_search_query")
|
|
|
- if not app.state.config.ENABLE_SEARCH_QUERY:
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
- detail=f"Search query generation is disabled",
|
|
|
- )
|
|
|
+@app.post("/api/task/queries/completions")
|
|
|
+async def generate_queries(form_data: dict, user=Depends(get_verified_user)):
|
|
|
+ print("generate_queries")
|
|
|
+ type = form_data.get("type")
|
|
|
+ if type == "web_search":
|
|
|
+ if not app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
+ detail=f"Search query generation is disabled",
|
|
|
+ )
|
|
|
+ elif type == "retrieval":
|
|
|
+ if not app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
+ detail=f"Query generation is disabled",
|
|
|
+ )
|
|
|
|
|
|
model_list = await get_all_models()
|
|
|
models = {model["id"]: model for model in model_list}
|
|
@@ -1830,20 +1874,12 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
|
|
|
|
|
model = models[task_model_id]
|
|
|
|
|
|
- if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "":
|
|
|
- template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
|
+ if app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE != "":
|
|
|
+ template = app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
|
|
|
else:
|
|
|
- template = """Given the user's message and interaction history, decide if a web search is necessary. You must be concise and exclusively provide a search query if one is necessary. Refrain from verbose responses or any additional commentary. Prefer suggesting a search if uncertain to provide comprehensive or updated information. If a search isn't needed at all, respond with an empty string. Default to a search query when in doubt. Today's date is {{CURRENT_DATE}}.
|
|
|
-
|
|
|
-User Message:
|
|
|
-{{prompt:end:4000}}
|
|
|
+ template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
|
|
|
|
-Interaction History:
|
|
|
-{{MESSAGES:END:6}}
|
|
|
-
|
|
|
-Search Query:"""
|
|
|
-
|
|
|
- content = search_query_generation_template(
|
|
|
+ content = query_generation_template(
|
|
|
template, form_data["messages"], {"name": user.name}
|
|
|
)
|
|
|
|
|
@@ -1851,13 +1887,6 @@ Search Query:"""
|
|
|
"model": task_model_id,
|
|
|
"messages": [{"role": "user", "content": content}],
|
|
|
"stream": False,
|
|
|
- **(
|
|
|
- {"max_tokens": 30}
|
|
|
- if models[task_model_id]["owned_by"] == "ollama"
|
|
|
- else {
|
|
|
- "max_completion_tokens": 30,
|
|
|
- }
|
|
|
- ),
|
|
|
"metadata": {"task": str(TASKS.QUERY_GENERATION), "task_body": form_data},
|
|
|
}
|
|
|
log.debug(payload)
|