浏览代码

feat: pipeline filter wildcard support

Timothy J. Baek 11 月之前
父节点
当前提交
ec36493d61
共有 1 个文件被更改,包括 16 次插入12 次删除
  1. 16 12
      backend/main.py

+ 16 - 12
backend/main.py

@@ -249,21 +249,23 @@ class PipelineMiddleware(BaseHTTPMiddleware):
             data = json.loads(body_str) if body_str else {}
 
             model_id = data["model"]
-            valves = [
+            filters = [
                 model
                 for model in app.state.MODELS.values()
                 if "pipeline" in model
                 and model["pipeline"]["type"] == "filter"
-                and model_id
-                in [
-                    target_model["id"]
-                    for target_model in model["pipeline"]["pipelines"]
-                ]
+                and (
+                    model["pipeline"]["pipelines"] == ["*"]
+                    or any(
+                        model_id == target_model["id"]
+                        for target_model in model["pipeline"]["pipelines"]
+                    )
+                )
             ]
-            sorted_valves = sorted(valves, key=lambda x: x["pipeline"]["priority"])
+            sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
 
             user = None
-            if len(sorted_valves) > 0:
+            if len(sorted_filters) > 0:
                 try:
                     user = get_current_user(
                         get_http_authorization_cred(
@@ -274,10 +276,12 @@ class PipelineMiddleware(BaseHTTPMiddleware):
                 except:
                     pass
 
-            for valve in sorted_valves:
+            print(sorted_filters)
+
+            for filter in sorted_filters:
 
                 try:
-                    urlIdx = valve["urlIdx"]
+                    urlIdx = filter["urlIdx"]
 
                     url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
                     key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
@@ -289,7 +293,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
                             headers=headers,
                             json={
                                 "user": user,
-                                "model": valve["id"],
+                                "model": filter["id"],
                                 "body": data,
                             },
                         )
@@ -298,7 +302,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
                         data = r.json()
                 except Exception as e:
                     # Handle connection error here
-                    log.error(f"Connection error: {e}")
+                    print(f"Connection error: {e}")
                     pass
 
             modified_body_bytes = json.dumps(data).encode("utf-8")