Michael Poluektov преди 8 месеца
родител
ревизия
0c9119d619
променени са 1 файла, в които са добавени 18 реда и са изтрити 23 реда
  1. 18 23
      backend/main.py

+ 18 - 23
backend/main.py

@@ -317,7 +317,7 @@ async def get_function_call_response(
             {"role": "user", "content": f"Query: {prompt}"},
         ],
         "stream": False,
-        "task": str(TASKS.FUNCTION_CALLING),
+        "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
     }
 
     try:
@@ -788,19 +788,21 @@ def filter_pipeline(payload, user):
             url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
             key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
 
-            if key != "":
-                headers = {"Authorization": f"Bearer {key}"}
-                r = requests.post(
-                    f"{url}/{filter['id']}/filter/inlet",
-                    headers=headers,
-                    json={
-                        "user": user,
-                        "body": payload,
-                    },
-                )
+            if key == "":
+                continue
+
+            headers = {"Authorization": f"Bearer {key}"}
+            r = requests.post(
+                f"{url}/{filter['id']}/filter/inlet",
+                headers=headers,
+                json={
+                    "user": user,
+                    "body": payload,
+                },
+            )
 
-                r.raise_for_status()
-                payload = r.json()
+            r.raise_for_status()
+            payload = r.json()
         except Exception as e:
             # Handle connection error here
             print(f"Connection error: {e}")
@@ -1086,13 +1088,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
         )
     model = app.state.MODELS[model_id]
 
-    # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
-    if task := form_data.pop("task", None):
-        if "metadata" in form_data:
-            form_data["metadata"]["task"] = task
-        else:
-            form_data["metadata"] = {"task": task}
-
     if model.get("pipe"):
         return await generate_function_chat_completion(form_data, user=user)
     if model["owned_by"] == "ollama":
@@ -1469,7 +1464,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
         "stream": False,
         "max_tokens": 50,
         "chat_id": form_data.get("chat_id", None),
-        "task": str(TASKS.TITLE_GENERATION),
+        "metadata": {"task": str(TASKS.TITLE_GENERATION)},
     }
 
     log.debug(payload)
@@ -1522,7 +1517,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
         "messages": [{"role": "user", "content": content}],
         "stream": False,
         "max_tokens": 30,
-        "task": str(TASKS.QUERY_GENERATION),
+        "metadata": {"task": str(TASKS.QUERY_GENERATION)},
     }
 
     print(payload)
@@ -1579,7 +1574,7 @@ Message: """{{prompt}}"""
         "stream": False,
         "max_tokens": 4,
         "chat_id": form_data.get("chat_id", None),
-        "task": str(TASKS.EMOJI_GENERATION),
+        "metadata": {"task": str(TASKS.EMOJI_GENERATION)},
     }
 
     log.debug(payload)