|
@@ -29,6 +29,7 @@ from open_webui.routers.tasks import (
|
|
|
generate_title,
|
|
|
generate_chat_tags,
|
|
|
)
|
|
|
+from open_webui.routers.retrieval import process_web_search, SearchForm
|
|
|
from open_webui.utils.webhook import post_webhook
|
|
|
|
|
|
|
|
@@ -333,6 +334,149 @@ async def chat_completion_tools_handler(
|
|
|
return body, {"sources": sources}
|
|
|
|
|
|
|
|
|
+async def chat_web_search_handler(
|
|
|
+ request: Request, form_data: dict, extra_params: dict, user
|
|
|
+):
|
|
|
+ event_emitter = extra_params["__event_emitter__"]
|
|
|
+ await event_emitter(
|
|
|
+ {
|
|
|
+ "type": "status",
|
|
|
+ "data": {
|
|
|
+ "action": "web_search",
|
|
|
+ "description": "Generating search query",
|
|
|
+ "done": False,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ messages = form_data["messages"]
|
|
|
+ user_message = get_last_user_message(messages)
|
|
|
+
|
|
|
+ queries = []
|
|
|
+ try:
|
|
|
+ res = await generate_queries(
|
|
|
+ request,
|
|
|
+ {
|
|
|
+ "model": form_data["model"],
|
|
|
+ "messages": messages,
|
|
|
+ "prompt": user_message,
|
|
|
+ "type": "web_search",
|
|
|
+ },
|
|
|
+ user,
|
|
|
+ )
|
|
|
+
|
|
|
+ response = res["choices"][0]["message"]["content"]
|
|
|
+
|
|
|
+ try:
|
|
|
+ bracket_start = response.find("{")
|
|
|
+ bracket_end = response.rfind("}") + 1
|
|
|
+
|
|
|
+ if bracket_start == -1 or bracket_end == -1:
|
|
|
+ raise Exception("No JSON object found in the response")
|
|
|
+
|
|
|
+ response = response[bracket_start:bracket_end]
|
|
|
+ queries = json.loads(response)
|
|
|
+ queries = queries.get("queries", [])
|
|
|
+ except Exception as e:
|
|
|
+ queries = [response]
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ log.exception(e)
|
|
|
+ queries = [user_message]
|
|
|
+
|
|
|
+ if len(queries) == 0:
|
|
|
+ await event_emitter(
|
|
|
+ {
|
|
|
+ "type": "status",
|
|
|
+ "data": {
|
|
|
+ "action": "web_search",
|
|
|
+ "description": "No search query generated",
|
|
|
+ "done": True,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ )
|
|
|
+ return
|
|
|
+
|
|
|
+ searchQuery = queries[0]
|
|
|
+
|
|
|
+ await event_emitter(
|
|
|
+ {
|
|
|
+ "type": "status",
|
|
|
+ "data": {
|
|
|
+ "action": "web_search",
|
|
|
+ "description": 'Searching "{{searchQuery}}"',
|
|
|
+ "query": searchQuery,
|
|
|
+ "done": False,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ try:
|
|
|
+ results = await process_web_search(
|
|
|
+ request,
|
|
|
+ SearchForm(
|
|
|
+ **{
|
|
|
+ "query": searchQuery,
|
|
|
+ }
|
|
|
+ ),
|
|
|
+ user,
|
|
|
+ )
|
|
|
+
|
|
|
+ if results:
|
|
|
+ await event_emitter(
|
|
|
+ {
|
|
|
+ "type": "status",
|
|
|
+ "data": {
|
|
|
+ "action": "web_search",
|
|
|
+ "description": "Searched {{count}} sites",
|
|
|
+ "query": searchQuery,
|
|
|
+ "urls": results["filenames"],
|
|
|
+ "done": True,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ files = form_data.get("files", [])
|
|
|
+ files.append(
|
|
|
+ {
|
|
|
+ "collection_name": results["collection_name"],
|
|
|
+ "name": searchQuery,
|
|
|
+ "type": "web_search_results",
|
|
|
+ "urls": results["filenames"],
|
|
|
+ }
|
|
|
+ )
|
|
|
+ form_data["files"] = files
|
|
|
+ else:
|
|
|
+ await event_emitter(
|
|
|
+ {
|
|
|
+ "type": "status",
|
|
|
+ "data": {
|
|
|
+ "action": "web_search",
|
|
|
+ "description": "No search results found",
|
|
|
+ "query": searchQuery,
|
|
|
+ "done": True,
|
|
|
+ "error": True,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ log.exception(e)
|
|
|
+ await event_emitter(
|
|
|
+ {
|
|
|
+ "type": "status",
|
|
|
+ "data": {
|
|
|
+ "action": "web_search",
|
|
|
+ "description": 'Error searching "{{searchQuery}}"',
|
|
|
+ "query": searchQuery,
|
|
|
+ "done": True,
|
|
|
+ "error": True,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ return form_data
|
|
|
+
|
|
|
+
|
|
|
async def chat_completion_files_handler(
|
|
|
request: Request, body: dict, user: UserModel
|
|
|
) -> tuple[dict, dict[str, list]]:
|
|
@@ -456,7 +600,6 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
|
|
|
|
|
knowledge_files = []
|
|
|
for item in model_knowledge:
|
|
|
- print(item)
|
|
|
if item.get("collection_name"):
|
|
|
knowledge_files.append(
|
|
|
{
|
|
@@ -481,6 +624,13 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
|
|
files.extend(knowledge_files)
|
|
|
form_data["files"] = files
|
|
|
|
|
|
+ features = form_data.pop("features", None)
|
|
|
+ if features:
|
|
|
+ if "web_search" in features and features["web_search"]:
|
|
|
+ form_data = await chat_web_search_handler(
|
|
|
+ request, form_data, extra_params, user
|
|
|
+ )
|
|
|
+
|
|
|
try:
|
|
|
form_data, flags = await chat_completion_filter_functions_handler(
|
|
|
request, form_data, model, extra_params
|