|
@@ -168,11 +168,25 @@ app.state.MODELS = {}
|
|
|
origins = ["*"]
|
|
|
|
|
|
|
|
|
-async def get_function_call_response(prompt, tool_id, template, task_model_id, user):
|
|
|
+async def get_function_call_response(messages, tool_id, template, task_model_id, user):
|
|
|
tool = Tools.get_tool_by_id(tool_id)
|
|
|
tools_specs = json.dumps(tool.specs, indent=2)
|
|
|
content = tools_function_calling_generation_template(template, tools_specs)
|
|
|
|
|
|
+ user_message = get_last_user_message(messages)
|
|
|
+ prompt = (
|
|
|
+ "History:\n"
|
|
|
+ + "\n".join(
|
|
|
+ [
|
|
|
+ f"{message['role']}: {message['content']}"
|
|
|
+ for message in messages[::-1][:4]
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ + f"\nQuery: {user_message}"
|
|
|
+ )
|
|
|
+
|
|
|
+ print(prompt)
|
|
|
+
|
|
|
payload = {
|
|
|
"model": task_model_id,
|
|
|
"messages": [
|
|
@@ -300,16 +314,16 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
):
|
|
|
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
|
|
|
+ prompt = get_last_user_message(data["messages"])
|
|
|
context = ""
|
|
|
|
|
|
# If tool_ids field is present, call the functions
|
|
|
if "tool_ids" in data:
|
|
|
print(data["tool_ids"])
|
|
|
- prompt = get_last_user_message(data["messages"])
|
|
|
for tool_id in data["tool_ids"]:
|
|
|
print(tool_id)
|
|
|
response = await get_function_call_response(
|
|
|
- prompt=prompt,
|
|
|
+ messages=data["messages"],
|
|
|
tool_id=tool_id,
|
|
|
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
|
task_model_id=task_model_id,
|
|
@@ -839,7 +853,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
|
|
|
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
|
|
|
|
return await get_function_call_response(
|
|
|
- form_data["prompt"], form_data["tool_id"], template, model_id, user
|
|
|
+ form_data["messages"], form_data["tool_id"], template, model_id, user
|
|
|
)
|
|
|
|
|
|
|