|
@@ -141,7 +141,8 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|
|
return_citations = False
|
|
|
|
|
|
if request.method == "POST" and (
|
|
|
- "/api/chat" in request.url.path or "/chat/completions" in request.url.path
|
|
|
+ "/ollama/api/chat" in request.url.path
|
|
|
+ or "/chat/completions" in request.url.path
|
|
|
):
|
|
|
log.debug(f"request.url.path: {request.url.path}")
|
|
|
|
|
@@ -229,7 +230,8 @@ app.add_middleware(RAGMiddleware)
|
|
|
class PipelineMiddleware(BaseHTTPMiddleware):
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
|
if request.method == "POST" and (
|
|
|
- "/api/chat" in request.url.path or "/chat/completions" in request.url.path
|
|
|
+ "/ollama/api/chat" in request.url.path
|
|
|
+ or "/chat/completions" in request.url.path
|
|
|
):
|
|
|
log.debug(f"request.url.path: {request.url.path}")
|
|
|
|
|
@@ -308,6 +310,9 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
|
|
else:
|
|
|
pass
|
|
|
|
|
|
+ if "chat_id" in data:
|
|
|
+ del data["chat_id"]
|
|
|
+
|
|
|
modified_body_bytes = json.dumps(data).encode("utf-8")
|
|
|
# Replace the request body with the modified one
|
|
|
request._body = modified_body_bytes
|
|
@@ -464,6 +469,69 @@ async def get_models(user=Depends(get_verified_user)):
|
|
|
return {"data": models}
|
|
|
|
|
|
|
|
|
+@app.post("/api/chat/completed")
|
|
|
+async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
|
|
+ data = form_data
|
|
|
+ model_id = data["model"]
|
|
|
+
|
|
|
+ filters = [
|
|
|
+ model
|
|
|
+ for model in app.state.MODELS.values()
|
|
|
+ if "pipeline" in model
|
|
|
+ and "type" in model["pipeline"]
|
|
|
+ and model["pipeline"]["type"] == "filter"
|
|
|
+ and (
|
|
|
+ model["pipeline"]["pipelines"] == ["*"]
|
|
|
+ or any(
|
|
|
+ model_id == target_model_id
|
|
|
+ for target_model_id in model["pipeline"]["pipelines"]
|
|
|
+ )
|
|
|
+ )
|
|
|
+ ]
|
|
|
+ sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
|
|
+
|
|
|
+ for filter in sorted_filters:
|
|
|
+ r = None
|
|
|
+ try:
|
|
|
+ urlIdx = filter["urlIdx"]
|
|
|
+
|
|
|
+ url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
|
|
+ key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
|
|
+
|
|
|
+ if key != "":
|
|
|
+ headers = {"Authorization": f"Bearer {key}"}
|
|
|
+ r = requests.post(
|
|
|
+ f"{url}/{filter['id']}/filter/outlet",
|
|
|
+ headers=headers,
|
|
|
+ json={
|
|
|
+ "user": {"id": user.id, "name": user.name, "role": user.role},
|
|
|
+ "body": data,
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ r.raise_for_status()
|
|
|
+ data = r.json()
|
|
|
+ except Exception as e:
|
|
|
+ # Handle connection error here
|
|
|
+ print(f"Connection error: {e}")
|
|
|
+
|
|
|
+ if r is not None:
|
|
|
+ try:
|
|
|
+ res = r.json()
|
|
|
+ if "detail" in res:
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=r.status_code,
|
|
|
+ content=res,
|
|
|
+ )
|
|
|
+ except:
|
|
|
+ pass
|
|
|
+
|
|
|
+ else:
|
|
|
+ pass
|
|
|
+
|
|
|
+ return data
|
|
|
+
|
|
|
+
|
|
|
@app.get("/api/pipelines/list")
|
|
|
async def get_pipelines_list(user=Depends(get_admin_user)):
|
|
|
responses = await get_openai_models(raw=True)
|