Browse Source

feat: filter inlet support

Timothy J. Baek 10 months ago
parent
commit
c4bd60114e

+ 54 - 13
backend/main.py

@@ -50,7 +50,9 @@ from typing import List, Optional
 
 from apps.webui.models.models import Models, ModelModel
 from apps.webui.models.tools import Tools
-from apps.webui.utils import load_toolkit_module_by_id
+from apps.webui.models.functions import Functions
+
+from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
 
 
 from utils.utils import (
@@ -318,9 +320,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
         data_items = []
 
-        if request.method == "POST" and (
-            "/ollama/api/chat" in request.url.path
-            or "/chat/completions" in request.url.path
+        if request.method == "POST" and any(
+            endpoint in request.url.path
+            for endpoint in ["/ollama/api/chat", "/chat/completions"]
         ):
             log.debug(f"request.url.path: {request.url.path}")
 
@@ -328,23 +330,62 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             body = await request.body()
             body_str = body.decode("utf-8")
             data = json.loads(body_str) if body_str else {}
-
-            model_id = data["model"]
             user = get_current_user(
                 request,
                 get_http_authorization_cred(request.headers.get("Authorization")),
             )
 
-            # Set the task model
-            task_model_id = model_id
-            if task_model_id not in app.state.MODELS:
+            # Flag to skip RAG completions if file_handler is present in tools/functions
+            skip_files = False
+
+            model_id = data["model"]
+            if model_id not in app.state.MODELS:
                 raise HTTPException(
                     status_code=status.HTTP_404_NOT_FOUND,
                     detail="Model not found",
                 )
+            model = app.state.MODELS[model_id]
+
+            print(":", data)
+
+            # Check if the model has any filters
+            for filter_id in model["info"]["meta"].get("filterIds", []):
+                filter = Functions.get_function_by_id(filter_id)
+                if filter:
+                    if filter_id in webui_app.state.FUNCTIONS:
+                        function_module = webui_app.state.FUNCTIONS[filter_id]
+                    else:
+                        function_module, function_type = load_function_module_by_id(
+                            filter_id
+                        )
+                        webui_app.state.FUNCTIONS[filter_id] = function_module
+
+                    # Check if the function has a file_handler variable
+                    if getattr(function_module, "file_handler"):
+                        skip_files = True
 
-            # Check if the user has a custom task model
-            # If the user has a custom task model, use that model
+                    try:
+                        if hasattr(function_module, "inlet"):
+                            data = function_module.inlet(
+                                data,
+                                {
+                                    "id": user.id,
+                                    "email": user.email,
+                                    "name": user.name,
+                                    "role": user.role,
+                                },
+                            )
+                    except Exception as e:
+                        print(f"Error: {e}")
+                        raise HTTPException(
+                            status_code=status.HTTP_400_BAD_REQUEST,
+                            detail=e,
+                        )
+
+            print("Filtered:", data)
+            # Set the task model
+            task_model_id = data["model"]
+            # Check if the user has a custom task model and use that model
             if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
                 if (
                     app.state.config.TASK_MODEL
@@ -358,7 +399,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 ):
                     task_model_id = app.state.config.TASK_MODEL_EXTERNAL
 
-            skip_files = False
             prompt = get_last_user_message(data["messages"])
             context = ""
 
@@ -409,8 +449,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
                     log.debug(f"rag_context: {rag_context}, citations: {citations}")
 
-                    if citations:
+                    if citations and data.get("citations"):
                         data_items.append({"citations": citations})
+                        del data["citations"]
 
                 del data["files"]
 

+ 3 - 2
src/lib/components/chat/Chat.svelte

@@ -630,7 +630,7 @@
 			keep_alive: $settings.keepAlive ?? undefined,
 			tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 			files: files.length > 0 ? files : undefined,
-			citations: files.length > 0,
+			citations: files.length > 0 ? true : undefined,
 			chat_id: $chatId
 		});
 
@@ -928,7 +928,8 @@
 					max_tokens: $settings?.params?.max_tokens ?? undefined,
 					tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 					files: files.length > 0 ? files : undefined,
-					citations: files.length > 0,
+					citations: files.length > 0 ? true : undefined,
+
 					chat_id: $chatId
 				},
 				`${OPENAI_API_BASE_URL}`

+ 1 - 0
src/lib/components/workspace/Models/FiltersSelector.svelte

@@ -31,6 +31,7 @@
 		{$i18n.t('To select filters here, add them to the "Functions" workspace first.')}
 	</div>
 
+	<!-- TODO: Filer order matters -->
 	<div class="flex flex-col">
 		{#if filters.length > 0}
 			<div class=" flex items-center mt-2 flex-wrap">