Explorar el Código

feat: tool citation

Timothy J. Baek hace 10 meses
padre
commit
6bb2f41812
Se han modificado 1 ficheros con 37 adiciones y 17 borrados
  1. 37 17
      backend/main.py

+ 37 - 17
backend/main.py

@@ -247,6 +247,7 @@ async def get_function_call_response(
             result = json.loads(content)
             result = json.loads(content)
             print(result)
             print(result)
 
 
+            citation = None
             # Call the function
             # Call the function
             if "name" in result:
             if "name" in result:
                 if tool_id in webui_app.state.TOOLS:
                 if tool_id in webui_app.state.TOOLS:
@@ -309,22 +310,32 @@ async def get_function_call_response(
                         }
                         }
 
 
                     function_result = function(**params)
                     function_result = function(**params)
+
+                    if hasattr(toolkit_module, "citation") and toolkit_module.citation:
+                        citation = {
+                            "source": {"name": f"TOOL:{tool.name}/{result['name']}"},
+                            "document": [function_result],
+                            "metadata": [{"source": result["name"]}],
+                        }
                 except Exception as e:
                 except Exception as e:
                     print(e)
                     print(e)
 
 
                 # Add the function result to the system prompt
                 # Add the function result to the system prompt
                 if function_result is not None:
                 if function_result is not None:
-                    return function_result, file_handler
+                    return function_result, citation, file_handler
     except Exception as e:
     except Exception as e:
         print(f"Error: {e}")
         print(f"Error: {e}")
 
 
-    return None, False
+    return None, None, False
 
 
 
 
 class ChatCompletionMiddleware(BaseHTTPMiddleware):
 class ChatCompletionMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
     async def dispatch(self, request: Request, call_next):
         data_items = []
         data_items = []
 
 
+        show_citations = False
+        citations = []
+
         if request.method == "POST" and any(
         if request.method == "POST" and any(
             endpoint in request.url.path
             endpoint in request.url.path
             for endpoint in ["/ollama/api/chat", "/chat/completions"]
             for endpoint in ["/ollama/api/chat", "/chat/completions"]
@@ -342,6 +353,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             )
             )
             # Flag to skip RAG completions if file_handler is present in tools/functions
             # Flag to skip RAG completions if file_handler is present in tools/functions
             skip_files = False
             skip_files = False
+            if data.get("citations"):
+                show_citations = True
+                del data["citations"]
 
 
             model_id = data["model"]
             model_id = data["model"]
             if model_id not in app.state.MODELS:
             if model_id not in app.state.MODELS:
@@ -365,8 +379,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                             webui_app.state.FUNCTIONS[filter_id] = function_module
                             webui_app.state.FUNCTIONS[filter_id] = function_module
 
 
                         # Check if the function has a file_handler variable
                         # Check if the function has a file_handler variable
-                        if getattr(function_module, "file_handler"):
-                            skip_files = True
+                        if hasattr(function_module, "file_handler"):
+                            skip_files = function_module.file_handler
 
 
                         try:
                         try:
                             if hasattr(function_module, "inlet"):
                             if hasattr(function_module, "inlet"):
@@ -411,19 +425,25 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 for tool_id in data["tool_ids"]:
                 for tool_id in data["tool_ids"]:
                     print(tool_id)
                     print(tool_id)
                     try:
                     try:
-                        response, file_handler = await get_function_call_response(
-                            messages=data["messages"],
-                            files=data.get("files", []),
-                            tool_id=tool_id,
-                            template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
-                            task_model_id=task_model_id,
-                            user=user,
+                        response, citation, file_handler = (
+                            await get_function_call_response(
+                                messages=data["messages"],
+                                files=data.get("files", []),
+                                tool_id=tool_id,
+                                template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
+                                task_model_id=task_model_id,
+                                user=user,
+                            )
                         )
                         )
 
 
                         print(file_handler)
                         print(file_handler)
                         if isinstance(response, str):
                         if isinstance(response, str):
                             context += ("\n" if context != "" else "") + response
                             context += ("\n" if context != "" else "") + response
 
 
+                        if citation:
+                            citations.append(citation)
+                            show_citations = True
+
                         if file_handler:
                         if file_handler:
                             skip_files = True
                             skip_files = True
 
 
@@ -438,7 +458,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             if "files" in data:
             if "files" in data:
                 if not skip_files:
                 if not skip_files:
                     data = {**data}
                     data = {**data}
-                    rag_context, citations = get_rag_context(
+                    rag_context, rag_citations = get_rag_context(
                         files=data["files"],
                         files=data["files"],
                         messages=data["messages"],
                         messages=data["messages"],
                         embedding_function=rag_app.state.EMBEDDING_FUNCTION,
                         embedding_function=rag_app.state.EMBEDDING_FUNCTION,
@@ -452,13 +472,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
 
                     log.debug(f"rag_context: {rag_context}, citations: {citations}")
                     log.debug(f"rag_context: {rag_context}, citations: {citations}")
 
 
-                    if citations and data.get("citations"):
-                        data_items.append({"citations": citations})
+                    if rag_citations:
+                        citations.extend(rag_citations)
 
 
                 del data["files"]
                 del data["files"]
 
 
-            if data.get("citations"):
-                del data["citations"]
+            if show_citations and len(citations) > 0:
+                data_items.append({"citations": citations})
 
 
             if context != "":
             if context != "":
                 system_prompt = rag_template(
                 system_prompt = rag_template(
@@ -1285,7 +1305,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
     template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
     template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
 
 
     try:
     try:
-        context, file_handler = await get_function_call_response(
+        context, citation, file_handler = await get_function_call_response(
             form_data["messages"],
             form_data["messages"],
             form_data.get("files", []),
             form_data.get("files", []),
             form_data["tool_id"],
             form_data["tool_id"],