Quellcode durchsuchen

feat: filter func outlet

Timothy J. Baek vor 10 Monaten
Ursprung
Commit
afd270523c
2 geänderte Dateien mit 45 neuen und 12 gelöschten Zeilen
  1. 42 11
      backend/main.py
  2. 3 1
      src/lib/components/chat/Chat.svelte

+ 42 - 11
backend/main.py

@@ -474,10 +474,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 ],
             ]
 
-        response = await call_next(request)
-
-        # If there are data_items to inject into the response
-        if len(data_items) > 0:
+            response = await call_next(request)
             if isinstance(response, StreamingResponse):
                 # If it's a streaming response, inject it as SSE event or NDJSON line
                 content_type = response.headers.get("Content-Type")
@@ -489,7 +486,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                     return StreamingResponse(
                         self.ollama_stream_wrapper(response.body_iterator, data_items),
                     )
+            else:
+                return response
 
+        # If it's not a chat completion request, just pass it through
+        response = await call_next(request)
         return response
 
     async def _receive(self, body: bytes):
@@ -800,6 +801,12 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
 async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
     data = form_data
     model_id = data["model"]
+    if model_id not in app.state.MODELS:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="Model not found",
+        )
+    model = app.state.MODELS[model_id]
 
     filters = [
         model
@@ -815,14 +822,10 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
             )
         )
     ]
-    sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
 
-    print(model_id)
-
-    if model_id in app.state.MODELS:
-        model = app.state.MODELS[model_id]
-        if "pipeline" in model:
-            sorted_filters = [model] + sorted_filters
+    sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
+    if "pipeline" in model:
+        sorted_filters = [model] + sorted_filters
 
     for filter in sorted_filters:
         r = None
@@ -863,6 +866,34 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
             else:
                 pass
 
+    # Check if the model has any filters
+    for filter_id in model["info"]["meta"].get("filterIds", []):
+        filter = Functions.get_function_by_id(filter_id)
+        if filter:
+            if filter_id in webui_app.state.FUNCTIONS:
+                function_module = webui_app.state.FUNCTIONS[filter_id]
+            else:
+                function_module, function_type = load_function_module_by_id(filter_id)
+                webui_app.state.FUNCTIONS[filter_id] = function_module
+
+            try:
+                if hasattr(function_module, "outlet"):
+                    data = function_module.outlet(
+                        data,
+                        {
+                            "id": user.id,
+                            "email": user.email,
+                            "name": user.name,
+                            "role": user.role,
+                        },
+                    )
+            except Exception as e:
+                print(f"Error: {e}")
+                return JSONResponse(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    content={"detail": str(e)},
+                )
+
     return data
 
 

+ 3 - 1
src/lib/components/chat/Chat.svelte

@@ -278,7 +278,9 @@
 			})),
 			chat_id: $chatId
 		}).catch((error) => {
-			console.error(error);
+			toast.error(error);
+			messages.at(-1).error = { content: error };
+
 			return null;
 		});