|
@@ -170,7 +170,9 @@ app.state.MODELS = {}
|
|
|
origins = ["*"]
|
|
|
|
|
|
|
|
|
-async def get_function_call_response(messages, tool_id, template, task_model_id, user):
|
|
|
+async def get_function_call_response(
|
|
|
+ messages, files, tool_id, template, task_model_id, user
|
|
|
+):
|
|
|
tool = Tools.get_tool_by_id(tool_id)
|
|
|
tools_specs = json.dumps(tool.specs, indent=2)
|
|
|
content = tools_function_calling_generation_template(template, tools_specs)
|
|
@@ -265,6 +267,13 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
|
|
|
"__messages__": messages,
|
|
|
}
|
|
|
|
|
|
+ if "__files__" in sig.parameters:
|
|
|
+ # Call the function with the '__files__' parameter included
|
|
|
+ params = {
|
|
|
+ **params,
|
|
|
+ "__files__": files,
|
|
|
+ }
|
|
|
+
|
|
|
function_result = function(**params)
|
|
|
except Exception as e:
|
|
|
print(e)
|
|
@@ -338,6 +347,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
try:
|
|
|
response = await get_function_call_response(
|
|
|
messages=data["messages"],
|
|
|
+ files=data.get("files", []),
|
|
|
tool_id=tool_id,
|
|
|
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
|
task_model_id=task_model_id,
|
|
@@ -353,7 +363,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
|
|
|
print(f"tool_context: {context}")
|
|
|
|
|
|
- # If docs field is present, generate RAG completions
|
|
|
+ # TODO: Check if tools & functions have files support to skip this step to delegate file processing
|
|
|
+ # If files field is present, generate RAG completions
|
|
|
if "files" in data:
|
|
|
data = {**data}
|
|
|
rag_context, citations = get_rag_context(
|
|
@@ -376,15 +387,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
system_prompt = rag_template(
|
|
|
rag_app.state.config.RAG_TEMPLATE, context, prompt
|
|
|
)
|
|
|
-
|
|
|
print(system_prompt)
|
|
|
-
|
|
|
data["messages"] = add_or_update_system_message(
|
|
|
f"\n{system_prompt}", data["messages"]
|
|
|
)
|
|
|
|
|
|
modified_body_bytes = json.dumps(data).encode("utf-8")
|
|
|
-
|
|
|
# Replace the request body with the modified one
|
|
|
request._body = modified_body_bytes
|
|
|
# Set custom header to ensure content-length matches new body length
|
|
@@ -961,7 +969,12 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
|
|
|
|
|
|
try:
|
|
|
context = await get_function_call_response(
|
|
|
- form_data["messages"], form_data["tool_id"], template, model_id, user
|
|
|
+ form_data["messages"],
|
|
|
+ form_data.get("files", []),
|
|
|
+ form_data["tool_id"],
|
|
|
+ template,
|
|
|
+ model_id,
|
|
|
+ user,
|
|
|
)
|
|
|
return context
|
|
|
except Exception as e:
|