瀏覽代碼

factor out get_function_calling_payload

Michael Poluektov 8 月之前
父節點
當前提交
e86688284a
共有 1 個文件被更改,包括 18 次插入21 次删除
  1. 18 21
      backend/main.py

+ 18 - 21
backend/main.py

@@ -322,31 +322,16 @@ async def call_tool_from_completion(
         return None
 
 
-async def get_function_call_response(
-    messages, files, tool_id, template, task_model_id, user, extra_params
-) -> tuple[Optional[str], Optional[dict], bool]:
-    tool = Tools.get_tool_by_id(tool_id)
-    if tool is None:
-        return None, None, False
-
-    tools_specs = json.dumps(tool.specs, indent=2)
-    content = tools_function_calling_generation_template(template, tools_specs)
-
+def get_function_calling_payload(messages, task_model_id, content):
     user_message = get_last_user_message(messages)
-    prompt = (
-        "History:\n"
-        + "\n".join(
-            [
-                f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
-                for message in messages[::-1][:4]
-            ]
-        )
-        + f"\nQuery: {user_message}"
+    history = "\n".join(
+        f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
+        for message in messages[::-1][:4]
     )
 
-    print(prompt)
+    prompt = f"History:\n{history}\nQuery: {user_message}"
 
-    payload = {
+    return {
         "model": task_model_id,
         "messages": [
             {"role": "system", "content": content},
@@ -356,6 +341,18 @@ async def get_function_call_response(
         "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
     }
 
+
+async def get_function_call_response(
+    messages, files, tool_id, template, task_model_id, user, extra_params
+) -> tuple[Optional[str], Optional[dict], bool]:
+    tool = Tools.get_tool_by_id(tool_id)
+    if tool is None:
+        return None, None, False
+
+    tools_specs = json.dumps(tool.specs, indent=2)
+    content = tools_function_calling_generation_template(template, tools_specs)
+    payload = get_function_calling_payload(messages, task_model_id, content)
+
     try:
         payload = filter_pipeline(payload, user)
     except Exception as e: