|
@@ -50,7 +50,9 @@ from typing import List, Optional
|
|
|
|
|
|
from apps.webui.models.models import Models, ModelModel
|
|
from apps.webui.models.models import Models, ModelModel
|
|
from apps.webui.models.tools import Tools
|
|
from apps.webui.models.tools import Tools
|
|
-from apps.webui.utils import load_toolkit_module_by_id
|
|
|
|
|
|
+from apps.webui.models.functions import Functions
|
|
|
|
+
|
|
|
|
+from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
|
|
|
|
|
|
|
|
|
|
from utils.utils import (
|
|
from utils.utils import (
|
|
@@ -318,9 +320,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
async def dispatch(self, request: Request, call_next):
|
|
data_items = []
|
|
data_items = []
|
|
|
|
|
|
- if request.method == "POST" and (
|
|
|
|
- "/ollama/api/chat" in request.url.path
|
|
|
|
- or "/chat/completions" in request.url.path
|
|
|
|
|
|
+ if request.method == "POST" and any(
|
|
|
|
+ endpoint in request.url.path
|
|
|
|
+ for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
|
):
|
|
):
|
|
log.debug(f"request.url.path: {request.url.path}")
|
|
log.debug(f"request.url.path: {request.url.path}")
|
|
|
|
|
|
@@ -328,23 +330,62 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
body = await request.body()
|
|
body = await request.body()
|
|
body_str = body.decode("utf-8")
|
|
body_str = body.decode("utf-8")
|
|
data = json.loads(body_str) if body_str else {}
|
|
data = json.loads(body_str) if body_str else {}
|
|
-
|
|
|
|
- model_id = data["model"]
|
|
|
|
user = get_current_user(
|
|
user = get_current_user(
|
|
request,
|
|
request,
|
|
get_http_authorization_cred(request.headers.get("Authorization")),
|
|
get_http_authorization_cred(request.headers.get("Authorization")),
|
|
)
|
|
)
|
|
|
|
|
|
- # Set the task model
|
|
|
|
- task_model_id = model_id
|
|
|
|
- if task_model_id not in app.state.MODELS:
|
|
|
|
|
|
+ # Flag to skip RAG completions if file_handler is present in tools/functions
|
|
|
|
+ skip_files = False
|
|
|
|
+
|
|
|
|
+ model_id = data["model"]
|
|
|
|
+ if model_id not in app.state.MODELS:
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Model not found",
|
|
detail="Model not found",
|
|
)
|
|
)
|
|
|
|
+ model = app.state.MODELS[model_id]
|
|
|
|
+
|
|
|
|
+ print(":", data)
|
|
|
|
+
|
|
|
|
+ # Check if the model has any filters
|
|
|
|
+ for filter_id in model["info"]["meta"].get("filterIds", []):
|
|
|
|
+ filter = Functions.get_function_by_id(filter_id)
|
|
|
|
+ if filter:
|
|
|
|
+ if filter_id in webui_app.state.FUNCTIONS:
|
|
|
|
+ function_module = webui_app.state.FUNCTIONS[filter_id]
|
|
|
|
+ else:
|
|
|
|
+ function_module, function_type = load_function_module_by_id(
|
|
|
|
+ filter_id
|
|
|
|
+ )
|
|
|
|
+ webui_app.state.FUNCTIONS[filter_id] = function_module
|
|
|
|
+
|
|
|
|
+ # Check if the function has a file_handler variable
|
|
|
|
+ if getattr(function_module, "file_handler"):
|
|
|
|
+ skip_files = True
|
|
|
|
|
|
- # Check if the user has a custom task model
|
|
|
|
- # If the user has a custom task model, use that model
|
|
|
|
|
|
+ try:
|
|
|
|
+ if hasattr(function_module, "inlet"):
|
|
|
|
+ data = function_module.inlet(
|
|
|
|
+ data,
|
|
|
|
+ {
|
|
|
|
+ "id": user.id,
|
|
|
|
+ "email": user.email,
|
|
|
|
+ "name": user.name,
|
|
|
|
+ "role": user.role,
|
|
|
|
+ },
|
|
|
|
+ )
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"Error: {e}")
|
|
|
|
+ raise HTTPException(
|
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
+ detail=e,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ print("Filtered:", data)
|
|
|
|
+ # Set the task model
|
|
|
|
+ task_model_id = data["model"]
|
|
|
|
+ # Check if the user has a custom task model and use that model
|
|
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
|
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
|
if (
|
|
if (
|
|
app.state.config.TASK_MODEL
|
|
app.state.config.TASK_MODEL
|
|
@@ -358,7 +399,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
):
|
|
):
|
|
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
|
|
|
- skip_files = False
|
|
|
|
prompt = get_last_user_message(data["messages"])
|
|
prompt = get_last_user_message(data["messages"])
|
|
context = ""
|
|
context = ""
|
|
|
|
|
|
@@ -409,8 +449,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
|
|
|
log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
|
log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
|
|
|
|
|
- if citations:
|
|
|
|
|
|
+ if citations and data.get("citations"):
|
|
data_items.append({"citations": citations})
|
|
data_items.append({"citations": citations})
|
|
|
|
+ del data["citations"]
|
|
|
|
|
|
del data["files"]
|
|
del data["files"]
|
|
|
|
|