|
@@ -51,13 +51,13 @@ from apps.webui.internal.db import Session
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
-from typing import Optional
|
|
|
+from typing import Optional, Callable, Awaitable
|
|
|
|
|
|
from apps.webui.models.auths import Auths
|
|
|
from apps.webui.models.models import Models
|
|
|
from apps.webui.models.tools import Tools
|
|
|
from apps.webui.models.functions import Functions
|
|
|
-from apps.webui.models.users import Users
|
|
|
+from apps.webui.models.users import Users, UserModel
|
|
|
|
|
|
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
|
|
|
|
|
@@ -72,7 +72,7 @@ from utils.utils import (
|
|
|
from utils.task import (
|
|
|
title_generation_template,
|
|
|
search_query_generation_template,
|
|
|
- tools_function_calling_generation_template,
|
|
|
+ tool_calling_generation_template,
|
|
|
)
|
|
|
from utils.misc import (
|
|
|
get_last_user_message,
|
|
@@ -261,6 +261,7 @@ def get_filter_function_ids(model):
|
|
|
def get_priority(function_id):
|
|
|
function = Functions.get_function_by_id(function_id)
|
|
|
if function is not None and hasattr(function, "valves"):
|
|
|
+ # TODO: Fix FunctionModel
|
|
|
return (function.valves if function.valves else {}).get("priority", 0)
|
|
|
return 0
|
|
|
|
|
@@ -282,164 +283,42 @@ def get_filter_function_ids(model):
|
|
|
return filter_ids
|
|
|
|
|
|
|
|
|
-async def get_function_call_response(
|
|
|
- messages,
|
|
|
- files,
|
|
|
- tool_id,
|
|
|
- template,
|
|
|
- task_model_id,
|
|
|
- user,
|
|
|
- __event_emitter__=None,
|
|
|
- __event_call__=None,
|
|
|
-):
|
|
|
- tool = Tools.get_tool_by_id(tool_id)
|
|
|
- tools_specs = json.dumps(tool.specs, indent=2)
|
|
|
- content = tools_function_calling_generation_template(template, tools_specs)
|
|
|
+async def get_content_from_response(response) -> Optional[str]:
|
|
|
+ content = None
|
|
|
+ if hasattr(response, "body_iterator"):
|
|
|
+ async for chunk in response.body_iterator:
|
|
|
+ data = json.loads(chunk.decode("utf-8"))
|
|
|
+ content = data["choices"][0]["message"]["content"]
|
|
|
+
|
|
|
+ # Cleanup any remaining background tasks if necessary
|
|
|
+ if response.background is not None:
|
|
|
+ await response.background()
|
|
|
+ else:
|
|
|
+ content = response["choices"][0]["message"]["content"]
|
|
|
+ return content
|
|
|
+
|
|
|
|
|
|
+def get_tool_call_payload(messages, task_model_id, content):
|
|
|
user_message = get_last_user_message(messages)
|
|
|
- prompt = (
|
|
|
- "History:\n"
|
|
|
- + "\n".join(
|
|
|
- [
|
|
|
- f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
|
|
- for message in messages[::-1][:4]
|
|
|
- ]
|
|
|
- )
|
|
|
- + f"\nQuery: {user_message}"
|
|
|
+ history = "\n".join(
|
|
|
+ f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
|
|
+ for message in messages[::-1][:4]
|
|
|
)
|
|
|
|
|
|
- print(prompt)
|
|
|
+ prompt = f"History:\n{history}\nQuery: {user_message}"
|
|
|
|
|
|
- payload = {
|
|
|
+ return {
|
|
|
"model": task_model_id,
|
|
|
"messages": [
|
|
|
{"role": "system", "content": content},
|
|
|
{"role": "user", "content": f"Query: {prompt}"},
|
|
|
],
|
|
|
"stream": False,
|
|
|
- "task": str(TASKS.FUNCTION_CALLING),
|
|
|
+ "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
|
|
}
|
|
|
|
|
|
- try:
|
|
|
- payload = filter_pipeline(payload, user)
|
|
|
- except Exception as e:
|
|
|
- raise e
|
|
|
-
|
|
|
- model = app.state.MODELS[task_model_id]
|
|
|
-
|
|
|
- response = None
|
|
|
- try:
|
|
|
- response = await generate_chat_completions(form_data=payload, user=user)
|
|
|
- content = None
|
|
|
-
|
|
|
- if hasattr(response, "body_iterator"):
|
|
|
- async for chunk in response.body_iterator:
|
|
|
- data = json.loads(chunk.decode("utf-8"))
|
|
|
- content = data["choices"][0]["message"]["content"]
|
|
|
-
|
|
|
- # Cleanup any remaining background tasks if necessary
|
|
|
- if response.background is not None:
|
|
|
- await response.background()
|
|
|
- else:
|
|
|
- content = response["choices"][0]["message"]["content"]
|
|
|
-
|
|
|
- if content is None:
|
|
|
- return None, None, False
|
|
|
-
|
|
|
- # Parse the function response
|
|
|
- print(f"content: {content}")
|
|
|
- result = json.loads(content)
|
|
|
- print(result)
|
|
|
-
|
|
|
- citation = None
|
|
|
-
|
|
|
- if "name" not in result:
|
|
|
- return None, None, False
|
|
|
-
|
|
|
- # Call the function
|
|
|
- if tool_id in webui_app.state.TOOLS:
|
|
|
- toolkit_module = webui_app.state.TOOLS[tool_id]
|
|
|
- else:
|
|
|
- toolkit_module, _ = load_toolkit_module_by_id(tool_id)
|
|
|
- webui_app.state.TOOLS[tool_id] = toolkit_module
|
|
|
-
|
|
|
- file_handler = False
|
|
|
- # check if toolkit_module has file_handler self variable
|
|
|
- if hasattr(toolkit_module, "file_handler"):
|
|
|
- file_handler = True
|
|
|
- print("file_handler: ", file_handler)
|
|
|
-
|
|
|
- if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
|
|
|
- valves = Tools.get_tool_valves_by_id(tool_id)
|
|
|
- toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
|
|
|
-
|
|
|
- function = getattr(toolkit_module, result["name"])
|
|
|
- function_result = None
|
|
|
- try:
|
|
|
- # Get the signature of the function
|
|
|
- sig = inspect.signature(function)
|
|
|
- params = result["parameters"]
|
|
|
-
|
|
|
- # Extra parameters to be passed to the function
|
|
|
- extra_params = {
|
|
|
- "__model__": model,
|
|
|
- "__id__": tool_id,
|
|
|
- "__messages__": messages,
|
|
|
- "__files__": files,
|
|
|
- "__event_emitter__": __event_emitter__,
|
|
|
- "__event_call__": __event_call__,
|
|
|
- }
|
|
|
-
|
|
|
- # Add extra params in contained in function signature
|
|
|
- for key, value in extra_params.items():
|
|
|
- if key in sig.parameters:
|
|
|
- params[key] = value
|
|
|
-
|
|
|
- if "__user__" in sig.parameters:
|
|
|
- # Call the function with the '__user__' parameter included
|
|
|
- __user__ = {
|
|
|
- "id": user.id,
|
|
|
- "email": user.email,
|
|
|
- "name": user.name,
|
|
|
- "role": user.role,
|
|
|
- }
|
|
|
-
|
|
|
- try:
|
|
|
- if hasattr(toolkit_module, "UserValves"):
|
|
|
- __user__["valves"] = toolkit_module.UserValves(
|
|
|
- **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
|
|
- )
|
|
|
- except Exception as e:
|
|
|
- print(e)
|
|
|
-
|
|
|
- params = {**params, "__user__": __user__}
|
|
|
-
|
|
|
- if inspect.iscoroutinefunction(function):
|
|
|
- function_result = await function(**params)
|
|
|
- else:
|
|
|
- function_result = function(**params)
|
|
|
-
|
|
|
- if hasattr(toolkit_module, "citation") and toolkit_module.citation:
|
|
|
- citation = {
|
|
|
- "source": {"name": f"TOOL:{tool.name}/{result['name']}"},
|
|
|
- "document": [function_result],
|
|
|
- "metadata": [{"source": result["name"]}],
|
|
|
- }
|
|
|
- except Exception as e:
|
|
|
- print(e)
|
|
|
-
|
|
|
- # Add the function result to the system prompt
|
|
|
- if function_result is not None:
|
|
|
- return function_result, citation, file_handler
|
|
|
- except Exception as e:
|
|
|
- print(f"Error: {e}")
|
|
|
-
|
|
|
- return None, None, False
|
|
|
|
|
|
-
|
|
|
-async def chat_completion_functions_handler(
|
|
|
- body, model, user, __event_emitter__, __event_call__
|
|
|
-):
|
|
|
+async def chat_completion_inlets_handler(body, model, extra_params):
|
|
|
skip_files = None
|
|
|
|
|
|
filter_ids = get_filter_function_ids(model)
|
|
@@ -475,37 +354,20 @@ async def chat_completion_functions_handler(
|
|
|
params = {"body": body}
|
|
|
|
|
|
# Extra parameters to be passed to the function
|
|
|
- extra_params = {
|
|
|
- "__model__": model,
|
|
|
- "__id__": filter_id,
|
|
|
- "__event_emitter__": __event_emitter__,
|
|
|
- "__event_call__": __event_call__,
|
|
|
- }
|
|
|
-
|
|
|
- # Add extra params in contained in function signature
|
|
|
- for key, value in extra_params.items():
|
|
|
- if key in sig.parameters:
|
|
|
- params[key] = value
|
|
|
-
|
|
|
- if "__user__" in sig.parameters:
|
|
|
- __user__ = {
|
|
|
- "id": user.id,
|
|
|
- "email": user.email,
|
|
|
- "name": user.name,
|
|
|
- "role": user.role,
|
|
|
- }
|
|
|
-
|
|
|
+ custom_params = {**extra_params, "__model__": model, "__id__": filter_id}
|
|
|
+ if hasattr(function_module, "UserValves") and "__user__" in sig.parameters:
|
|
|
try:
|
|
|
- if hasattr(function_module, "UserValves"):
|
|
|
- __user__["valves"] = function_module.UserValves(
|
|
|
- **Functions.get_user_valves_by_id_and_user_id(
|
|
|
- filter_id, user.id
|
|
|
- )
|
|
|
- )
|
|
|
+ uid = custom_params["__user__"]["id"]
|
|
|
+ custom_params["__user__"]["valves"] = function_module.UserValves(
|
|
|
+ **Functions.get_user_valves_by_id_and_user_id(filter_id, uid)
|
|
|
+ )
|
|
|
except Exception as e:
|
|
|
print(e)
|
|
|
|
|
|
- params = {**params, "__user__": __user__}
|
|
|
+ # Add extra params in contained in function signature
|
|
|
+ for key, value in custom_params.items():
|
|
|
+ if key in sig.parameters:
|
|
|
+ params[key] = value
|
|
|
|
|
|
if inspect.iscoroutinefunction(inlet):
|
|
|
body = await inlet(**params)
|
|
@@ -516,74 +378,171 @@ async def chat_completion_functions_handler(
|
|
|
print(f"Error: {e}")
|
|
|
raise e
|
|
|
|
|
|
- if skip_files:
|
|
|
- if "files" in body:
|
|
|
- del body["files"]
|
|
|
+ if skip_files and "files" in body:
|
|
|
+ del body["files"]
|
|
|
|
|
|
return body, {}
|
|
|
|
|
|
|
|
|
-async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__):
|
|
|
- skip_files = None
|
|
|
+def get_tool_with_custom_params(
|
|
|
+ tool: Callable, custom_params: dict
|
|
|
+) -> Callable[..., Awaitable]:
|
|
|
+ sig = inspect.signature(tool)
|
|
|
+ extra_params = {
|
|
|
+ key: value for key, value in custom_params.items() if key in sig.parameters
|
|
|
+ }
|
|
|
+ is_coroutine = inspect.iscoroutinefunction(tool)
|
|
|
|
|
|
- contexts = []
|
|
|
- citations = None
|
|
|
+ async def new_tool(**kwargs):
|
|
|
+ extra_kwargs = kwargs | extra_params
|
|
|
+ if is_coroutine:
|
|
|
+ return await tool(**extra_kwargs)
|
|
|
+ return tool(**extra_kwargs)
|
|
|
+
|
|
|
+ return new_tool
|
|
|
+
|
|
|
+
|
|
|
+# Mutation on extra_params
|
|
|
+def get_configured_tools(
|
|
|
+ tool_ids: list[str], extra_params: dict, user: UserModel
|
|
|
+) -> dict[str, dict]:
|
|
|
+ tools = {}
|
|
|
+ for tool_id in tool_ids:
|
|
|
+ toolkit = Tools.get_tool_by_id(tool_id)
|
|
|
+ if toolkit is None:
|
|
|
+ continue
|
|
|
|
|
|
+ module = webui_app.state.TOOLS.get(tool_id, None)
|
|
|
+ if module is None:
|
|
|
+ module, _ = load_toolkit_module_by_id(tool_id)
|
|
|
+ webui_app.state.TOOLS[tool_id] = module
|
|
|
+
|
|
|
+ extra_params["__id__"] = tool_id
|
|
|
+ has_citation = hasattr(module, "citation") and module.citation
|
|
|
+ handles_files = hasattr(module, "file_handler") and module.file_handler
|
|
|
+ if hasattr(module, "valves") and hasattr(module, "Valves"):
|
|
|
+ valves = Tools.get_tool_valves_by_id(tool_id) or {}
|
|
|
+ module.valves = module.Valves(**valves)
|
|
|
+
|
|
|
+ if hasattr(module, "UserValves"):
|
|
|
+ extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
|
|
|
+ **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
|
|
+ )
|
|
|
+
|
|
|
+ for spec in toolkit.specs:
|
|
|
+ # TODO: Fix hack for OpenAI API
|
|
|
+ for val in spec.get("parameters", {}).get("properties", {}).values():
|
|
|
+ if val["type"] == "str":
|
|
|
+ val["type"] = "string"
|
|
|
+ name = spec["name"]
|
|
|
+ callable = getattr(module, name)
|
|
|
+
|
|
|
+ # convert to function that takes only model params and inserts custom params
|
|
|
+ custom_callable = get_tool_with_custom_params(callable, extra_params)
|
|
|
+
|
|
|
+ # TODO: This needs to be a pydantic model
|
|
|
+ tool_dict = {
|
|
|
+ "spec": spec,
|
|
|
+ "citation": has_citation,
|
|
|
+ "file_handler": handles_files,
|
|
|
+ "toolkit_id": tool_id,
|
|
|
+ "callable": custom_callable,
|
|
|
+ }
|
|
|
+ # TODO: if collision, prepend toolkit name
|
|
|
+ if name in tools:
|
|
|
+ log.warning(f"Tool {name} already exists in another toolkit!")
|
|
|
+ log.warning(f"Collision between {toolkit} and {tool_id}.")
|
|
|
+ log.warning(f"Discarding {toolkit}.{name}")
|
|
|
+ else:
|
|
|
+ tools[name] = tool_dict
|
|
|
+
|
|
|
+ return tools
|
|
|
+
|
|
|
+
|
|
|
+async def chat_completion_tools_handler(
|
|
|
+ body: dict, user: UserModel, extra_params: dict
|
|
|
+) -> tuple[dict, dict]:
|
|
|
+ skip_files = False
|
|
|
+ contexts = []
|
|
|
+ citations = []
|
|
|
task_model_id = get_task_model_id(body["model"])
|
|
|
|
|
|
# If tool_ids field is present, call the functions
|
|
|
- if "tool_ids" in body:
|
|
|
- print(body["tool_ids"])
|
|
|
- for tool_id in body["tool_ids"]:
|
|
|
- print(tool_id)
|
|
|
- try:
|
|
|
- response, citation, file_handler = await get_function_call_response(
|
|
|
- messages=body["messages"],
|
|
|
- files=body.get("files", []),
|
|
|
- tool_id=tool_id,
|
|
|
- template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
|
- task_model_id=task_model_id,
|
|
|
- user=user,
|
|
|
- __event_emitter__=__event_emitter__,
|
|
|
- __event_call__=__event_call__,
|
|
|
- )
|
|
|
+ tool_ids = body.pop("tool_ids", None)
|
|
|
+ if not tool_ids:
|
|
|
+ return body, {}
|
|
|
+
|
|
|
+ log.debug(f"{tool_ids=}")
|
|
|
+ custom_params = {
|
|
|
+ **extra_params,
|
|
|
+ "__model__": app.state.MODELS[task_model_id],
|
|
|
+ "__messages__": body["messages"],
|
|
|
+ "__files__": body.get("files", []),
|
|
|
+ }
|
|
|
+ configured_tools = get_configured_tools(tool_ids, custom_params, user)
|
|
|
|
|
|
- print(file_handler)
|
|
|
- if isinstance(response, str):
|
|
|
- contexts.append(response)
|
|
|
+ log.info(f"{configured_tools=}")
|
|
|
|
|
|
- if citation:
|
|
|
- if citations is None:
|
|
|
- citations = [citation]
|
|
|
- else:
|
|
|
- citations.append(citation)
|
|
|
+ specs = [tool["spec"] for tool in configured_tools.values()]
|
|
|
+ tools_specs = json.dumps(specs)
|
|
|
+ template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
|
+ content = tool_calling_generation_template(template, tools_specs)
|
|
|
+ payload = get_tool_call_payload(body["messages"], task_model_id, content)
|
|
|
+ try:
|
|
|
+ payload = filter_pipeline(payload, user)
|
|
|
+ except Exception as e:
|
|
|
+ raise e
|
|
|
|
|
|
- if file_handler:
|
|
|
- skip_files = True
|
|
|
+ try:
|
|
|
+ response = await generate_chat_completions(form_data=payload, user=user)
|
|
|
+ log.debug(f"{response=}")
|
|
|
+ content = await get_content_from_response(response)
|
|
|
+ log.debug(f"{content=}")
|
|
|
+ if content is None:
|
|
|
+ return body, {}
|
|
|
|
|
|
- except Exception as e:
|
|
|
- print(f"Error: {e}")
|
|
|
- del body["tool_ids"]
|
|
|
- print(f"tool_contexts: {contexts}")
|
|
|
+ result = json.loads(content)
|
|
|
+ tool_name = result.get("name", None)
|
|
|
+ if tool_name not in configured_tools:
|
|
|
+ return body, {}
|
|
|
|
|
|
- if skip_files:
|
|
|
- if "files" in body:
|
|
|
- del body["files"]
|
|
|
+ tool_params = result.get("parameters", {})
|
|
|
+ toolkit_id = configured_tools[tool_name]["toolkit_id"]
|
|
|
+ try:
|
|
|
+ tool_output = await configured_tools[tool_name]["callable"](**tool_params)
|
|
|
+ except Exception as e:
|
|
|
+ tool_output = str(e)
|
|
|
+ if configured_tools[tool_name]["citation"]:
|
|
|
+ citations.append(
|
|
|
+ {
|
|
|
+ "source": {"name": f"TOOL:{toolkit_id}/{tool_name}"},
|
|
|
+ "document": [tool_output],
|
|
|
+ "metadata": [{"source": tool_name}],
|
|
|
+ }
|
|
|
+ )
|
|
|
+ if configured_tools[tool_name]["file_handler"]:
|
|
|
+ skip_files = True
|
|
|
|
|
|
- return body, {
|
|
|
- **({"contexts": contexts} if contexts is not None else {}),
|
|
|
- **({"citations": citations} if citations is not None else {}),
|
|
|
- }
|
|
|
+ if isinstance(tool_output, str):
|
|
|
+ contexts.append(tool_output)
|
|
|
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error: {e}")
|
|
|
+ content = None
|
|
|
|
|
|
-async def chat_completion_files_handler(body):
|
|
|
- contexts = []
|
|
|
- citations = None
|
|
|
+ log.debug(f"tool_contexts: {contexts}")
|
|
|
|
|
|
- if "files" in body:
|
|
|
- files = body["files"]
|
|
|
+ if skip_files and "files" in body:
|
|
|
del body["files"]
|
|
|
|
|
|
+ return body, {"contexts": contexts, "citations": citations}
|
|
|
+
|
|
|
+
|
|
|
+async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
|
|
|
+ contexts = []
|
|
|
+ citations = []
|
|
|
+
|
|
|
+ if files := body.pop("files", None):
|
|
|
contexts, citations = get_rag_context(
|
|
|
files=files,
|
|
|
messages=body["messages"],
|
|
@@ -596,134 +555,130 @@ async def chat_completion_files_handler(body):
|
|
|
|
|
|
log.debug(f"rag_contexts: {contexts}, citations: {citations}")
|
|
|
|
|
|
- return body, {
|
|
|
- **({"contexts": contexts} if contexts is not None else {}),
|
|
|
- **({"citations": citations} if citations is not None else {}),
|
|
|
- }
|
|
|
-
|
|
|
+ return body, {"contexts": contexts, "citations": citations}
|
|
|
|
|
|
-class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
- async def dispatch(self, request: Request, call_next):
|
|
|
- if request.method == "POST" and any(
|
|
|
- endpoint in request.url.path
|
|
|
- for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
|
|
- ):
|
|
|
- log.debug(f"request.url.path: {request.url.path}")
|
|
|
|
|
|
- try:
|
|
|
- body, model, user = await get_body_and_model_and_user(request)
|
|
|
- except Exception as e:
|
|
|
- return JSONResponse(
|
|
|
- status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
- content={"detail": str(e)},
|
|
|
- )
|
|
|
-
|
|
|
- metadata = {
|
|
|
- "chat_id": body.pop("chat_id", None),
|
|
|
- "message_id": body.pop("id", None),
|
|
|
- "session_id": body.pop("session_id", None),
|
|
|
- "valves": body.pop("valves", None),
|
|
|
- }
|
|
|
-
|
|
|
- __event_emitter__ = get_event_emitter(metadata)
|
|
|
- __event_call__ = get_event_call(metadata)
|
|
|
+def is_chat_completion_request(request):
|
|
|
+ return request.method == "POST" and any(
|
|
|
+ endpoint in request.url.path
|
|
|
+ for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
|
|
+ )
|
|
|
|
|
|
- # Initialize data_items to store additional data to be sent to the client
|
|
|
- data_items = []
|
|
|
|
|
|
- # Initialize context, and citations
|
|
|
- contexts = []
|
|
|
- citations = []
|
|
|
+class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
+ async def dispatch(self, request: Request, call_next):
|
|
|
+ if not is_chat_completion_request(request):
|
|
|
+ return await call_next(request)
|
|
|
+ log.debug(f"request.url.path: {request.url.path}")
|
|
|
|
|
|
- try:
|
|
|
- body, flags = await chat_completion_functions_handler(
|
|
|
- body, model, user, __event_emitter__, __event_call__
|
|
|
- )
|
|
|
- except Exception as e:
|
|
|
- return JSONResponse(
|
|
|
- status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
- content={"detail": str(e)},
|
|
|
- )
|
|
|
+ try:
|
|
|
+ body, model, user = await get_body_and_model_and_user(request)
|
|
|
+ except Exception as e:
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
+ content={"detail": str(e)},
|
|
|
+ )
|
|
|
|
|
|
- try:
|
|
|
- body, flags = await chat_completion_tools_handler(
|
|
|
- body, user, __event_emitter__, __event_call__
|
|
|
- )
|
|
|
+ metadata = {
|
|
|
+ "chat_id": body.pop("chat_id", None),
|
|
|
+ "message_id": body.pop("id", None),
|
|
|
+ "session_id": body.pop("session_id", None),
|
|
|
+ "valves": body.pop("valves", None),
|
|
|
+ }
|
|
|
|
|
|
- contexts.extend(flags.get("contexts", []))
|
|
|
- citations.extend(flags.get("citations", []))
|
|
|
- except Exception as e:
|
|
|
- print(e)
|
|
|
- pass
|
|
|
+ __user__ = {
|
|
|
+ "id": user.id,
|
|
|
+ "email": user.email,
|
|
|
+ "name": user.name,
|
|
|
+ "role": user.role,
|
|
|
+ }
|
|
|
|
|
|
- try:
|
|
|
- body, flags = await chat_completion_files_handler(body)
|
|
|
+ extra_params = {
|
|
|
+ "__user__": __user__,
|
|
|
+ "__event_emitter__": get_event_emitter(metadata),
|
|
|
+ "__event_call__": get_event_call(metadata),
|
|
|
+ }
|
|
|
|
|
|
- contexts.extend(flags.get("contexts", []))
|
|
|
- citations.extend(flags.get("citations", []))
|
|
|
- except Exception as e:
|
|
|
- print(e)
|
|
|
- pass
|
|
|
+ # Initialize data_items to store additional data to be sent to the client
|
|
|
+ # Initalize contexts and citation
|
|
|
+ data_items = []
|
|
|
+ contexts = []
|
|
|
+ citations = []
|
|
|
|
|
|
- # If context is not empty, insert it into the messages
|
|
|
- if len(contexts) > 0:
|
|
|
- context_string = "/n".join(contexts).strip()
|
|
|
- prompt = get_last_user_message(body["messages"])
|
|
|
-
|
|
|
- # Workaround for Ollama 2.0+ system prompt issue
|
|
|
- # TODO: replace with add_or_update_system_message
|
|
|
- if model["owned_by"] == "ollama":
|
|
|
- body["messages"] = prepend_to_first_user_message_content(
|
|
|
- rag_template(
|
|
|
- rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
|
|
- ),
|
|
|
- body["messages"],
|
|
|
- )
|
|
|
- else:
|
|
|
- body["messages"] = add_or_update_system_message(
|
|
|
- rag_template(
|
|
|
- rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
|
|
- ),
|
|
|
- body["messages"],
|
|
|
- )
|
|
|
+ try:
|
|
|
+ body, flags = await chat_completion_inlets_handler(
|
|
|
+ body, model, extra_params
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
+ content={"detail": str(e)},
|
|
|
+ )
|
|
|
|
|
|
- # If there are citations, add them to the data_items
|
|
|
- if len(citations) > 0:
|
|
|
- data_items.append({"citations": citations})
|
|
|
-
|
|
|
- body["metadata"] = metadata
|
|
|
- modified_body_bytes = json.dumps(body).encode("utf-8")
|
|
|
- # Replace the request body with the modified one
|
|
|
- request._body = modified_body_bytes
|
|
|
- # Set custom header to ensure content-length matches new body length
|
|
|
- request.headers.__dict__["_list"] = [
|
|
|
- (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
|
|
- *[
|
|
|
- (k, v)
|
|
|
- for k, v in request.headers.raw
|
|
|
- if k.lower() != b"content-length"
|
|
|
- ],
|
|
|
- ]
|
|
|
-
|
|
|
- response = await call_next(request)
|
|
|
- if isinstance(response, StreamingResponse):
|
|
|
- # If it's a streaming response, inject it as SSE event or NDJSON line
|
|
|
- content_type = response.headers.get("Content-Type")
|
|
|
- if "text/event-stream" in content_type:
|
|
|
- return StreamingResponse(
|
|
|
- self.openai_stream_wrapper(response.body_iterator, data_items),
|
|
|
- )
|
|
|
- if "application/x-ndjson" in content_type:
|
|
|
- return StreamingResponse(
|
|
|
- self.ollama_stream_wrapper(response.body_iterator, data_items),
|
|
|
- )
|
|
|
+ try:
|
|
|
+ body, flags = await chat_completion_tools_handler(body, user, extra_params)
|
|
|
+ contexts.extend(flags.get("contexts", []))
|
|
|
+ citations.extend(flags.get("citations", []))
|
|
|
+ except Exception as e:
|
|
|
+ log.exception(e)
|
|
|
|
|
|
- return response
|
|
|
+ try:
|
|
|
+ body, flags = await chat_completion_files_handler(body)
|
|
|
+ contexts.extend(flags.get("contexts", []))
|
|
|
+ citations.extend(flags.get("citations", []))
|
|
|
+ except Exception as e:
|
|
|
+ log.exception(e)
|
|
|
+
|
|
|
+ # If context is not empty, insert it into the messages
|
|
|
+ if len(contexts) > 0:
|
|
|
+ context_string = "/n".join(contexts).strip()
|
|
|
+ prompt = get_last_user_message(body["messages"])
|
|
|
+ if prompt is None:
|
|
|
+ raise Exception("No user message found")
|
|
|
+ # Workaround for Ollama 2.0+ system prompt issue
|
|
|
+ # TODO: replace with add_or_update_system_message
|
|
|
+ if model["owned_by"] == "ollama":
|
|
|
+ body["messages"] = prepend_to_first_user_message_content(
|
|
|
+ rag_template(
|
|
|
+ rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
|
|
+ ),
|
|
|
+ body["messages"],
|
|
|
+ )
|
|
|
else:
|
|
|
- return response
|
|
|
+ body["messages"] = add_or_update_system_message(
|
|
|
+ rag_template(
|
|
|
+ rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
|
|
+ ),
|
|
|
+ body["messages"],
|
|
|
+ )
|
|
|
+
|
|
|
+ # If there are citations, add them to the data_items
|
|
|
+ if len(citations) > 0:
|
|
|
+ data_items.append({"citations": citations})
|
|
|
+
|
|
|
+ body["metadata"] = metadata
|
|
|
+ modified_body_bytes = json.dumps(body).encode("utf-8")
|
|
|
+ # Replace the request body with the modified one
|
|
|
+ request._body = modified_body_bytes
|
|
|
+ # Set custom header to ensure content-length matches new body length
|
|
|
+ request.headers.__dict__["_list"] = [
|
|
|
+ (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
|
|
+ *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
|
|
|
+ ]
|
|
|
|
|
|
- # If it's not a chat completion request, just pass it through
|
|
|
response = await call_next(request)
|
|
|
+ if isinstance(response, StreamingResponse):
|
|
|
+ # If it's a streaming response, inject it as SSE event or NDJSON line
|
|
|
+ content_type = response.headers["Content-Type"]
|
|
|
+ if "text/event-stream" in content_type:
|
|
|
+ return StreamingResponse(
|
|
|
+ self.openai_stream_wrapper(response.body_iterator, data_items),
|
|
|
+ )
|
|
|
+ if "application/x-ndjson" in content_type:
|
|
|
+ return StreamingResponse(
|
|
|
+ self.ollama_stream_wrapper(response.body_iterator, data_items),
|
|
|
+ )
|
|
|
+
|
|
|
return response
|
|
|
|
|
|
async def _receive(self, body: bytes):
|
|
@@ -790,19 +745,21 @@ def filter_pipeline(payload, user):
|
|
|
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
|
|
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
|
|
|
|
|
- if key != "":
|
|
|
- headers = {"Authorization": f"Bearer {key}"}
|
|
|
- r = requests.post(
|
|
|
- f"{url}/{filter['id']}/filter/inlet",
|
|
|
- headers=headers,
|
|
|
- json={
|
|
|
- "user": user,
|
|
|
- "body": payload,
|
|
|
- },
|
|
|
- )
|
|
|
+ if key == "":
|
|
|
+ continue
|
|
|
+
|
|
|
+ headers = {"Authorization": f"Bearer {key}"}
|
|
|
+ r = requests.post(
|
|
|
+ f"{url}/{filter['id']}/filter/inlet",
|
|
|
+ headers=headers,
|
|
|
+ json={
|
|
|
+ "user": user,
|
|
|
+ "body": payload,
|
|
|
+ },
|
|
|
+ )
|
|
|
|
|
|
- r.raise_for_status()
|
|
|
- payload = r.json()
|
|
|
+ r.raise_for_status()
|
|
|
+ payload = r.json()
|
|
|
except Exception as e:
|
|
|
# Handle connection error here
|
|
|
print(f"Connection error: {e}")
|
|
@@ -817,44 +774,39 @@ def filter_pipeline(payload, user):
|
|
|
|
|
|
class PipelineMiddleware(BaseHTTPMiddleware):
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
|
- if request.method == "POST" and (
|
|
|
- "/ollama/api/chat" in request.url.path
|
|
|
- or "/chat/completions" in request.url.path
|
|
|
- ):
|
|
|
- log.debug(f"request.url.path: {request.url.path}")
|
|
|
-
|
|
|
- # Read the original request body
|
|
|
- body = await request.body()
|
|
|
- # Decode body to string
|
|
|
- body_str = body.decode("utf-8")
|
|
|
- # Parse string to JSON
|
|
|
- data = json.loads(body_str) if body_str else {}
|
|
|
-
|
|
|
- user = get_current_user(
|
|
|
- request,
|
|
|
- get_http_authorization_cred(request.headers.get("Authorization")),
|
|
|
- )
|
|
|
+ if not is_chat_completion_request(request):
|
|
|
+ return await call_next(request)
|
|
|
|
|
|
- try:
|
|
|
- data = filter_pipeline(data, user)
|
|
|
- except Exception as e:
|
|
|
- return JSONResponse(
|
|
|
- status_code=e.args[0],
|
|
|
- content={"detail": e.args[1]},
|
|
|
- )
|
|
|
+ log.debug(f"request.url.path: {request.url.path}")
|
|
|
|
|
|
- modified_body_bytes = json.dumps(data).encode("utf-8")
|
|
|
- # Replace the request body with the modified one
|
|
|
- request._body = modified_body_bytes
|
|
|
- # Set custom header to ensure content-length matches new body length
|
|
|
- request.headers.__dict__["_list"] = [
|
|
|
- (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
|
|
- *[
|
|
|
- (k, v)
|
|
|
- for k, v in request.headers.raw
|
|
|
- if k.lower() != b"content-length"
|
|
|
- ],
|
|
|
- ]
|
|
|
+ # Read the original request body
|
|
|
+ body = await request.body()
|
|
|
+ # Decode body to string
|
|
|
+ body_str = body.decode("utf-8")
|
|
|
+ # Parse string to JSON
|
|
|
+ data = json.loads(body_str) if body_str else {}
|
|
|
+
|
|
|
+ user = get_current_user(
|
|
|
+ request,
|
|
|
+ get_http_authorization_cred(request.headers["Authorization"]),
|
|
|
+ )
|
|
|
+
|
|
|
+ try:
|
|
|
+ data = filter_pipeline(data, user)
|
|
|
+ except Exception as e:
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=e.args[0],
|
|
|
+ content={"detail": e.args[1]},
|
|
|
+ )
|
|
|
+
|
|
|
+ modified_body_bytes = json.dumps(data).encode("utf-8")
|
|
|
+ # Replace the request body with the modified one
|
|
|
+ request._body = modified_body_bytes
|
|
|
+ # Set custom header to ensure content-length matches new body length
|
|
|
+ request.headers.__dict__["_list"] = [
|
|
|
+ (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
|
|
+ *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
|
|
|
+ ]
|
|
|
|
|
|
response = await call_next(request)
|
|
|
return response
|
|
@@ -1019,6 +971,8 @@ async def get_all_models():
|
|
|
model["actions"] = []
|
|
|
for action_id in action_ids:
|
|
|
action = Functions.get_function_by_id(action_id)
|
|
|
+ if action is None:
|
|
|
+ raise Exception(f"Action not found: {action_id}")
|
|
|
|
|
|
if action_id in webui_app.state.FUNCTIONS:
|
|
|
function_module = webui_app.state.FUNCTIONS[action_id]
|
|
@@ -1099,22 +1053,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
|
|
)
|
|
|
model = app.state.MODELS[model_id]
|
|
|
|
|
|
- # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
|
|
|
- task = None
|
|
|
- if "task" in form_data:
|
|
|
- task = form_data["task"]
|
|
|
- del form_data["task"]
|
|
|
-
|
|
|
- if task:
|
|
|
- if "metadata" in form_data:
|
|
|
- form_data["metadata"]["task"] = task
|
|
|
- else:
|
|
|
- form_data["metadata"] = {"task": task}
|
|
|
-
|
|
|
if model.get("pipe"):
|
|
|
return await generate_function_chat_completion(form_data, user=user)
|
|
|
if model["owned_by"] == "ollama":
|
|
|
- print("generate_ollama_chat_completion")
|
|
|
return await generate_ollama_chat_completion(form_data, user=user)
|
|
|
else:
|
|
|
return await generate_openai_chat_completion(form_data, user=user)
|
|
@@ -1198,6 +1139,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
|
|
def get_priority(function_id):
|
|
|
function = Functions.get_function_by_id(function_id)
|
|
|
if function is not None and hasattr(function, "valves"):
|
|
|
+ # TODO: Fix FunctionModel to include vavles
|
|
|
return (function.valves if function.valves else {}).get("priority", 0)
|
|
|
return 0
|
|
|
|
|
@@ -1487,7 +1429,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|
|
"stream": False,
|
|
|
"max_tokens": 50,
|
|
|
"chat_id": form_data.get("chat_id", None),
|
|
|
- "task": str(TASKS.TITLE_GENERATION),
|
|
|
+ "metadata": {"task": str(TASKS.TITLE_GENERATION)},
|
|
|
}
|
|
|
|
|
|
log.debug(payload)
|
|
@@ -1540,7 +1482,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
|
|
"messages": [{"role": "user", "content": content}],
|
|
|
"stream": False,
|
|
|
"max_tokens": 30,
|
|
|
- "task": str(TASKS.QUERY_GENERATION),
|
|
|
+ "metadata": {"task": str(TASKS.QUERY_GENERATION)},
|
|
|
}
|
|
|
|
|
|
print(payload)
|
|
@@ -1597,7 +1539,7 @@ Message: """{{prompt}}"""
|
|
|
"stream": False,
|
|
|
"max_tokens": 4,
|
|
|
"chat_id": form_data.get("chat_id", None),
|
|
|
- "task": str(TASKS.EMOJI_GENERATION),
|
|
|
+ "metadata": {"task": str(TASKS.EMOJI_GENERATION)},
|
|
|
}
|
|
|
|
|
|
log.debug(payload)
|
|
@@ -1616,41 +1558,6 @@ Message: """{{prompt}}"""
|
|
|
return await generate_chat_completions(form_data=payload, user=user)
|
|
|
|
|
|
|
|
|
-@app.post("/api/task/tools/completions")
|
|
|
-async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
|
|
|
- print("get_tools_function_calling")
|
|
|
-
|
|
|
- model_id = form_data["model"]
|
|
|
- if model_id not in app.state.MODELS:
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_404_NOT_FOUND,
|
|
|
- detail="Model not found",
|
|
|
- )
|
|
|
-
|
|
|
- # Check if the user has a custom task model
|
|
|
- # If the user has a custom task model, use that model
|
|
|
- model_id = get_task_model_id(model_id)
|
|
|
-
|
|
|
- print(model_id)
|
|
|
- template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
|
-
|
|
|
- try:
|
|
|
- context, _, _ = await get_function_call_response(
|
|
|
- form_data["messages"],
|
|
|
- form_data.get("files", []),
|
|
|
- form_data["tool_id"],
|
|
|
- template,
|
|
|
- model_id,
|
|
|
- user,
|
|
|
- )
|
|
|
- return context
|
|
|
- except Exception as e:
|
|
|
- return JSONResponse(
|
|
|
- status_code=e.args[0],
|
|
|
- content={"detail": e.args[1]},
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
##################################
|
|
|
#
|
|
|
# Pipelines Endpoints
|
|
@@ -1689,7 +1596,7 @@ async def upload_pipeline(
|
|
|
):
|
|
|
print("upload_pipeline", urlIdx, file.filename)
|
|
|
# Check if the uploaded file is a python file
|
|
|
- if not file.filename.endswith(".py"):
|
|
|
+ if not (file.filename and file.filename.endswith(".py")):
|
|
|
raise HTTPException(
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
detail="Only Python (.py) files are allowed.",
|
|
@@ -2138,7 +2045,10 @@ async def oauth_login(provider: str, request: Request):
|
|
|
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
|
|
|
"oauth_callback", provider=provider
|
|
|
)
|
|
|
- return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
|
|
|
+ client = oauth.create_client(provider)
|
|
|
+ if client is None:
|
|
|
+ raise HTTPException(404)
|
|
|
+ return await client.authorize_redirect(request, redirect_uri)
|
|
|
|
|
|
|
|
|
# OAuth login logic is as follows:
|