Browse Source

feat: tools file handler support

Timothy J. Baek 10 months ago
parent
commit
a2e1ea103c
1 changed files with 34 additions and 19 deletions
  1. 34 19
      backend/main.py

+ 34 - 19
backend/main.py

@@ -241,6 +241,12 @@ async def get_function_call_response(
                     toolkit_module = load_toolkit_module_by_id(tool_id)
                     webui_app.state.TOOLS[tool_id] = toolkit_module
 
+                file_handler = False
+                # check if toolkit_module has file_handler self variable
+                if hasattr(toolkit_module, "file_handler"):
+                    file_handler = True
+                    print("file_handler: ", file_handler)
+
                 function = getattr(toolkit_module, result["name"])
                 function_result = None
                 try:
@@ -279,12 +285,12 @@ async def get_function_call_response(
                     print(e)
 
                 # Add the function result to the system prompt
-                if function_result:
-                    return function_result
+                if function_result is not None:
+                    return function_result, file_handler
     except Exception as e:
         print(f"Error: {e}")
 
-    return None
+    return None, False
 
 
 class ChatCompletionMiddleware(BaseHTTPMiddleware):
@@ -340,12 +346,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             context = ""
 
             # If tool_ids field is present, call the functions
+
+            skip_files = False
             if "tool_ids" in data:
                 print(data["tool_ids"])
                 for tool_id in data["tool_ids"]:
                     print(tool_id)
                     try:
-                        response = await get_function_call_response(
+                        response, file_handler = await get_function_call_response(
                             messages=data["messages"],
                             files=data.get("files", []),
                             tool_id=tool_id,
@@ -354,34 +362,41 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                             user=user,
                         )
 
+                        print(file_handler)
                         if isinstance(response, str):
                             context += ("\n" if context != "" else "") + response
 
+                        if file_handler:
+                            skip_files = True
+
                     except Exception as e:
                         print(f"Error: {e}")
                 del data["tool_ids"]
 
                 print(f"tool_context: {context}")
 
-            # TODO: Check if tools & functions have files support to skip this step to delegate file processing
             # If files field is present, generate RAG completions
+            # If skip_files is True, skip the RAG completions
             if "files" in data:
-                data = {**data}
-                rag_context, citations = get_rag_context(
-                    files=data["files"],
-                    messages=data["messages"],
-                    embedding_function=rag_app.state.EMBEDDING_FUNCTION,
-                    k=rag_app.state.config.TOP_K,
-                    reranking_function=rag_app.state.sentence_transformer_rf,
-                    r=rag_app.state.config.RELEVANCE_THRESHOLD,
-                    hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
-                )
+                if not skip_files:
+                    data = {**data}
+                    rag_context, citations = get_rag_context(
+                        files=data["files"],
+                        messages=data["messages"],
+                        embedding_function=rag_app.state.EMBEDDING_FUNCTION,
+                        k=rag_app.state.config.TOP_K,
+                        reranking_function=rag_app.state.sentence_transformer_rf,
+                        r=rag_app.state.config.RELEVANCE_THRESHOLD,
+                        hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
+                    )
+                    if rag_context:
+                        context += ("\n" if context != "" else "") + rag_context
 
-                if rag_context:
-                    context += ("\n" if context != "" else "") + rag_context
+                    log.debug(f"rag_context: {rag_context}, citations: {citations}")
+                else:
+                    return_citations = False
 
                 del data["files"]
-                log.debug(f"rag_context: {rag_context}, citations: {citations}")
 
             if context != "":
                 system_prompt = rag_template(
@@ -968,7 +983,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
     template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
 
     try:
-        context = await get_function_call_response(
+        context, file_handler = await get_function_call_response(
             form_data["messages"],
             form_data.get("files", []),
             form_data["tool_id"],