|
@@ -241,6 +241,12 @@ async def get_function_call_response(
|
|
|
toolkit_module = load_toolkit_module_by_id(tool_id)
|
|
|
webui_app.state.TOOLS[tool_id] = toolkit_module
|
|
|
|
|
|
+ file_handler = False
|
|
|
+ # check if toolkit_module has file_handler self variable
|
|
|
+ if hasattr(toolkit_module, "file_handler"):
|
|
|
+ file_handler = True
|
|
|
+ print("file_handler: ", file_handler)
|
|
|
+
|
|
|
function = getattr(toolkit_module, result["name"])
|
|
|
function_result = None
|
|
|
try:
|
|
@@ -279,12 +285,12 @@ async def get_function_call_response(
|
|
|
print(e)
|
|
|
|
|
|
# Add the function result to the system prompt
|
|
|
- if function_result:
|
|
|
- return function_result
|
|
|
+ if function_result is not None:
|
|
|
+ return function_result, file_handler
|
|
|
except Exception as e:
|
|
|
print(f"Error: {e}")
|
|
|
|
|
|
- return None
|
|
|
+ return None, False
|
|
|
|
|
|
|
|
|
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
@@ -340,12 +346,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
context = ""
|
|
|
|
|
|
# If tool_ids field is present, call the functions
|
|
|
+
|
|
|
+ skip_files = False
|
|
|
if "tool_ids" in data:
|
|
|
print(data["tool_ids"])
|
|
|
for tool_id in data["tool_ids"]:
|
|
|
print(tool_id)
|
|
|
try:
|
|
|
- response = await get_function_call_response(
|
|
|
+ response, file_handler = await get_function_call_response(
|
|
|
messages=data["messages"],
|
|
|
files=data.get("files", []),
|
|
|
tool_id=tool_id,
|
|
@@ -354,34 +362,41 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
user=user,
|
|
|
)
|
|
|
|
|
|
+ print(file_handler)
|
|
|
if isinstance(response, str):
|
|
|
context += ("\n" if context != "" else "") + response
|
|
|
|
|
|
+ if file_handler:
|
|
|
+ skip_files = True
|
|
|
+
|
|
|
except Exception as e:
|
|
|
print(f"Error: {e}")
|
|
|
del data["tool_ids"]
|
|
|
|
|
|
print(f"tool_context: {context}")
|
|
|
|
|
|
- # TODO: Check if tools & functions have files support to skip this step to delegate file processing
|
|
|
# If files field is present, generate RAG completions
|
|
|
+ # If skip_files is True, skip the RAG completions
|
|
|
if "files" in data:
|
|
|
- data = {**data}
|
|
|
- rag_context, citations = get_rag_context(
|
|
|
- files=data["files"],
|
|
|
- messages=data["messages"],
|
|
|
- embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
|
|
- k=rag_app.state.config.TOP_K,
|
|
|
- reranking_function=rag_app.state.sentence_transformer_rf,
|
|
|
- r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
|
|
- hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
|
|
- )
|
|
|
+ if not skip_files:
|
|
|
+ data = {**data}
|
|
|
+ rag_context, citations = get_rag_context(
|
|
|
+ files=data["files"],
|
|
|
+ messages=data["messages"],
|
|
|
+ embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
|
|
+ k=rag_app.state.config.TOP_K,
|
|
|
+ reranking_function=rag_app.state.sentence_transformer_rf,
|
|
|
+ r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
|
|
+ hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
|
|
+ )
|
|
|
+ if rag_context:
|
|
|
+ context += ("\n" if context != "" else "") + rag_context
|
|
|
|
|
|
- if rag_context:
|
|
|
- context += ("\n" if context != "" else "") + rag_context
|
|
|
+ log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
|
|
+ else:
|
|
|
+ return_citations = False
|
|
|
|
|
|
del data["files"]
|
|
|
- log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
|
|
|
|
|
if context != "":
|
|
|
system_prompt = rag_template(
|
|
@@ -968,7 +983,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
|
|
|
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
|
|
|
|
try:
|
|
|
- context = await get_function_call_response(
|
|
|
+ context, file_handler = await get_function_call_response(
|
|
|
form_data["messages"],
|
|
|
form_data.get("files", []),
|
|
|
form_data["tool_id"],
|