فهرست منبع

feat: pipeline valve support

Timothy J. Baek 11 ماه پیش
والد
کامیت
cc6d9bb8c0
1فایلهای تغییر یافته به همراه85 افزوده شده و 0 حذف شده
  1. 85 0
      backend/main.py

+ 85 - 0
backend/main.py

@@ -229,6 +229,83 @@ class RAGMiddleware(BaseHTTPMiddleware):
 app.add_middleware(RAGMiddleware)
 
 
+class PipelineMiddleware(BaseHTTPMiddleware):
+    async def dispatch(self, request: Request, call_next):
+        if request.method == "POST" and (
+            "/api/chat" in request.url.path or "/chat/completions" in request.url.path
+        ):
+            log.debug(f"request.url.path: {request.url.path}")
+
+            # Read the original request body
+            body = await request.body()
+            # Decode body to string
+            body_str = body.decode("utf-8")
+            # Parse string to JSON
+            data = json.loads(body_str) if body_str else {}
+
+            model_id = data["model"]
+
+            valves = [
+                model
+                for model in app.state.MODELS.values()
+                if "pipeline" in model
+                and model["pipeline"]["type"] == "valve"
+                and model_id
+                in [
+                    target_model["id"]
+                    for target_model in model["pipeline"]["pipelines"]
+                ]
+            ]
+            sorted_valves = sorted(valves, key=lambda x: x["pipeline"]["priority"])
+
+            for valve in sorted_valves:
+                try:
+                    urlIdx = valve["urlIdx"]
+
+                    url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
+                    key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
+
+                    if key != "":
+                        headers = {"Authorization": f"Bearer {key}"}
+                        r = requests.post(
+                            f"{url}/valve",
+                            headers=headers,
+                            json={
+                                "model": valve["id"],
+                                "body": data,
+                            },
+                        )
+
+                        r.raise_for_status()
+                        data = r.json()
+                except Exception as e:
+                    # Handle connection error here
+                    log.error(f"Connection error: {e}")
+                    pass
+
+            modified_body_bytes = json.dumps(data).encode("utf-8")
+            # Replace the request body with the modified one
+            request._body = modified_body_bytes
+            # Set custom header to ensure content-length matches new body length
+            request.headers.__dict__["_list"] = [
+                (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
+                *[
+                    (k, v)
+                    for k, v in request.headers.raw
+                    if k.lower() != b"content-length"
+                ],
+            ]
+
+        response = await call_next(request)
+        return response
+
+    async def _receive(self, body: bytes):
+        return {"type": "http.request", "body": body, "more_body": False}
+
+
+app.add_middleware(PipelineMiddleware)
+
+
 @app.middleware("http")
 async def check_url(request: Request, call_next):
     if len(app.state.MODELS) == 0:
@@ -332,6 +409,14 @@ async def get_all_models():
 @app.get("/api/models")
 async def get_models(user=Depends(get_verified_user)):
     models = await get_all_models()
+
+    # Filter out valve models
+    models = [
+        model
+        for model in models
+        if "pipeline" not in model or model["pipeline"]["type"] != "valve"
+    ]
+
     if app.state.config.ENABLE_MODEL_FILTER:
         if user.role == "user":
             models = list(