12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316 |
- import time
- import logging
- import sys
- import asyncio
- from aiocache import cached
- from typing import Any, Optional
- import random
- import json
- import inspect
- from uuid import uuid4
- from concurrent.futures import ThreadPoolExecutor
- from fastapi import Request
- from fastapi import BackgroundTasks
- from starlette.responses import Response, StreamingResponse
- from open_webui.models.chats import Chats
- from open_webui.models.users import Users
- from open_webui.socket.main import (
- get_event_call,
- get_event_emitter,
- get_active_status_by_user_id,
- )
- from open_webui.routers.tasks import (
- generate_queries,
- generate_title,
- generate_image_prompt,
- generate_chat_tags,
- )
- from open_webui.routers.retrieval import process_web_search, SearchForm
- from open_webui.routers.images import image_generations, GenerateImageForm
- from open_webui.utils.webhook import post_webhook
- from open_webui.models.users import UserModel
- from open_webui.models.functions import Functions
- from open_webui.models.models import Models
- from open_webui.retrieval.utils import get_sources_from_files
- from open_webui.utils.chat import generate_chat_completion
- from open_webui.utils.task import (
- get_task_model_id,
- rag_template,
- tools_function_calling_generation_template,
- )
- from open_webui.utils.misc import (
- get_message_list,
- add_or_update_system_message,
- get_last_user_message,
- get_last_assistant_message,
- prepend_to_first_user_message_content,
- )
- from open_webui.utils.tools import get_tools
- from open_webui.utils.plugin import load_function_module_by_id
- from open_webui.tasks import create_task
- from open_webui.config import DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
- from open_webui.env import (
- SRC_LOG_LEVELS,
- GLOBAL_LOG_LEVEL,
- BYPASS_MODEL_ACCESS_CONTROL,
- ENABLE_REALTIME_CHAT_SAVE,
- )
- from open_webui.constants import TASKS
- logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
- log = logging.getLogger(__name__)
- log.setLevel(SRC_LOG_LEVELS["MAIN"])
- async def chat_completion_filter_functions_handler(request, body, model, extra_params):
- skip_files = None
- def get_filter_function_ids(model):
- def get_priority(function_id):
- function = Functions.get_function_by_id(function_id)
- if function is not None and hasattr(function, "valves"):
- # TODO: Fix FunctionModel
- return (function.valves if function.valves else {}).get("priority", 0)
- return 0
- filter_ids = [
- function.id for function in Functions.get_global_filter_functions()
- ]
- if "info" in model and "meta" in model["info"]:
- filter_ids.extend(model["info"]["meta"].get("filterIds", []))
- filter_ids = list(set(filter_ids))
- enabled_filter_ids = [
- function.id
- for function in Functions.get_functions_by_type("filter", active_only=True)
- ]
- filter_ids = [
- filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
- ]
- filter_ids.sort(key=get_priority)
- return filter_ids
- filter_ids = get_filter_function_ids(model)
- for filter_id in filter_ids:
- filter = Functions.get_function_by_id(filter_id)
- if not filter:
- continue
- if filter_id in request.app.state.FUNCTIONS:
- function_module = request.app.state.FUNCTIONS[filter_id]
- else:
- function_module, _, _ = load_function_module_by_id(filter_id)
- request.app.state.FUNCTIONS[filter_id] = function_module
- # Check if the function has a file_handler variable
- if hasattr(function_module, "file_handler"):
- skip_files = function_module.file_handler
- # Apply valves to the function
- if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
- valves = Functions.get_function_valves_by_id(filter_id)
- function_module.valves = function_module.Valves(
- **(valves if valves else {})
- )
- if hasattr(function_module, "inlet"):
- try:
- inlet = function_module.inlet
- # Create a dictionary of parameters to be passed to the function
- params = {"body": body} | {
- k: v
- for k, v in {
- **extra_params,
- "__model__": model,
- "__id__": filter_id,
- }.items()
- if k in inspect.signature(inlet).parameters
- }
- if "__user__" in params and hasattr(function_module, "UserValves"):
- try:
- params["__user__"]["valves"] = function_module.UserValves(
- **Functions.get_user_valves_by_id_and_user_id(
- filter_id, params["__user__"]["id"]
- )
- )
- except Exception as e:
- print(e)
- if inspect.iscoroutinefunction(inlet):
- body = await inlet(**params)
- else:
- body = inlet(**params)
- except Exception as e:
- print(f"Error: {e}")
- raise e
- if skip_files and "files" in body.get("metadata", {}):
- del body["metadata"]["files"]
- return body, {}
- async def chat_completion_tools_handler(
- request: Request, body: dict, user: UserModel, models, extra_params: dict
- ) -> tuple[dict, dict]:
- async def get_content_from_response(response) -> Optional[str]:
- content = None
- if hasattr(response, "body_iterator"):
- async for chunk in response.body_iterator:
- data = json.loads(chunk.decode("utf-8"))
- content = data["choices"][0]["message"]["content"]
- # Cleanup any remaining background tasks if necessary
- if response.background is not None:
- await response.background()
- else:
- content = response["choices"][0]["message"]["content"]
- return content
- def get_tools_function_calling_payload(messages, task_model_id, content):
- user_message = get_last_user_message(messages)
- history = "\n".join(
- f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
- for message in messages[::-1][:4]
- )
- prompt = f"History:\n{history}\nQuery: {user_message}"
- return {
- "model": task_model_id,
- "messages": [
- {"role": "system", "content": content},
- {"role": "user", "content": f"Query: {prompt}"},
- ],
- "stream": False,
- "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
- }
- # If tool_ids field is present, call the functions
- metadata = body.get("metadata", {})
- tool_ids = metadata.get("tool_ids", None)
- log.debug(f"{tool_ids=}")
- if not tool_ids:
- return body, {}
- skip_files = False
- sources = []
- task_model_id = get_task_model_id(
- body["model"],
- request.app.state.config.TASK_MODEL,
- request.app.state.config.TASK_MODEL_EXTERNAL,
- models,
- )
- tools = get_tools(
- request,
- tool_ids,
- user,
- {
- **extra_params,
- "__model__": models[task_model_id],
- "__messages__": body["messages"],
- "__files__": metadata.get("files", []),
- },
- )
- log.info(f"{tools=}")
- specs = [tool["spec"] for tool in tools.values()]
- tools_specs = json.dumps(specs)
- if request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "":
- template = request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
- else:
- template = DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
- tools_function_calling_prompt = tools_function_calling_generation_template(
- template, tools_specs
- )
- log.info(f"{tools_function_calling_prompt=}")
- payload = get_tools_function_calling_payload(
- body["messages"], task_model_id, tools_function_calling_prompt
- )
- try:
- response = await generate_chat_completion(request, form_data=payload, user=user)
- log.debug(f"{response=}")
- content = await get_content_from_response(response)
- log.debug(f"{content=}")
- if not content:
- return body, {}
- try:
- content = content[content.find("{") : content.rfind("}") + 1]
- if not content:
- raise Exception("No JSON object found in the response")
- result = json.loads(content)
- async def tool_call_handler(tool_call):
- log.debug(f"{tool_call=}")
- tool_function_name = tool_call.get("name", None)
- if tool_function_name not in tools:
- return body, {}
- tool_function_params = tool_call.get("parameters", {})
- try:
- required_params = (
- tools[tool_function_name]
- .get("spec", {})
- .get("parameters", {})
- .get("required", [])
- )
- tool_function = tools[tool_function_name]["callable"]
- tool_function_params = {
- k: v
- for k, v in tool_function_params.items()
- if k in required_params
- }
- tool_output = await tool_function(**tool_function_params)
- except Exception as e:
- tool_output = str(e)
- if isinstance(tool_output, str):
- if tools[tool_function_name]["citation"]:
- sources.append(
- {
- "source": {
- "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
- },
- "document": [tool_output],
- "metadata": [
- {
- "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
- }
- ],
- }
- )
- else:
- sources.append(
- {
- "source": {},
- "document": [tool_output],
- "metadata": [
- {
- "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
- }
- ],
- }
- )
- if tools[tool_function_name]["file_handler"]:
- skip_files = True
- # check if "tool_calls" in result
- if result.get("tool_calls"):
- for tool_call in result.get("tool_calls"):
- await tool_call_handler(tool_call)
- else:
- await tool_call_handler(result)
- except Exception as e:
- log.exception(f"Error: {e}")
- content = None
- except Exception as e:
- log.exception(f"Error: {e}")
- content = None
- log.debug(f"tool_contexts: {sources}")
- if skip_files and "files" in body.get("metadata", {}):
- del body["metadata"]["files"]
- return body, {"sources": sources}
- async def chat_web_search_handler(
- request: Request, form_data: dict, extra_params: dict, user
- ):
- event_emitter = extra_params["__event_emitter__"]
- await event_emitter(
- {
- "type": "status",
- "data": {
- "action": "web_search",
- "description": "Generating search query",
- "done": False,
- },
- }
- )
- messages = form_data["messages"]
- user_message = get_last_user_message(messages)
- queries = []
- try:
- res = await generate_queries(
- request,
- {
- "model": form_data["model"],
- "messages": messages,
- "prompt": user_message,
- "type": "web_search",
- },
- user,
- )
- response = res["choices"][0]["message"]["content"]
- try:
- bracket_start = response.find("{")
- bracket_end = response.rfind("}") + 1
- if bracket_start == -1 or bracket_end == -1:
- raise Exception("No JSON object found in the response")
- response = response[bracket_start:bracket_end]
- queries = json.loads(response)
- queries = queries.get("queries", [])
- except Exception as e:
- queries = [response]
- except Exception as e:
- log.exception(e)
- queries = [user_message]
- if len(queries) == 0:
- await event_emitter(
- {
- "type": "status",
- "data": {
- "action": "web_search",
- "description": "No search query generated",
- "done": True,
- },
- }
- )
- return
- searchQuery = queries[0]
- await event_emitter(
- {
- "type": "status",
- "data": {
- "action": "web_search",
- "description": 'Searching "{{searchQuery}}"',
- "query": searchQuery,
- "done": False,
- },
- }
- )
- try:
- # Offload process_web_search to a separate thread
- loop = asyncio.get_running_loop()
- with ThreadPoolExecutor() as executor:
- results = await loop.run_in_executor(
- executor,
- lambda: process_web_search(
- request,
- SearchForm(
- **{
- "query": searchQuery,
- }
- ),
- user,
- ),
- )
- if results:
- await event_emitter(
- {
- "type": "status",
- "data": {
- "action": "web_search",
- "description": "Searched {{count}} sites",
- "query": searchQuery,
- "urls": results["filenames"],
- "done": True,
- },
- }
- )
- files = form_data.get("files", [])
- files.append(
- {
- "collection_name": results["collection_name"],
- "name": searchQuery,
- "type": "web_search_results",
- "urls": results["filenames"],
- }
- )
- form_data["files"] = files
- else:
- await event_emitter(
- {
- "type": "status",
- "data": {
- "action": "web_search",
- "description": "No search results found",
- "query": searchQuery,
- "done": True,
- "error": True,
- },
- }
- )
- except Exception as e:
- log.exception(e)
- await event_emitter(
- {
- "type": "status",
- "data": {
- "action": "web_search",
- "description": 'Error searching "{{searchQuery}}"',
- "query": searchQuery,
- "done": True,
- "error": True,
- },
- }
- )
- return form_data
- async def chat_image_generation_handler(
- request: Request, form_data: dict, extra_params: dict, user
- ):
- __event_emitter__ = extra_params["__event_emitter__"]
- await __event_emitter__(
- {
- "type": "status",
- "data": {"description": "Generating an image", "done": False},
- }
- )
- messages = form_data["messages"]
- user_message = get_last_user_message(messages)
- prompt = user_message
- negative_prompt = ""
- if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
- try:
- res = await generate_image_prompt(
- request,
- {
- "model": form_data["model"],
- "messages": messages,
- },
- user,
- )
- response = res["choices"][0]["message"]["content"]
- try:
- bracket_start = response.find("{")
- bracket_end = response.rfind("}") + 1
- if bracket_start == -1 or bracket_end == -1:
- raise Exception("No JSON object found in the response")
- response = response[bracket_start:bracket_end]
- response = json.loads(response)
- prompt = response.get("prompt", [])
- except Exception as e:
- prompt = user_message
- except Exception as e:
- log.exception(e)
- prompt = user_message
- system_message_content = ""
- try:
- images = await image_generations(
- request=request,
- form_data=GenerateImageForm(**{"prompt": prompt}),
- user=user,
- )
- await __event_emitter__(
- {
- "type": "status",
- "data": {"description": "Generated an image", "done": True},
- }
- )
- for image in images:
- await __event_emitter__(
- {
- "type": "message",
- "data": {"content": f"\n"},
- }
- )
- system_message_content = "<context>User is shown the generated image, tell the user that the image has been generated</context>"
- except Exception as e:
- log.exception(e)
- await __event_emitter__(
- {
- "type": "status",
- "data": {
- "description": f"An error occured while generating an image",
- "done": True,
- },
- }
- )
- system_message_content = "<context>Unable to generate an image, tell the user that an error occured</context>"
- if system_message_content:
- form_data["messages"] = add_or_update_system_message(
- system_message_content, form_data["messages"]
- )
- return form_data
- async def chat_completion_files_handler(
- request: Request, body: dict, user: UserModel
- ) -> tuple[dict, dict[str, list]]:
- sources = []
- if files := body.get("metadata", {}).get("files", None):
- try:
- queries_response = await generate_queries(
- request,
- {
- "model": body["model"],
- "messages": body["messages"],
- "type": "retrieval",
- },
- user,
- )
- queries_response = queries_response["choices"][0]["message"]["content"]
- try:
- bracket_start = queries_response.find("{")
- bracket_end = queries_response.rfind("}") + 1
- if bracket_start == -1 or bracket_end == -1:
- raise Exception("No JSON object found in the response")
- queries_response = queries_response[bracket_start:bracket_end]
- queries_response = json.loads(queries_response)
- except Exception as e:
- queries_response = {"queries": [queries_response]}
- queries = queries_response.get("queries", [])
- except Exception as e:
- queries = []
- if len(queries) == 0:
- queries = [get_last_user_message(body["messages"])]
- try:
- # Offload get_sources_from_files to a separate thread
- loop = asyncio.get_running_loop()
- with ThreadPoolExecutor() as executor:
- sources = await loop.run_in_executor(
- executor,
- lambda: get_sources_from_files(
- files=files,
- queries=queries,
- embedding_function=request.app.state.EMBEDDING_FUNCTION,
- k=request.app.state.config.TOP_K,
- reranking_function=request.app.state.rf,
- r=request.app.state.config.RELEVANCE_THRESHOLD,
- hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
- ),
- )
- except Exception as e:
- log.exception(e)
- log.debug(f"rag_contexts:sources: {sources}")
- return body, {"sources": sources}
- def apply_params_to_form_data(form_data, model):
- params = form_data.pop("params", {})
- if model.get("ollama"):
- form_data["options"] = params
- if "format" in params:
- form_data["format"] = params["format"]
- if "keep_alive" in params:
- form_data["keep_alive"] = params["keep_alive"]
- else:
- if "seed" in params:
- form_data["seed"] = params["seed"]
- if "stop" in params:
- form_data["stop"] = params["stop"]
- if "temperature" in params:
- form_data["temperature"] = params["temperature"]
- if "max_tokens" in params:
- form_data["max_tokens"] = params["max_tokens"]
- if "top_p" in params:
- form_data["top_p"] = params["top_p"]
- if "frequency_penalty" in params:
- form_data["frequency_penalty"] = params["frequency_penalty"]
- if "reasoning_effort" in params:
- form_data["reasoning_effort"] = params["reasoning_effort"]
- return form_data
- async def process_chat_payload(request, form_data, metadata, user, model):
- form_data = apply_params_to_form_data(form_data, model)
- log.debug(f"form_data: {form_data}")
- event_emitter = get_event_emitter(metadata)
- event_call = get_event_call(metadata)
- extra_params = {
- "__event_emitter__": event_emitter,
- "__event_call__": event_call,
- "__user__": {
- "id": user.id,
- "email": user.email,
- "name": user.name,
- "role": user.role,
- },
- "__metadata__": metadata,
- "__request__": request,
- }
- # Initialize events to store additional event to be sent to the client
- # Initialize contexts and citation
- models = request.app.state.MODELS
- events = []
- sources = []
- user_message = get_last_user_message(form_data["messages"])
- model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False)
- if model_knowledge:
- await event_emitter(
- {
- "type": "status",
- "data": {
- "action": "knowledge_search",
- "query": user_message,
- "done": False,
- },
- }
- )
- knowledge_files = []
- for item in model_knowledge:
- if item.get("collection_name"):
- knowledge_files.append(
- {
- "id": item.get("collection_name"),
- "name": item.get("name"),
- "legacy": True,
- }
- )
- elif item.get("collection_names"):
- knowledge_files.append(
- {
- "name": item.get("name"),
- "type": "collection",
- "collection_names": item.get("collection_names"),
- "legacy": True,
- }
- )
- else:
- knowledge_files.append(item)
- files = form_data.get("files", [])
- files.extend(knowledge_files)
- form_data["files"] = files
- variables = form_data.pop("variables", None)
- features = form_data.pop("features", None)
- if features:
- if "web_search" in features and features["web_search"]:
- form_data = await chat_web_search_handler(
- request, form_data, extra_params, user
- )
- if "image_generation" in features and features["image_generation"]:
- form_data = await chat_image_generation_handler(
- request, form_data, extra_params, user
- )
- try:
- form_data, flags = await chat_completion_filter_functions_handler(
- request, form_data, model, extra_params
- )
- except Exception as e:
- raise Exception(f"Error: {e}")
- tool_ids = form_data.pop("tool_ids", None)
- files = form_data.pop("files", None)
- # Remove files duplicates
- if files:
- files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
- metadata = {
- **metadata,
- "tool_ids": tool_ids,
- "files": files,
- }
- form_data["metadata"] = metadata
- try:
- form_data, flags = await chat_completion_tools_handler(
- request, form_data, user, models, extra_params
- )
- sources.extend(flags.get("sources", []))
- except Exception as e:
- log.exception(e)
- try:
- form_data, flags = await chat_completion_files_handler(request, form_data, user)
- sources.extend(flags.get("sources", []))
- except Exception as e:
- log.exception(e)
- # If context is not empty, insert it into the messages
- if len(sources) > 0:
- context_string = ""
- for source_idx, source in enumerate(sources):
- source_id = source.get("source", {}).get("name", "")
- if "document" in source:
- for doc_idx, doc_context in enumerate(source["document"]):
- metadata = source.get("metadata")
- doc_source_id = None
- if metadata:
- doc_source_id = metadata[doc_idx].get("source", source_id)
- if source_id:
- context_string += f"<source><source_id>{doc_source_id if doc_source_id is not None else source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
- else:
- # If there is no source_id, then do not include the source_id tag
- context_string += f"<source><source_context>{doc_context}</source_context></source>\n"
- context_string = context_string.strip()
- prompt = get_last_user_message(form_data["messages"])
- if prompt is None:
- raise Exception("No user message found")
- if (
- request.app.state.config.RELEVANCE_THRESHOLD == 0
- and context_string.strip() == ""
- ):
- log.debug(
- f"With a 0 relevancy threshold for RAG, the context cannot be empty"
- )
- # Workaround for Ollama 2.0+ system prompt issue
- # TODO: replace with add_or_update_system_message
- if model["owned_by"] == "ollama":
- form_data["messages"] = prepend_to_first_user_message_content(
- rag_template(
- request.app.state.config.RAG_TEMPLATE, context_string, prompt
- ),
- form_data["messages"],
- )
- else:
- form_data["messages"] = add_or_update_system_message(
- rag_template(
- request.app.state.config.RAG_TEMPLATE, context_string, prompt
- ),
- form_data["messages"],
- )
- # If there are citations, add them to the data_items
- sources = [source for source in sources if source.get("source", {}).get("name", "")]
- if len(sources) > 0:
- events.append({"sources": sources})
- if model_knowledge:
- await event_emitter(
- {
- "type": "status",
- "data": {
- "action": "knowledge_search",
- "query": user_message,
- "done": True,
- "hidden": True,
- },
- }
- )
- return form_data, events
- async def process_chat_response(
- request, response, form_data, user, events, metadata, tasks
- ):
- async def background_tasks_handler():
- message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
- message = message_map.get(metadata["message_id"]) if message_map else None
- if message:
- messages = get_message_list(message_map, message.get("id"))
- if tasks:
- if TASKS.TITLE_GENERATION in tasks:
- if tasks[TASKS.TITLE_GENERATION]:
- res = await generate_title(
- request,
- {
- "model": message["model"],
- "messages": messages,
- "chat_id": metadata["chat_id"],
- },
- user,
- )
- if res and isinstance(res, dict):
- if len(res.get("choices", [])) == 1:
- title_string = (
- res.get("choices", [])[0]
- .get("message", {})
- .get("content", message.get("content", "New Chat"))
- )
- else:
- title_string = ""
- title_string = title_string[
- title_string.find("{") : title_string.rfind("}") + 1
- ]
- try:
- title = json.loads(title_string).get(
- "title", "New Chat"
- )
- except Exception as e:
- title = ""
- if not title:
- title = messages[0].get("content", "New Chat")
- Chats.update_chat_title_by_id(metadata["chat_id"], title)
- await event_emitter(
- {
- "type": "chat:title",
- "data": title,
- }
- )
- elif len(messages) == 2:
- title = messages[0].get("content", "New Chat")
- Chats.update_chat_title_by_id(metadata["chat_id"], title)
- await event_emitter(
- {
- "type": "chat:title",
- "data": message.get("content", "New Chat"),
- }
- )
- if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]:
- res = await generate_chat_tags(
- request,
- {
- "model": message["model"],
- "messages": messages,
- "chat_id": metadata["chat_id"],
- },
- user,
- )
- if res and isinstance(res, dict):
- if len(res.get("choices", [])) == 1:
- tags_string = (
- res.get("choices", [])[0]
- .get("message", {})
- .get("content", "")
- )
- else:
- tags_string = ""
- tags_string = tags_string[
- tags_string.find("{") : tags_string.rfind("}") + 1
- ]
- try:
- tags = json.loads(tags_string).get("tags", [])
- Chats.update_chat_tags_by_id(
- metadata["chat_id"], tags, user
- )
- await event_emitter(
- {
- "type": "chat:tags",
- "data": tags,
- }
- )
- except Exception as e:
- pass
- event_emitter = None
- if (
- "session_id" in metadata
- and metadata["session_id"]
- and "chat_id" in metadata
- and metadata["chat_id"]
- and "message_id" in metadata
- and metadata["message_id"]
- ):
- event_emitter = get_event_emitter(metadata)
- if not isinstance(response, StreamingResponse):
- if event_emitter:
- if "selected_model_id" in response:
- Chats.upsert_message_to_chat_by_id_and_message_id(
- metadata["chat_id"],
- metadata["message_id"],
- {
- "selectedModelId": response["selected_model_id"],
- },
- )
- if response.get("choices", [])[0].get("message", {}).get("content"):
- content = response["choices"][0]["message"]["content"]
- if content:
- await event_emitter(
- {
- "type": "chat:completion",
- "data": response,
- }
- )
- title = Chats.get_chat_title_by_id(metadata["chat_id"])
- await event_emitter(
- {
- "type": "chat:completion",
- "data": {
- "done": True,
- "content": content,
- "title": title,
- },
- }
- )
- # Save message in the database
- Chats.upsert_message_to_chat_by_id_and_message_id(
- metadata["chat_id"],
- metadata["message_id"],
- {
- "content": content,
- },
- )
- # Send a webhook notification if the user is not active
- if get_active_status_by_user_id(user.id) is None:
- webhook_url = Users.get_user_webhook_url_by_id(user.id)
- if webhook_url:
- post_webhook(
- webhook_url,
- f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
- {
- "action": "chat",
- "message": content,
- "title": title,
- "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
- },
- )
- await background_tasks_handler()
- return response
- else:
- return response
- if not any(
- content_type in response.headers["Content-Type"]
- for content_type in ["text/event-stream", "application/x-ndjson"]
- ):
- return response
- if event_emitter:
- task_id = str(uuid4()) # Create a unique task ID.
- # Handle as a background task
- async def post_response_handler(response, events):
- message = Chats.get_message_by_id_and_message_id(
- metadata["chat_id"], metadata["message_id"]
- )
- content = message.get("content", "") if message else ""
- try:
- for event in events:
- await event_emitter(
- {
- "type": "chat:completion",
- "data": event,
- }
- )
- # Save message in the database
- Chats.upsert_message_to_chat_by_id_and_message_id(
- metadata["chat_id"],
- metadata["message_id"],
- {
- **event,
- },
- )
- # We might want to disable this by default
- detect_reasoning = True
- reasoning_tags = ["think", "reason", "reasoning", "thought", "Thought"]
- current_tag = None
- reasoning_start_time = None
- reasoning_content = ""
- ongoing_content = ""
- async for line in response.body_iterator:
- line = line.decode("utf-8") if isinstance(line, bytes) else line
- data = line
- # Skip empty lines
- if not data.strip():
- continue
- # "data:" is the prefix for each event
- if not data.startswith("data:"):
- continue
- # Remove the prefix
- data = data[len("data:") :].strip()
- try:
- data = json.loads(data)
- if "selected_model_id" in data:
- Chats.upsert_message_to_chat_by_id_and_message_id(
- metadata["chat_id"],
- metadata["message_id"],
- {
- "selectedModelId": data["selected_model_id"],
- },
- )
- else:
- value = (
- data.get("choices", [])[0]
- .get("delta", {})
- .get("content")
- )
- if value:
- content = f"{content}{value}"
- if detect_reasoning:
- for tag in reasoning_tags:
- start_tag = f"<{tag}>\n"
- end_tag = f"</{tag}>\n"
- if start_tag in content:
- # Remove the start tag
- content = content.replace(start_tag, "")
- ongoing_content = content
- reasoning_start_time = time.time()
- reasoning_content = ""
- current_tag = tag
- break
- if reasoning_start_time is not None:
- # Remove the last value from the content
- content = content[: -len(value)]
- reasoning_content += value
- end_tag = f"</{current_tag}>\n"
- if end_tag in reasoning_content:
- reasoning_end_time = time.time()
- reasoning_duration = int(
- reasoning_end_time
- - reasoning_start_time
- )
- reasoning_content = (
- reasoning_content.strip(
- f"<{current_tag}>\n"
- )
- .strip(end_tag)
- .strip()
- )
- if reasoning_content:
- reasoning_display_content = "\n".join(
- (
- f"> {line}"
- if not line.startswith(">")
- else line
- )
- for line in reasoning_content.splitlines()
- )
- # Format reasoning with <details> tag
- content = f'{ongoing_content}<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
- else:
- content = ""
- reasoning_start_time = None
- else:
- reasoning_display_content = "\n".join(
- (
- f"> {line}"
- if not line.startswith(">")
- else line
- )
- for line in reasoning_content.splitlines()
- )
- # Show ongoing thought process
- content = f'{ongoing_content}<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
- if ENABLE_REALTIME_CHAT_SAVE:
- # Save message in the database
- Chats.upsert_message_to_chat_by_id_and_message_id(
- metadata["chat_id"],
- metadata["message_id"],
- {
- "content": content,
- },
- )
- else:
- data = {
- "content": content,
- }
- await event_emitter(
- {
- "type": "chat:completion",
- "data": data,
- }
- )
- except Exception as e:
- done = "data: [DONE]" in line
- if done:
- pass
- else:
- continue
- title = Chats.get_chat_title_by_id(metadata["chat_id"])
- data = {"done": True, "content": content, "title": title}
- if not ENABLE_REALTIME_CHAT_SAVE:
- # Save message in the database
- Chats.upsert_message_to_chat_by_id_and_message_id(
- metadata["chat_id"],
- metadata["message_id"],
- {
- "content": content,
- },
- )
- # Send a webhook notification if the user is not active
- if get_active_status_by_user_id(user.id) is None:
- webhook_url = Users.get_user_webhook_url_by_id(user.id)
- if webhook_url:
- post_webhook(
- webhook_url,
- f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
- {
- "action": "chat",
- "message": content,
- "title": title,
- "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
- },
- )
- await event_emitter(
- {
- "type": "chat:completion",
- "data": data,
- }
- )
- await background_tasks_handler()
- except asyncio.CancelledError:
- print("Task was cancelled!")
- await event_emitter({"type": "task-cancelled"})
- if not ENABLE_REALTIME_CHAT_SAVE:
- # Save message in the database
- Chats.upsert_message_to_chat_by_id_and_message_id(
- metadata["chat_id"],
- metadata["message_id"],
- {
- "content": content,
- },
- )
- if response.background is not None:
- await response.background()
- # background_tasks.add_task(post_response_handler, response, events)
- task_id, _ = create_task(post_response_handler(response, events))
- return {"status": True, "task_id": task_id}
- else:
- # Fallback to the original response
- async def stream_wrapper(original_generator, events):
- def wrap_item(item):
- return f"data: {item}\n\n"
- for event in events:
- yield wrap_item(json.dumps(event))
- async for data in original_generator:
- yield data
- return StreamingResponse(
- stream_wrapper(response.body_iterator, events),
- headers=dict(response.headers),
- background=response.background,
- )
|