浏览代码

refac: pipelines

Timothy Jaeryang Baek 2 月之前
父节点
当前提交
19c340d3fb
共有 3 个文件被更改,包括 83 次插入70 次删除
  1. 59 51
      backend/open_webui/routers/pipelines.py
  2. 1 7
      backend/open_webui/utils/chat.py
  3. 23 12
      backend/open_webui/utils/middleware.py

+ 59 - 51
backend/open_webui/routers/pipelines.py

@@ -9,6 +9,7 @@ from fastapi import (
     status,
     APIRouter,
 )
+import aiohttp
 import os
 import logging
 import shutil
@@ -56,96 +57,103 @@ def get_sorted_filters(model_id, models):
     return sorted_filters
 
 
-def process_pipeline_inlet_filter(request, payload, user, models):
+async def process_pipeline_inlet_filter(request, payload, user, models):
     user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
     model_id = payload["model"]
-
     sorted_filters = get_sorted_filters(model_id, models)
     model = models[model_id]
 
     if "pipeline" in model:
         sorted_filters.append(model)
 
-    for filter in sorted_filters:
-        r = None
-        try:
-            urlIdx = filter["urlIdx"]
+    async with aiohttp.ClientSession() as session:
+        for filter in sorted_filters:
+            urlIdx = filter.get("urlIdx")
+            if urlIdx is None:
+                continue
 
             url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
             key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
 
-            if key == "":
+            if not key:
                 continue
 
             headers = {"Authorization": f"Bearer {key}"}
-            r = requests.post(
-                f"{url}/{filter['id']}/filter/inlet",
-                headers=headers,
-                json={
-                    "user": user,
-                    "body": payload,
-                },
-            )
-
-            r.raise_for_status()
-            payload = r.json()
-        except Exception as e:
-            # Handle connection error here
-            print(f"Connection error: {e}")
+            request_data = {
+                "user": user,
+                "body": payload,
+            }
 
-            if r is not None:
-                res = r.json()
+            try:
+                async with session.post(
+                    f"{url}/{filter['id']}/filter/inlet",
+                    headers=headers,
+                    json=request_data,
+                ) as response:
+                    response.raise_for_status()
+                    payload = await response.json()
+            except aiohttp.ClientResponseError as e:
+                res = (
+                    await response.json()
+                    if response.content_type == "application/json"
+                    else {}
+                )
                 if "detail" in res:
-                    raise Exception(r.status_code, res["detail"])
+                    raise Exception(response.status, res["detail"])
+            except Exception as e:
+                print(f"Connection error: {e}")
 
     return payload
 
 
-def process_pipeline_outlet_filter(request, payload, user, models):
+async def process_pipeline_outlet_filter(request, payload, user, models):
     user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
     model_id = payload["model"]
-
     sorted_filters = get_sorted_filters(model_id, models)
     model = models[model_id]
 
     if "pipeline" in model:
         sorted_filters = [model] + sorted_filters
 
-    for filter in sorted_filters:
-        r = None
-        try:
-            urlIdx = filter["urlIdx"]
+    async with aiohttp.ClientSession() as session:
+        for filter in sorted_filters:
+            urlIdx = filter.get("urlIdx")
+            if urlIdx is None:
+                continue
 
             url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
             key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
 
-            if key != "":
-                r = requests.post(
-                    f"{url}/{filter['id']}/filter/outlet",
-                    headers={"Authorization": f"Bearer {key}"},
-                    json={
-                        "user": user,
-                        "body": payload,
-                    },
-                )
+            if not key:
+                continue
 
-                r.raise_for_status()
-                data = r.json()
-                payload = data
-        except Exception as e:
-            # Handle connection error here
-            print(f"Connection error: {e}")
+            headers = {"Authorization": f"Bearer {key}"}
+            request_data = {
+                "user": user,
+                "body": payload,
+            }
 
-            if r is not None:
+            try:
+                async with session.post(
+                    f"{url}/{filter['id']}/filter/outlet",
+                    headers=headers,
+                    json=request_data,
+                ) as response:
+                    response.raise_for_status()
+                    payload = await response.json()
+            except aiohttp.ClientResponseError as e:
                 try:
-                    res = r.json()
+                    res = (
+                        await response.json()
+                        if "application/json" in response.content_type
+                        else {}
+                    )
                     if "detail" in res:
-                        return Exception(r.status_code, res)
+                        raise Exception(response.status, res)
                 except Exception:
                     pass
-
-            else:
-                pass
+            except Exception as e:
+                print(f"Connection error: {e}")
 
     return payload
 

+ 1 - 7
backend/open_webui/utils/chat.py

@@ -186,12 +186,6 @@ async def generate_chat_completion(
     if model_id not in models:
         raise Exception("Model not found")
 
-    # Process the form_data through the pipeline
-    try:
-        form_data = process_pipeline_inlet_filter(request, form_data, user, models)
-    except Exception as e:
-        raise e
-
     model = models[model_id]
 
     if getattr(request.state, "direct", False):
@@ -308,7 +302,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
     model = models[model_id]
 
     try:
-        data = process_pipeline_outlet_filter(request, data, user, models)
+        data = await process_pipeline_outlet_filter(request, data, user, models)
     except Exception as e:
         return Exception(f"Error: {e}")
 

+ 23 - 12
backend/open_webui/utils/middleware.py

@@ -39,7 +39,10 @@ from open_webui.routers.tasks import (
 )
 from open_webui.routers.retrieval import process_web_search, SearchForm
 from open_webui.routers.images import image_generations, GenerateImageForm
-
+from open_webui.routers.pipelines import (
+    process_pipeline_inlet_filter,
+    process_pipeline_outlet_filter,
+)
 
 from open_webui.utils.webhook import post_webhook
 
@@ -676,6 +679,25 @@ async def process_chat_payload(request, form_data, metadata, user, model):
 
     variables = form_data.pop("variables", None)
 
+    # Process the form_data through the pipeline
+    try:
+        form_data = await process_pipeline_inlet_filter(
+            request, form_data, user, models
+        )
+    except Exception as e:
+        raise e
+
+    try:
+        form_data, flags = await process_filter_functions(
+            request=request,
+            filter_ids=get_sorted_filter_ids(model),
+            filter_type="inlet",
+            form_data=form_data,
+            extra_params=extra_params,
+        )
+    except Exception as e:
+        raise Exception(f"Error: {e}")
+
     features = form_data.pop("features", None)
     if features:
         if "web_search" in features and features["web_search"]:
@@ -698,17 +720,6 @@ async def process_chat_payload(request, form_data, metadata, user, model):
                 form_data["messages"],
             )
 
-    try:
-        form_data, flags = await process_filter_functions(
-            request=request,
-            filter_ids=get_sorted_filter_ids(model),
-            filter_type="inlet",
-            form_data=form_data,
-            extra_params=extra_params,
-        )
-    except Exception as e:
-        raise Exception(f"Error: {e}")
-
     tool_ids = form_data.pop("tool_ids", None)
     files = form_data.pop("files", None)
     # Remove files duplicates