Timothy J. Baek 10 months ago
parent
commit
9d16dd997a
1 changed files with 18 additions and 4 deletions
  1. 18 4
      backend/main.py

+ 18 - 4
backend/main.py

@@ -168,11 +168,25 @@ app.state.MODELS = {}
 origins = ["*"]
 origins = ["*"]
 
 
 
 
-async def get_function_call_response(prompt, tool_id, template, task_model_id, user):
+async def get_function_call_response(messages, tool_id, template, task_model_id, user):
     tool = Tools.get_tool_by_id(tool_id)
     tool = Tools.get_tool_by_id(tool_id)
     tools_specs = json.dumps(tool.specs, indent=2)
     tools_specs = json.dumps(tool.specs, indent=2)
     content = tools_function_calling_generation_template(template, tools_specs)
     content = tools_function_calling_generation_template(template, tools_specs)
 
 
+    user_message = get_last_user_message(messages)
+    prompt = (
+        "History:\n"
+        + "\n".join(
+            [
+                f"{message['role']}: {message['content']}"
+                for message in messages[::-1][:4]
+            ]
+        )
+        + f"\nQuery: {user_message}"
+    )
+
+    print(prompt)
+
     payload = {
     payload = {
         "model": task_model_id,
         "model": task_model_id,
         "messages": [
         "messages": [
@@ -300,16 +314,16 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 ):
                 ):
                     task_model_id = app.state.config.TASK_MODEL_EXTERNAL
                     task_model_id = app.state.config.TASK_MODEL_EXTERNAL
 
 
+            prompt = get_last_user_message(data["messages"])
             context = ""
             context = ""
 
 
             # If tool_ids field is present, call the functions
             # If tool_ids field is present, call the functions
             if "tool_ids" in data:
             if "tool_ids" in data:
                 print(data["tool_ids"])
                 print(data["tool_ids"])
-                prompt = get_last_user_message(data["messages"])
                 for tool_id in data["tool_ids"]:
                 for tool_id in data["tool_ids"]:
                     print(tool_id)
                     print(tool_id)
                     response = await get_function_call_response(
                     response = await get_function_call_response(
-                        prompt=prompt,
+                        messages=data["messages"],
                         tool_id=tool_id,
                         tool_id=tool_id,
                         template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
                         template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
                         task_model_id=task_model_id,
                         task_model_id=task_model_id,
@@ -839,7 +853,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
     template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
     template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
 
 
     return await get_function_call_response(
     return await get_function_call_response(
-        form_data["prompt"], form_data["tool_id"], template, model_id, user
+        form_data["messages"], form_data["tool_id"], template, model_id, user
     )
     )