|
@@ -41,8 +41,6 @@ from starlette.responses import Response, StreamingResponse
|
|
from open_webui.socket.main import (
|
|
from open_webui.socket.main import (
|
|
app as socket_app,
|
|
app as socket_app,
|
|
periodic_usage_pool_cleanup,
|
|
periodic_usage_pool_cleanup,
|
|
- get_event_call,
|
|
|
|
- get_event_emitter,
|
|
|
|
)
|
|
)
|
|
from open_webui.routers import (
|
|
from open_webui.routers import (
|
|
audio,
|
|
audio,
|
|
@@ -74,12 +72,6 @@ from open_webui.routers.retrieval import (
|
|
get_ef,
|
|
get_ef,
|
|
get_rf,
|
|
get_rf,
|
|
)
|
|
)
|
|
-from open_webui.routers.pipelines import (
|
|
|
|
- process_pipeline_inlet_filter,
|
|
|
|
-)
|
|
|
|
-
|
|
|
|
-from open_webui.retrieval.utils import get_sources_from_files
|
|
|
|
-
|
|
|
|
|
|
|
|
from open_webui.internal.db import Session
|
|
from open_webui.internal.db import Session
|
|
|
|
|
|
@@ -87,8 +79,6 @@ from open_webui.models.functions import Functions
|
|
from open_webui.models.models import Models
|
|
from open_webui.models.models import Models
|
|
from open_webui.models.users import UserModel, Users
|
|
from open_webui.models.users import UserModel, Users
|
|
|
|
|
|
-
|
|
|
|
-from open_webui.constants import TASKS
|
|
|
|
from open_webui.config import (
|
|
from open_webui.config import (
|
|
# Ollama
|
|
# Ollama
|
|
ENABLE_OLLAMA_API,
|
|
ENABLE_OLLAMA_API,
|
|
@@ -274,43 +264,22 @@ from open_webui.env import (
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
-from open_webui.utils.models import get_all_models, get_all_base_models
|
|
|
|
|
|
+from open_webui.utils.models import (
|
|
|
|
+ get_all_models,
|
|
|
|
+ get_all_base_models,
|
|
|
|
+ check_model_access,
|
|
|
|
+)
|
|
from open_webui.utils.chat import (
|
|
from open_webui.utils.chat import (
|
|
generate_chat_completion as chat_completion_handler,
|
|
generate_chat_completion as chat_completion_handler,
|
|
chat_completed as chat_completed_handler,
|
|
chat_completed as chat_completed_handler,
|
|
chat_action as chat_action_handler,
|
|
chat_action as chat_action_handler,
|
|
)
|
|
)
|
|
-
|
|
|
|
-
|
|
|
|
-from open_webui.utils.plugin import load_function_module_by_id
|
|
|
|
-from open_webui.utils.misc import (
|
|
|
|
- add_or_update_system_message,
|
|
|
|
- get_last_user_message,
|
|
|
|
- prepend_to_first_user_message_content,
|
|
|
|
- openai_chat_chunk_message_template,
|
|
|
|
- openai_chat_completion_message_template,
|
|
|
|
-)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-from open_webui.utils.payload import convert_payload_openai_to_ollama
|
|
|
|
-from open_webui.utils.response import (
|
|
|
|
- convert_response_ollama_to_openai,
|
|
|
|
- convert_streaming_response_ollama_to_openai,
|
|
|
|
-)
|
|
|
|
-
|
|
|
|
-from open_webui.utils.task import (
|
|
|
|
- get_task_model_id,
|
|
|
|
- rag_template,
|
|
|
|
- tools_function_calling_generation_template,
|
|
|
|
-)
|
|
|
|
-from open_webui.utils.tools import get_tools
|
|
|
|
|
|
+from open_webui.utils.middleware import process_chat_payload, process_chat_response
|
|
from open_webui.utils.access_control import has_access
|
|
from open_webui.utils.access_control import has_access
|
|
|
|
|
|
from open_webui.utils.auth import (
|
|
from open_webui.utils.auth import (
|
|
decode_token,
|
|
decode_token,
|
|
get_admin_user,
|
|
get_admin_user,
|
|
- get_current_user,
|
|
|
|
- get_http_authorization_cred,
|
|
|
|
get_verified_user,
|
|
get_verified_user,
|
|
)
|
|
)
|
|
from open_webui.utils.oauth import oauth_manager
|
|
from open_webui.utils.oauth import oauth_manager
|
|
@@ -665,634 +634,6 @@ app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
|
|
|
|
|
|
app.state.MODELS = {}
|
|
app.state.MODELS = {}
|
|
|
|
|
|
-##################################
|
|
|
|
-#
|
|
|
|
-# ChatCompletion Middleware
|
|
|
|
-#
|
|
|
|
-##################################
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-async def chat_completion_filter_functions_handler(body, model, extra_params):
|
|
|
|
- skip_files = None
|
|
|
|
-
|
|
|
|
- def get_filter_function_ids(model):
|
|
|
|
- def get_priority(function_id):
|
|
|
|
- function = Functions.get_function_by_id(function_id)
|
|
|
|
- if function is not None and hasattr(function, "valves"):
|
|
|
|
- # TODO: Fix FunctionModel
|
|
|
|
- return (function.valves if function.valves else {}).get("priority", 0)
|
|
|
|
- return 0
|
|
|
|
-
|
|
|
|
- filter_ids = [
|
|
|
|
- function.id for function in Functions.get_global_filter_functions()
|
|
|
|
- ]
|
|
|
|
- if "info" in model and "meta" in model["info"]:
|
|
|
|
- filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
|
|
|
- filter_ids = list(set(filter_ids))
|
|
|
|
-
|
|
|
|
- enabled_filter_ids = [
|
|
|
|
- function.id
|
|
|
|
- for function in Functions.get_functions_by_type("filter", active_only=True)
|
|
|
|
- ]
|
|
|
|
-
|
|
|
|
- filter_ids = [
|
|
|
|
- filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
|
|
|
- ]
|
|
|
|
-
|
|
|
|
- filter_ids.sort(key=get_priority)
|
|
|
|
- return filter_ids
|
|
|
|
-
|
|
|
|
- filter_ids = get_filter_function_ids(model)
|
|
|
|
- for filter_id in filter_ids:
|
|
|
|
- filter = Functions.get_function_by_id(filter_id)
|
|
|
|
- if not filter:
|
|
|
|
- continue
|
|
|
|
-
|
|
|
|
- if filter_id in app.state.FUNCTIONS:
|
|
|
|
- function_module = app.state.FUNCTIONS[filter_id]
|
|
|
|
- else:
|
|
|
|
- function_module, _, _ = load_function_module_by_id(filter_id)
|
|
|
|
- app.state.FUNCTIONS[filter_id] = function_module
|
|
|
|
-
|
|
|
|
- # Check if the function has a file_handler variable
|
|
|
|
- if hasattr(function_module, "file_handler"):
|
|
|
|
- skip_files = function_module.file_handler
|
|
|
|
-
|
|
|
|
- if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
|
|
|
- valves = Functions.get_function_valves_by_id(filter_id)
|
|
|
|
- function_module.valves = function_module.Valves(
|
|
|
|
- **(valves if valves else {})
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- if not hasattr(function_module, "inlet"):
|
|
|
|
- continue
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- inlet = function_module.inlet
|
|
|
|
-
|
|
|
|
- # Get the signature of the function
|
|
|
|
- sig = inspect.signature(inlet)
|
|
|
|
- params = {"body": body} | {
|
|
|
|
- k: v
|
|
|
|
- for k, v in {
|
|
|
|
- **extra_params,
|
|
|
|
- "__model__": model,
|
|
|
|
- "__id__": filter_id,
|
|
|
|
- }.items()
|
|
|
|
- if k in sig.parameters
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if "__user__" in params and hasattr(function_module, "UserValves"):
|
|
|
|
- try:
|
|
|
|
- params["__user__"]["valves"] = function_module.UserValves(
|
|
|
|
- **Functions.get_user_valves_by_id_and_user_id(
|
|
|
|
- filter_id, params["__user__"]["id"]
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
- except Exception as e:
|
|
|
|
- print(e)
|
|
|
|
-
|
|
|
|
- if inspect.iscoroutinefunction(inlet):
|
|
|
|
- body = await inlet(**params)
|
|
|
|
- else:
|
|
|
|
- body = inlet(**params)
|
|
|
|
-
|
|
|
|
- except Exception as e:
|
|
|
|
- print(f"Error: {e}")
|
|
|
|
- raise e
|
|
|
|
-
|
|
|
|
- if skip_files and "files" in body.get("metadata", {}):
|
|
|
|
- del body["metadata"]["files"]
|
|
|
|
-
|
|
|
|
- return body, {}
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-async def chat_completion_tools_handler(
|
|
|
|
- request: Request, body: dict, user: UserModel, models, extra_params: dict
|
|
|
|
-) -> tuple[dict, dict]:
|
|
|
|
- async def get_content_from_response(response) -> Optional[str]:
|
|
|
|
- content = None
|
|
|
|
- if hasattr(response, "body_iterator"):
|
|
|
|
- async for chunk in response.body_iterator:
|
|
|
|
- data = json.loads(chunk.decode("utf-8"))
|
|
|
|
- content = data["choices"][0]["message"]["content"]
|
|
|
|
-
|
|
|
|
- # Cleanup any remaining background tasks if necessary
|
|
|
|
- if response.background is not None:
|
|
|
|
- await response.background()
|
|
|
|
- else:
|
|
|
|
- content = response["choices"][0]["message"]["content"]
|
|
|
|
- return content
|
|
|
|
-
|
|
|
|
- def get_tools_function_calling_payload(messages, task_model_id, content):
|
|
|
|
- user_message = get_last_user_message(messages)
|
|
|
|
- history = "\n".join(
|
|
|
|
- f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
|
|
|
- for message in messages[::-1][:4]
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- prompt = f"History:\n{history}\nQuery: {user_message}"
|
|
|
|
-
|
|
|
|
- return {
|
|
|
|
- "model": task_model_id,
|
|
|
|
- "messages": [
|
|
|
|
- {"role": "system", "content": content},
|
|
|
|
- {"role": "user", "content": f"Query: {prompt}"},
|
|
|
|
- ],
|
|
|
|
- "stream": False,
|
|
|
|
- "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- # If tool_ids field is present, call the functions
|
|
|
|
- metadata = body.get("metadata", {})
|
|
|
|
-
|
|
|
|
- tool_ids = metadata.get("tool_ids", None)
|
|
|
|
- log.debug(f"{tool_ids=}")
|
|
|
|
- if not tool_ids:
|
|
|
|
- return body, {}
|
|
|
|
-
|
|
|
|
- skip_files = False
|
|
|
|
- sources = []
|
|
|
|
-
|
|
|
|
- task_model_id = get_task_model_id(
|
|
|
|
- body["model"],
|
|
|
|
- request.app.state.config.TASK_MODEL,
|
|
|
|
- request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
|
|
- models,
|
|
|
|
- )
|
|
|
|
- tools = get_tools(
|
|
|
|
- request,
|
|
|
|
- tool_ids,
|
|
|
|
- user,
|
|
|
|
- {
|
|
|
|
- **extra_params,
|
|
|
|
- "__model__": models[task_model_id],
|
|
|
|
- "__messages__": body["messages"],
|
|
|
|
- "__files__": metadata.get("files", []),
|
|
|
|
- },
|
|
|
|
- )
|
|
|
|
- log.info(f"{tools=}")
|
|
|
|
-
|
|
|
|
- specs = [tool["spec"] for tool in tools.values()]
|
|
|
|
- tools_specs = json.dumps(specs)
|
|
|
|
-
|
|
|
|
- if app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "":
|
|
|
|
- template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
|
|
- else:
|
|
|
|
- template = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text."""
|
|
|
|
-
|
|
|
|
- tools_function_calling_prompt = tools_function_calling_generation_template(
|
|
|
|
- template, tools_specs
|
|
|
|
- )
|
|
|
|
- log.info(f"{tools_function_calling_prompt=}")
|
|
|
|
- payload = get_tools_function_calling_payload(
|
|
|
|
- body["messages"], task_model_id, tools_function_calling_prompt
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- payload = process_pipeline_inlet_filter(request, payload, user, models)
|
|
|
|
- except Exception as e:
|
|
|
|
- raise e
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- response = await generate_chat_completions(form_data=payload, user=user)
|
|
|
|
- log.debug(f"{response=}")
|
|
|
|
- content = await get_content_from_response(response)
|
|
|
|
- log.debug(f"{content=}")
|
|
|
|
-
|
|
|
|
- if not content:
|
|
|
|
- return body, {}
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- content = content[content.find("{") : content.rfind("}") + 1]
|
|
|
|
- if not content:
|
|
|
|
- raise Exception("No JSON object found in the response")
|
|
|
|
-
|
|
|
|
- result = json.loads(content)
|
|
|
|
-
|
|
|
|
- tool_function_name = result.get("name", None)
|
|
|
|
- if tool_function_name not in tools:
|
|
|
|
- return body, {}
|
|
|
|
-
|
|
|
|
- tool_function_params = result.get("parameters", {})
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- required_params = (
|
|
|
|
- tools[tool_function_name]
|
|
|
|
- .get("spec", {})
|
|
|
|
- .get("parameters", {})
|
|
|
|
- .get("required", [])
|
|
|
|
- )
|
|
|
|
- tool_function = tools[tool_function_name]["callable"]
|
|
|
|
- tool_function_params = {
|
|
|
|
- k: v
|
|
|
|
- for k, v in tool_function_params.items()
|
|
|
|
- if k in required_params
|
|
|
|
- }
|
|
|
|
- tool_output = await tool_function(**tool_function_params)
|
|
|
|
-
|
|
|
|
- except Exception as e:
|
|
|
|
- tool_output = str(e)
|
|
|
|
-
|
|
|
|
- if isinstance(tool_output, str):
|
|
|
|
- if tools[tool_function_name]["citation"]:
|
|
|
|
- sources.append(
|
|
|
|
- {
|
|
|
|
- "source": {
|
|
|
|
- "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
|
|
|
|
- },
|
|
|
|
- "document": [tool_output],
|
|
|
|
- "metadata": [
|
|
|
|
- {
|
|
|
|
- "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
|
|
|
|
- }
|
|
|
|
- ],
|
|
|
|
- }
|
|
|
|
- )
|
|
|
|
- else:
|
|
|
|
- sources.append(
|
|
|
|
- {
|
|
|
|
- "source": {},
|
|
|
|
- "document": [tool_output],
|
|
|
|
- "metadata": [
|
|
|
|
- {
|
|
|
|
- "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
|
|
|
|
- }
|
|
|
|
- ],
|
|
|
|
- }
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- if tools[tool_function_name]["file_handler"]:
|
|
|
|
- skip_files = True
|
|
|
|
-
|
|
|
|
- except Exception as e:
|
|
|
|
- log.exception(f"Error: {e}")
|
|
|
|
- content = None
|
|
|
|
- except Exception as e:
|
|
|
|
- log.exception(f"Error: {e}")
|
|
|
|
- content = None
|
|
|
|
-
|
|
|
|
- log.debug(f"tool_contexts: {sources}")
|
|
|
|
-
|
|
|
|
- if skip_files and "files" in body.get("metadata", {}):
|
|
|
|
- del body["metadata"]["files"]
|
|
|
|
-
|
|
|
|
- return body, {"sources": sources}
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-async def chat_completion_files_handler(
|
|
|
|
- request: Request, body: dict, user: UserModel
|
|
|
|
-) -> tuple[dict, dict[str, list]]:
|
|
|
|
- sources = []
|
|
|
|
-
|
|
|
|
- if files := body.get("metadata", {}).get("files", None):
|
|
|
|
- try:
|
|
|
|
- queries_response = await generate_queries(
|
|
|
|
- {
|
|
|
|
- "model": body["model"],
|
|
|
|
- "messages": body["messages"],
|
|
|
|
- "type": "retrieval",
|
|
|
|
- },
|
|
|
|
- user,
|
|
|
|
- )
|
|
|
|
- queries_response = queries_response["choices"][0]["message"]["content"]
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- bracket_start = queries_response.find("{")
|
|
|
|
- bracket_end = queries_response.rfind("}") + 1
|
|
|
|
-
|
|
|
|
- if bracket_start == -1 or bracket_end == -1:
|
|
|
|
- raise Exception("No JSON object found in the response")
|
|
|
|
-
|
|
|
|
- queries_response = queries_response[bracket_start:bracket_end]
|
|
|
|
- queries_response = json.loads(queries_response)
|
|
|
|
- except Exception as e:
|
|
|
|
- queries_response = {"queries": [queries_response]}
|
|
|
|
-
|
|
|
|
- queries = queries_response.get("queries", [])
|
|
|
|
- except Exception as e:
|
|
|
|
- queries = []
|
|
|
|
-
|
|
|
|
- if len(queries) == 0:
|
|
|
|
- queries = [get_last_user_message(body["messages"])]
|
|
|
|
-
|
|
|
|
- sources = get_sources_from_files(
|
|
|
|
- files=files,
|
|
|
|
- queries=queries,
|
|
|
|
- embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
|
|
|
- k=request.app.state.config.TOP_K,
|
|
|
|
- reranking_function=request.app.state.rf,
|
|
|
|
- r=request.app.state.config.RELEVANCE_THRESHOLD,
|
|
|
|
- hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- log.debug(f"rag_contexts:sources: {sources}")
|
|
|
|
- return body, {"sources": sources}
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
|
- async def dispatch(self, request: Request, call_next):
|
|
|
|
- if not (
|
|
|
|
- request.method == "POST"
|
|
|
|
- and any(
|
|
|
|
- endpoint in request.url.path
|
|
|
|
- for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
|
|
|
- )
|
|
|
|
- ):
|
|
|
|
- return await call_next(request)
|
|
|
|
- log.debug(f"request.url.path: {request.url.path}")
|
|
|
|
-
|
|
|
|
- await get_all_models(request)
|
|
|
|
- models = app.state.MODELS
|
|
|
|
-
|
|
|
|
- async def get_body_and_model_and_user(request, models):
|
|
|
|
- # Read the original request body
|
|
|
|
- body = await request.body()
|
|
|
|
- body_str = body.decode("utf-8")
|
|
|
|
- body = json.loads(body_str) if body_str else {}
|
|
|
|
-
|
|
|
|
- model_id = body["model"]
|
|
|
|
- if model_id not in models:
|
|
|
|
- raise Exception("Model not found")
|
|
|
|
- model = models[model_id]
|
|
|
|
-
|
|
|
|
- user = get_current_user(
|
|
|
|
- request,
|
|
|
|
- get_http_authorization_cred(request.headers.get("Authorization")),
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- return body, model, user
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- body, model, user = await get_body_and_model_and_user(request, models)
|
|
|
|
- except Exception as e:
|
|
|
|
- return JSONResponse(
|
|
|
|
- status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
- content={"detail": str(e)},
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- model_info = Models.get_model_by_id(model["id"])
|
|
|
|
- if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
|
|
|
- if model.get("arena"):
|
|
|
|
- if not has_access(
|
|
|
|
- user.id,
|
|
|
|
- type="read",
|
|
|
|
- access_control=model.get("info", {})
|
|
|
|
- .get("meta", {})
|
|
|
|
- .get("access_control", {}),
|
|
|
|
- ):
|
|
|
|
- raise HTTPException(
|
|
|
|
- status_code=403,
|
|
|
|
- detail="Model not found",
|
|
|
|
- )
|
|
|
|
- else:
|
|
|
|
- if not model_info:
|
|
|
|
- return JSONResponse(
|
|
|
|
- status_code=status.HTTP_404_NOT_FOUND,
|
|
|
|
- content={"detail": "Model not found"},
|
|
|
|
- )
|
|
|
|
- elif not (
|
|
|
|
- user.id == model_info.user_id
|
|
|
|
- or has_access(
|
|
|
|
- user.id, type="read", access_control=model_info.access_control
|
|
|
|
- )
|
|
|
|
- ):
|
|
|
|
- return JSONResponse(
|
|
|
|
- status_code=status.HTTP_403_FORBIDDEN,
|
|
|
|
- content={"detail": "User does not have access to the model"},
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- metadata = {
|
|
|
|
- "chat_id": body.pop("chat_id", None),
|
|
|
|
- "message_id": body.pop("id", None),
|
|
|
|
- "session_id": body.pop("session_id", None),
|
|
|
|
- "tool_ids": body.get("tool_ids", None),
|
|
|
|
- "files": body.get("files", None),
|
|
|
|
- }
|
|
|
|
- body["metadata"] = metadata
|
|
|
|
-
|
|
|
|
- extra_params = {
|
|
|
|
- "__event_emitter__": get_event_emitter(metadata),
|
|
|
|
- "__event_call__": get_event_call(metadata),
|
|
|
|
- "__user__": {
|
|
|
|
- "id": user.id,
|
|
|
|
- "email": user.email,
|
|
|
|
- "name": user.name,
|
|
|
|
- "role": user.role,
|
|
|
|
- },
|
|
|
|
- "__metadata__": metadata,
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- # Initialize data_items to store additional data to be sent to the client
|
|
|
|
- # Initialize contexts and citation
|
|
|
|
- data_items = []
|
|
|
|
- sources = []
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- body, flags = await chat_completion_filter_functions_handler(
|
|
|
|
- body, model, extra_params
|
|
|
|
- )
|
|
|
|
- except Exception as e:
|
|
|
|
- return JSONResponse(
|
|
|
|
- status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
- content={"detail": str(e)},
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- tool_ids = body.pop("tool_ids", None)
|
|
|
|
- files = body.pop("files", None)
|
|
|
|
-
|
|
|
|
- metadata = {
|
|
|
|
- **metadata,
|
|
|
|
- "tool_ids": tool_ids,
|
|
|
|
- "files": files,
|
|
|
|
- }
|
|
|
|
- body["metadata"] = metadata
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- body, flags = await chat_completion_tools_handler(
|
|
|
|
- request, body, user, models, extra_params
|
|
|
|
- )
|
|
|
|
- sources.extend(flags.get("sources", []))
|
|
|
|
- except Exception as e:
|
|
|
|
- log.exception(e)
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- body, flags = await chat_completion_files_handler(request, body, user)
|
|
|
|
- sources.extend(flags.get("sources", []))
|
|
|
|
- except Exception as e:
|
|
|
|
- log.exception(e)
|
|
|
|
-
|
|
|
|
- # If context is not empty, insert it into the messages
|
|
|
|
- if len(sources) > 0:
|
|
|
|
- context_string = ""
|
|
|
|
- for source_idx, source in enumerate(sources):
|
|
|
|
- source_id = source.get("source", {}).get("name", "")
|
|
|
|
-
|
|
|
|
- if "document" in source:
|
|
|
|
- for doc_idx, doc_context in enumerate(source["document"]):
|
|
|
|
- metadata = source.get("metadata")
|
|
|
|
- doc_source_id = None
|
|
|
|
-
|
|
|
|
- if metadata:
|
|
|
|
- doc_source_id = metadata[doc_idx].get("source", source_id)
|
|
|
|
-
|
|
|
|
- if source_id:
|
|
|
|
- context_string += f"<source><source_id>{doc_source_id if doc_source_id is not None else source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
|
|
|
|
- else:
|
|
|
|
- # If there is no source_id, then do not include the source_id tag
|
|
|
|
- context_string += f"<source><source_context>{doc_context}</source_context></source>\n"
|
|
|
|
-
|
|
|
|
- context_string = context_string.strip()
|
|
|
|
- prompt = get_last_user_message(body["messages"])
|
|
|
|
-
|
|
|
|
- if prompt is None:
|
|
|
|
- raise Exception("No user message found")
|
|
|
|
- if (
|
|
|
|
- app.state.config.RELEVANCE_THRESHOLD == 0
|
|
|
|
- and context_string.strip() == ""
|
|
|
|
- ):
|
|
|
|
- log.debug(
|
|
|
|
- f"With a 0 relevancy threshold for RAG, the context cannot be empty"
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- # Workaround for Ollama 2.0+ system prompt issue
|
|
|
|
- # TODO: replace with add_or_update_system_message
|
|
|
|
- if model["owned_by"] == "ollama":
|
|
|
|
- body["messages"] = prepend_to_first_user_message_content(
|
|
|
|
- rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt),
|
|
|
|
- body["messages"],
|
|
|
|
- )
|
|
|
|
- else:
|
|
|
|
- body["messages"] = add_or_update_system_message(
|
|
|
|
- rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt),
|
|
|
|
- body["messages"],
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- # If there are citations, add them to the data_items
|
|
|
|
- sources = [
|
|
|
|
- source for source in sources if source.get("source", {}).get("name", "")
|
|
|
|
- ]
|
|
|
|
- if len(sources) > 0:
|
|
|
|
- data_items.append({"sources": sources})
|
|
|
|
-
|
|
|
|
- modified_body_bytes = json.dumps(body).encode("utf-8")
|
|
|
|
- # Replace the request body with the modified one
|
|
|
|
- request._body = modified_body_bytes
|
|
|
|
- # Set custom header to ensure content-length matches new body length
|
|
|
|
- request.headers.__dict__["_list"] = [
|
|
|
|
- (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
|
|
|
- *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
|
|
|
|
- ]
|
|
|
|
-
|
|
|
|
- response = await call_next(request)
|
|
|
|
- if not isinstance(response, StreamingResponse):
|
|
|
|
- return response
|
|
|
|
-
|
|
|
|
- content_type = response.headers["Content-Type"]
|
|
|
|
- is_openai = "text/event-stream" in content_type
|
|
|
|
- is_ollama = "application/x-ndjson" in content_type
|
|
|
|
- if not is_openai and not is_ollama:
|
|
|
|
- return response
|
|
|
|
-
|
|
|
|
- def wrap_item(item):
|
|
|
|
- return f"data: {item}\n\n" if is_openai else f"{item}\n"
|
|
|
|
-
|
|
|
|
- async def stream_wrapper(original_generator, data_items):
|
|
|
|
- for item in data_items:
|
|
|
|
- yield wrap_item(json.dumps(item))
|
|
|
|
-
|
|
|
|
- async for data in original_generator:
|
|
|
|
- yield data
|
|
|
|
-
|
|
|
|
- return StreamingResponse(
|
|
|
|
- stream_wrapper(response.body_iterator, data_items),
|
|
|
|
- headers=dict(response.headers),
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- async def _receive(self, body: bytes):
|
|
|
|
- return {"type": "http.request", "body": body, "more_body": False}
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-app.add_middleware(ChatCompletionMiddleware)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-class PipelineMiddleware(BaseHTTPMiddleware):
|
|
|
|
- async def dispatch(self, request: Request, call_next):
|
|
|
|
- if not (
|
|
|
|
- request.method == "POST"
|
|
|
|
- and any(
|
|
|
|
- endpoint in request.url.path
|
|
|
|
- for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
|
|
|
- )
|
|
|
|
- ):
|
|
|
|
- return await call_next(request)
|
|
|
|
-
|
|
|
|
- log.debug(f"request.url.path: {request.url.path}")
|
|
|
|
-
|
|
|
|
- # Read the original request body
|
|
|
|
- body = await request.body()
|
|
|
|
- # Decode body to string
|
|
|
|
- body_str = body.decode("utf-8")
|
|
|
|
- # Parse string to JSON
|
|
|
|
- data = json.loads(body_str) if body_str else {}
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- user = get_current_user(
|
|
|
|
- request,
|
|
|
|
- get_http_authorization_cred(request.headers["Authorization"]),
|
|
|
|
- )
|
|
|
|
- except KeyError as e:
|
|
|
|
- if len(e.args) > 1:
|
|
|
|
- return JSONResponse(
|
|
|
|
- status_code=e.args[0],
|
|
|
|
- content={"detail": e.args[1]},
|
|
|
|
- )
|
|
|
|
- else:
|
|
|
|
- return JSONResponse(
|
|
|
|
- status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
- content={"detail": "Not authenticated"},
|
|
|
|
- )
|
|
|
|
- except HTTPException as e:
|
|
|
|
- return JSONResponse(
|
|
|
|
- status_code=e.status_code,
|
|
|
|
- content={"detail": e.detail},
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- await get_all_models(request)
|
|
|
|
- models = app.state.MODELS
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- data = process_pipeline_inlet_filter(request, data, user, models)
|
|
|
|
- except Exception as e:
|
|
|
|
- if len(e.args) > 1:
|
|
|
|
- return JSONResponse(
|
|
|
|
- status_code=e.args[0],
|
|
|
|
- content={"detail": e.args[1]},
|
|
|
|
- )
|
|
|
|
- else:
|
|
|
|
- return JSONResponse(
|
|
|
|
- status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
- content={"detail": str(e)},
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- modified_body_bytes = json.dumps(data).encode("utf-8")
|
|
|
|
- # Replace the request body with the modified one
|
|
|
|
- request._body = modified_body_bytes
|
|
|
|
- # Set custom header to ensure content-length matches new body length
|
|
|
|
- request.headers.__dict__["_list"] = [
|
|
|
|
- (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
|
|
|
- *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
|
|
|
|
- ]
|
|
|
|
-
|
|
|
|
- response = await call_next(request)
|
|
|
|
- return response
|
|
|
|
-
|
|
|
|
- async def _receive(self, body: bytes):
|
|
|
|
- return {"type": "http.request", "body": body, "more_body": False}
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-app.add_middleware(PipelineMiddleware)
|
|
|
|
-
|
|
|
|
|
|
|
|
class RedirectMiddleware(BaseHTTPMiddleware):
|
|
class RedirectMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
async def dispatch(self, request: Request, call_next):
|
|
@@ -1471,8 +812,32 @@ async def chat_completion(
|
|
user=Depends(get_verified_user),
|
|
user=Depends(get_verified_user),
|
|
bypass_filter: bool = False,
|
|
bypass_filter: bool = False,
|
|
):
|
|
):
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ model_id = form_data.get("model", None)
|
|
|
|
+ if model_id not in request.app.state.MODELS:
|
|
|
|
+ raise Exception("Model not found")
|
|
|
|
+ model = request.app.state.MODELS[model_id]
|
|
|
|
+
|
|
|
|
+ # Check if user has access to the model
|
|
|
|
+ if not bypass_filter and user.role == "user":
|
|
|
|
+ try:
|
|
|
|
+ check_model_access(user, model)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ raise e
|
|
|
|
+
|
|
|
|
+ form_data, events = await process_chat_payload(request, form_data, user, model)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ raise HTTPException(
|
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
+ detail=str(e),
|
|
|
|
+ )
|
|
|
|
+
|
|
try:
|
|
try:
|
|
- return await chat_completion_handler(request, form_data, user, bypass_filter)
|
|
|
|
|
|
+ response = await chat_completion_handler(
|
|
|
|
+ request, form_data, user, bypass_filter
|
|
|
|
+ )
|
|
|
|
+ return await process_chat_response(response, events)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
@@ -1480,6 +845,7 @@ async def chat_completion(
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
+# Alias for chat_completion (Legacy)
|
|
generate_chat_completions = chat_completion
|
|
generate_chat_completions = chat_completion
|
|
generate_chat_completion = chat_completion
|
|
generate_chat_completion = chat_completion
|
|
|
|
|