Timothy J. Baek 9 kuukautta sitten
vanhempi
commit
7d7a29cfb9
1 muutettua tiedostoa jossa 9 lisäystä ja 7 poistoa
  1. 9 7
      backend/main.py

+ 9 - 7
backend/main.py

@@ -618,12 +618,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                     content={"detail": str(e)},
                     content={"detail": str(e)},
                 )
                 )
 
 
-            # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
-            task = None
-            if "task" in body:
-                task = body["task"]
-                del body["task"]
-
             # Extract session_id, chat_id and message_id from the request body
             # Extract session_id, chat_id and message_id from the request body
             session_id = None
             session_id = None
             if "session_id" in body:
             if "session_id" in body:
@@ -703,7 +697,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 "session_id": session_id,
                 "session_id": session_id,
                 "chat_id": chat_id,
                 "chat_id": chat_id,
                 "message_id": message_id,
                 "message_id": message_id,
-                "task": task,
             }
             }
 
 
             modified_body_bytes = json.dumps(body).encode("utf-8")
             modified_body_bytes = json.dumps(body).encode("utf-8")
@@ -1038,6 +1031,15 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
         )
         )
     model = app.state.MODELS[model_id]
     model = app.state.MODELS[model_id]
 
 
+    # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
+    task = None
+    if "task" in form_data:
+        task = form_data["task"]
+        del form_data["task"]
+
+    if "metadata" in form_data:
+        form_data["metadata"]['task'] = task
+
     if model.get("pipe"):
     if model.get("pipe"):
         return await generate_function_chat_completion(form_data, user=user)
         return await generate_function_chat_completion(form_data, user=user)
     if model["owned_by"] == "ollama":
     if model["owned_by"] == "ollama":