|
@@ -229,6 +229,83 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|
|
app.add_middleware(RAGMiddleware)
|
|
|
|
|
|
|
|
|
+class PipelineMiddleware(BaseHTTPMiddleware):
|
|
|
+ async def dispatch(self, request: Request, call_next):
|
|
|
+ if request.method == "POST" and (
|
|
|
+ "/api/chat" in request.url.path or "/chat/completions" in request.url.path
|
|
|
+ ):
|
|
|
+ log.debug(f"request.url.path: {request.url.path}")
|
|
|
+
|
|
|
+ # Read the original request body
|
|
|
+ body = await request.body()
|
|
|
+ # Decode body to string
|
|
|
+ body_str = body.decode("utf-8")
|
|
|
+ # Parse string to JSON
|
|
|
+ data = json.loads(body_str) if body_str else {}
|
|
|
+
|
|
|
+ model_id = data["model"]
|
|
|
+
|
|
|
+ valves = [
|
|
|
+ model
|
|
|
+ for model in app.state.MODELS.values()
|
|
|
+ if "pipeline" in model
|
|
|
+ and model["pipeline"]["type"] == "valve"
|
|
|
+ and model_id
|
|
|
+ in [
|
|
|
+ target_model["id"]
|
|
|
+ for target_model in model["pipeline"]["pipelines"]
|
|
|
+ ]
|
|
|
+ ]
|
|
|
+ sorted_valves = sorted(valves, key=lambda x: x["pipeline"]["priority"])
|
|
|
+
|
|
|
+ for valve in sorted_valves:
|
|
|
+ try:
|
|
|
+ urlIdx = valve["urlIdx"]
|
|
|
+
|
|
|
+ url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
|
|
+ key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
|
|
+
|
|
|
+ if key != "":
|
|
|
+ headers = {"Authorization": f"Bearer {key}"}
|
|
|
+ r = requests.post(
|
|
|
+ f"{url}/valve",
|
|
|
+ headers=headers,
|
|
|
+ json={
|
|
|
+ "model": valve["id"],
|
|
|
+ "body": data,
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ r.raise_for_status()
|
|
|
+ data = r.json()
|
|
|
+ except Exception as e:
|
|
|
+ # Handle connection error here
|
|
|
+ log.error(f"Connection error: {e}")
|
|
|
+ pass
|
|
|
+
|
|
|
+ modified_body_bytes = json.dumps(data).encode("utf-8")
|
|
|
+ # Replace the request body with the modified one
|
|
|
+ request._body = modified_body_bytes
|
|
|
+ # Set custom header to ensure content-length matches new body length
|
|
|
+ request.headers.__dict__["_list"] = [
|
|
|
+ (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
|
|
+ *[
|
|
|
+ (k, v)
|
|
|
+ for k, v in request.headers.raw
|
|
|
+ if k.lower() != b"content-length"
|
|
|
+ ],
|
|
|
+ ]
|
|
|
+
|
|
|
+ response = await call_next(request)
|
|
|
+ return response
|
|
|
+
|
|
|
+ async def _receive(self, body: bytes):
|
|
|
+ return {"type": "http.request", "body": body, "more_body": False}
|
|
|
+
|
|
|
+
|
|
|
+app.add_middleware(PipelineMiddleware)
|
|
|
+
|
|
|
+
|
|
|
@app.middleware("http")
|
|
|
async def check_url(request: Request, call_next):
|
|
|
if len(app.state.MODELS) == 0:
|
|
@@ -332,6 +409,14 @@ async def get_all_models():
|
|
|
@app.get("/api/models")
|
|
|
async def get_models(user=Depends(get_verified_user)):
|
|
|
models = await get_all_models()
|
|
|
+
|
|
|
+ # Filter out valve models
|
|
|
+ models = [
|
|
|
+ model
|
|
|
+ for model in models
|
|
|
+ if "pipeline" not in model or model["pipeline"]["type"] != "valve"
|
|
|
+ ]
|
|
|
+
|
|
|
if app.state.config.ENABLE_MODEL_FILTER:
|
|
|
if user.role == "user":
|
|
|
models = list(
|