Ver código fonte

refac: get_sorted_pipelines()

Michael Poluektov 10 meses atrás
pai
commit
144581a7df
1 arquivos alterados com 12 adições e 29 exclusões
  1. 12 29
      backend/main.py

+ 12 - 29
backend/main.py

@@ -764,9 +764,7 @@ app.add_middleware(ChatCompletionMiddleware)
 ##################################
 
 
-def filter_pipeline(payload, user):
-    user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
-    model_id = payload["model"]
+def get_sorted_filters(model_id):
     filters = [
         model
         for model in app.state.MODELS.values()
@@ -782,6 +780,13 @@ def filter_pipeline(payload, user):
         )
     ]
     sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
+    return sorted_filters
+
+
+def filter_pipeline(payload, user):
+    user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
+    model_id = payload["model"]
+    sorted_filters = get_sorted_filters(model_id)
 
     model = app.state.MODELS[model_id]
 
@@ -814,19 +819,12 @@ def filter_pipeline(payload, user):
             print(f"Connection error: {e}")
 
             if r is not None:
-                try:
-                    res = r.json()
-                except:
-                    pass
+                res = r.json()
                 if "detail" in res:
                     raise Exception(r.status_code, res["detail"])
 
-            else:
-                pass
-
-    if "pipeline" not in app.state.MODELS[model_id]:
-        if "task" in payload:
-            del payload["task"]
+    if "pipeline" not in app.state.MODELS[model_id] and "task" in payload:
+        del payload["task"]
 
     return payload
 
@@ -1061,22 +1059,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
         )
     model = app.state.MODELS[model_id]
 
-    filters = [
-        model
-        for model in app.state.MODELS.values()
-        if "pipeline" in model
-        and "type" in model["pipeline"]
-        and model["pipeline"]["type"] == "filter"
-        and (
-            model["pipeline"]["pipelines"] == ["*"]
-            or any(
-                model_id == target_model_id
-                for target_model_id in model["pipeline"]["pipelines"]
-            )
-        )
-    ]
-
-    sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
+    sorted_filters = get_sorted_filters(model_id)
     if "pipeline" in model:
         sorted_filters = [model] + sorted_filters