|
@@ -14,6 +14,7 @@ import requests
|
|
|
import mimetypes
|
|
|
import shutil
|
|
|
import inspect
|
|
|
+from typing import Optional
|
|
|
|
|
|
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
@@ -51,15 +52,13 @@ from apps.webui.internal.db import Session
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
-from typing import Optional, Callable, Awaitable
|
|
|
|
|
|
from apps.webui.models.auths import Auths
|
|
|
from apps.webui.models.models import Models
|
|
|
-from apps.webui.models.tools import Tools
|
|
|
from apps.webui.models.functions import Functions
|
|
|
from apps.webui.models.users import Users, UserModel
|
|
|
|
|
|
-from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
|
|
|
+from apps.webui.utils import load_function_module_by_id
|
|
|
|
|
|
from utils.utils import (
|
|
|
get_admin_user,
|
|
@@ -76,6 +75,8 @@ from utils.task import (
|
|
|
tools_function_calling_generation_template,
|
|
|
moa_response_generation_template,
|
|
|
)
|
|
|
+
|
|
|
+from utils.tools import get_tools
|
|
|
from utils.misc import (
|
|
|
get_last_user_message,
|
|
|
add_or_update_system_message,
|
|
@@ -325,8 +326,8 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
|
|
|
print(f"Error: {e}")
|
|
|
raise e
|
|
|
|
|
|
- if skip_files and "files" in body:
|
|
|
- del body["files"]
|
|
|
+ if skip_files and "files" in body.get("metadata", {}):
|
|
|
+ del body["metadata"]["files"]
|
|
|
|
|
|
return body, {}
|
|
|
|
|
@@ -351,80 +352,6 @@ def get_tools_function_calling_payload(messages, task_model_id, content):
|
|
|
}
|
|
|
|
|
|
|
|
|
-def apply_extra_params_to_tool_function(
|
|
|
- function: Callable, extra_params: dict
|
|
|
-) -> Callable[..., Awaitable]:
|
|
|
- sig = inspect.signature(function)
|
|
|
- extra_params = {
|
|
|
- key: value for key, value in extra_params.items() if key in sig.parameters
|
|
|
- }
|
|
|
- is_coroutine = inspect.iscoroutinefunction(function)
|
|
|
-
|
|
|
- async def new_function(**kwargs):
|
|
|
- extra_kwargs = kwargs | extra_params
|
|
|
- if is_coroutine:
|
|
|
- return await function(**extra_kwargs)
|
|
|
- return function(**extra_kwargs)
|
|
|
-
|
|
|
- return new_function
|
|
|
-
|
|
|
-
|
|
|
-# Mutation on extra_params
|
|
|
-def get_tools(
|
|
|
- tool_ids: list[str], user: UserModel, extra_params: dict
|
|
|
-) -> dict[str, dict]:
|
|
|
- tools = {}
|
|
|
- for tool_id in tool_ids:
|
|
|
- toolkit = Tools.get_tool_by_id(tool_id)
|
|
|
- if toolkit is None:
|
|
|
- continue
|
|
|
-
|
|
|
- module = webui_app.state.TOOLS.get(tool_id, None)
|
|
|
- if module is None:
|
|
|
- module, _ = load_toolkit_module_by_id(tool_id)
|
|
|
- webui_app.state.TOOLS[tool_id] = module
|
|
|
-
|
|
|
- extra_params["__id__"] = tool_id
|
|
|
- if hasattr(module, "valves") and hasattr(module, "Valves"):
|
|
|
- valves = Tools.get_tool_valves_by_id(tool_id) or {}
|
|
|
- module.valves = module.Valves(**valves)
|
|
|
-
|
|
|
- if hasattr(module, "UserValves"):
|
|
|
- extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
|
|
|
- **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
|
|
- )
|
|
|
-
|
|
|
- for spec in toolkit.specs:
|
|
|
- # TODO: Fix hack for OpenAI API
|
|
|
- for val in spec.get("parameters", {}).get("properties", {}).values():
|
|
|
- if val["type"] == "str":
|
|
|
- val["type"] = "string"
|
|
|
- function_name = spec["name"]
|
|
|
-
|
|
|
- # convert to function that takes only model params and inserts custom params
|
|
|
- callable = apply_extra_params_to_tool_function(
|
|
|
- getattr(module, function_name), extra_params
|
|
|
- )
|
|
|
-
|
|
|
- # TODO: This needs to be a pydantic model
|
|
|
- tool_dict = {
|
|
|
- "toolkit_id": tool_id,
|
|
|
- "callable": callable,
|
|
|
- "spec": spec,
|
|
|
- "file_handler": hasattr(module, "file_handler") and module.file_handler,
|
|
|
- "citation": hasattr(module, "citation") and module.citation,
|
|
|
- }
|
|
|
-
|
|
|
- # TODO: if collision, prepend toolkit name
|
|
|
- if function_name in tools:
|
|
|
- log.warning(f"Tool {function_name} already exists in another toolkit!")
|
|
|
- log.warning(f"Collision between {toolkit} and {tool_id}.")
|
|
|
- log.warning(f"Discarding {toolkit}.{function_name}")
|
|
|
- else:
|
|
|
- tools[function_name] = tool_dict
|
|
|
- return tools
|
|
|
-
|
|
|
-
|
|
|
async def get_content_from_response(response) -> Optional[str]:
|
|
|
content = None
|
|
|
if hasattr(response, "body_iterator"):
|
|
@@ -443,15 +370,17 @@ async def get_content_from_response(response) -> Optional[str]:
|
|
|
async def chat_completion_tools_handler(
|
|
|
body: dict, user: UserModel, extra_params: dict
|
|
|
) -> tuple[dict, dict]:
|
|
|
+ # If tool_ids field is present, call the functions
|
|
|
+ metadata = body.get("metadata", {})
|
|
|
+ tool_ids = metadata.get("tool_ids", None)
|
|
|
+ if not tool_ids:
|
|
|
+ return body, {}
|
|
|
+
|
|
|
skip_files = False
|
|
|
contexts = []
|
|
|
citations = []
|
|
|
|
|
|
task_model_id = get_task_model_id(body["model"])
|
|
|
- # If tool_ids field is present, call the functions
|
|
|
- tool_ids = body.pop("tool_ids", None)
|
|
|
- if not tool_ids:
|
|
|
- return body, {}
|
|
|
|
|
|
log.debug(f"{tool_ids=}")
|
|
|
|
|
@@ -459,9 +388,9 @@ async def chat_completion_tools_handler(
|
|
|
**extra_params,
|
|
|
"__model__": app.state.MODELS[task_model_id],
|
|
|
"__messages__": body["messages"],
|
|
|
- "__files__": body.get("files", []),
|
|
|
+ "__files__": metadata.get("files", []),
|
|
|
}
|
|
|
- tools = get_tools(tool_ids, user, custom_params)
|
|
|
+ tools = get_tools(webui_app, tool_ids, user, custom_params)
|
|
|
log.info(f"{tools=}")
|
|
|
|
|
|
specs = [tool["spec"] for tool in tools.values()]
|
|
@@ -486,7 +415,7 @@ async def chat_completion_tools_handler(
|
|
|
content = await get_content_from_response(response)
|
|
|
log.debug(f"{content=}")
|
|
|
|
|
|
- if content is None:
|
|
|
+ if not content:
|
|
|
return body, {}
|
|
|
|
|
|
result = json.loads(content)
|
|
@@ -521,13 +450,13 @@ async def chat_completion_tools_handler(
|
|
|
contexts.append(tool_output)
|
|
|
|
|
|
except Exception as e:
|
|
|
- print(f"Error: {e}")
|
|
|
+ log.exception(f"Error: {e}")
|
|
|
content = None
|
|
|
|
|
|
log.debug(f"tool_contexts: {contexts}")
|
|
|
|
|
|
- if skip_files and "files" in body:
|
|
|
- del body["files"]
|
|
|
+ if skip_files and "files" in body.get("metadata", {}):
|
|
|
+ del body["metadata"]["files"]
|
|
|
|
|
|
return body, {"contexts": contexts, "citations": citations}
|
|
|
|
|
@@ -536,7 +465,7 @@ async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
|
|
|
contexts = []
|
|
|
citations = []
|
|
|
|
|
|
- if files := body.pop("files", None):
|
|
|
+ if files := body.get("metadata", {}).get("files", None):
|
|
|
contexts, citations = get_rag_context(
|
|
|
files=files,
|
|
|
messages=body["messages"],
|
|
@@ -597,6 +526,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
"message_id": body.pop("id", None),
|
|
|
"session_id": body.pop("session_id", None),
|
|
|
"valves": body.pop("valves", None),
|
|
|
+ "tool_ids": body.pop("tool_ids", None),
|
|
|
+ "files": body.pop("files", None),
|
|
|
}
|
|
|
|
|
|
__user__ = {
|
|
@@ -680,36 +611,29 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
]
|
|
|
|
|
|
response = await call_next(request)
|
|
|
- if isinstance(response, StreamingResponse):
|
|
|
- # If it's a streaming response, inject it as SSE event or NDJSON line
|
|
|
- content_type = response.headers["Content-Type"]
|
|
|
- if "text/event-stream" in content_type:
|
|
|
- return StreamingResponse(
|
|
|
- self.openai_stream_wrapper(response.body_iterator, data_items),
|
|
|
- )
|
|
|
- if "application/x-ndjson" in content_type:
|
|
|
- return StreamingResponse(
|
|
|
- self.ollama_stream_wrapper(response.body_iterator, data_items),
|
|
|
- )
|
|
|
+ if not isinstance(response, StreamingResponse):
|
|
|
+ return response
|
|
|
|
|
|
- return response
|
|
|
+ content_type = response.headers["Content-Type"]
|
|
|
+ is_openai = "text/event-stream" in content_type
|
|
|
+ is_ollama = "application/x-ndjson" in content_type
|
|
|
+ if not is_openai and not is_ollama:
|
|
|
+ return response
|
|
|
|
|
|
- async def _receive(self, body: bytes):
|
|
|
- return {"type": "http.request", "body": body, "more_body": False}
|
|
|
+ def wrap_item(item):
|
|
|
+ return f"data: {item}\n\n" if is_openai else f"{item}\n"
|
|
|
|
|
|
- async def openai_stream_wrapper(self, original_generator, data_items):
|
|
|
- for item in data_items:
|
|
|
- yield f"data: {json.dumps(item)}\n\n"
|
|
|
+ async def stream_wrapper(original_generator, data_items):
|
|
|
+ for item in data_items:
|
|
|
+ yield wrap_item(json.dumps(item))
|
|
|
|
|
|
- async for data in original_generator:
|
|
|
- yield data
|
|
|
+ async for data in original_generator:
|
|
|
+ yield data
|
|
|
|
|
|
- async def ollama_stream_wrapper(self, original_generator, data_items):
|
|
|
- for item in data_items:
|
|
|
- yield f"{json.dumps(item)}\n"
|
|
|
+ return StreamingResponse(stream_wrapper(response.body_iterator, data_items))
|
|
|
|
|
|
- async for data in original_generator:
|
|
|
- yield data
|
|
|
+ async def _receive(self, body: bytes):
|
|
|
+ return {"type": "http.request", "body": body, "more_body": False}
|
|
|
|
|
|
|
|
|
app.add_middleware(ChatCompletionMiddleware)
|
|
@@ -1065,7 +989,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
|
|
detail="Model not found",
|
|
|
)
|
|
|
model = app.state.MODELS[model_id]
|
|
|
-
|
|
|
if model.get("pipe"):
|
|
|
return await generate_function_chat_completion(form_data, user=user)
|
|
|
if model["owned_by"] == "ollama":
|