|
@@ -249,21 +249,23 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
|
data = json.loads(body_str) if body_str else {}
|
|
data = json.loads(body_str) if body_str else {}
|
|
|
|
|
|
model_id = data["model"]
|
|
model_id = data["model"]
|
|
- valves = [
|
|
|
|
|
|
+ filters = [
|
|
model
|
|
model
|
|
for model in app.state.MODELS.values()
|
|
for model in app.state.MODELS.values()
|
|
if "pipeline" in model
|
|
if "pipeline" in model
|
|
and model["pipeline"]["type"] == "filter"
|
|
and model["pipeline"]["type"] == "filter"
|
|
- and model_id
|
|
|
|
- in [
|
|
|
|
- target_model["id"]
|
|
|
|
- for target_model in model["pipeline"]["pipelines"]
|
|
|
|
- ]
|
|
|
|
|
|
+ and (
|
|
|
|
+ model["pipeline"]["pipelines"] == ["*"]
|
|
|
|
+ or any(
|
|
|
|
+ model_id == target_model["id"]
|
|
|
|
+ for target_model in model["pipeline"]["pipelines"]
|
|
|
|
+ )
|
|
|
|
+ )
|
|
]
|
|
]
|
|
- sorted_valves = sorted(valves, key=lambda x: x["pipeline"]["priority"])
|
|
|
|
|
|
+ sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
|
|
|
|
|
user = None
|
|
user = None
|
|
- if len(sorted_valves) > 0:
|
|
|
|
|
|
+ if len(sorted_filters) > 0:
|
|
try:
|
|
try:
|
|
user = get_current_user(
|
|
user = get_current_user(
|
|
get_http_authorization_cred(
|
|
get_http_authorization_cred(
|
|
@@ -274,10 +276,12 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
|
except:
|
|
except:
|
|
pass
|
|
pass
|
|
|
|
|
|
- for valve in sorted_valves:
|
|
|
|
|
|
+ print(sorted_filters)
|
|
|
|
+
|
|
|
|
+ for filter in sorted_filters:
|
|
|
|
|
|
try:
|
|
try:
|
|
- urlIdx = valve["urlIdx"]
|
|
|
|
|
|
+ urlIdx = filter["urlIdx"]
|
|
|
|
|
|
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
|
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
|
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
|
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
|
@@ -289,7 +293,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
|
headers=headers,
|
|
headers=headers,
|
|
json={
|
|
json={
|
|
"user": user,
|
|
"user": user,
|
|
- "model": valve["id"],
|
|
|
|
|
|
+ "model": filter["id"],
|
|
"body": data,
|
|
"body": data,
|
|
},
|
|
},
|
|
)
|
|
)
|
|
@@ -298,7 +302,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
|
data = r.json()
|
|
data = r.json()
|
|
except Exception as e:
|
|
except Exception as e:
|
|
# Handle connection error here
|
|
# Handle connection error here
|
|
- log.error(f"Connection error: {e}")
|
|
|
|
|
|
+ print(f"Connection error: {e}")
|
|
pass
|
|
pass
|
|
|
|
|
|
modified_body_bytes = json.dumps(data).encode("utf-8")
|
|
modified_body_bytes = json.dumps(data).encode("utf-8")
|