Browse Source

feat: user hook

Timothy J. Baek 10 tháng trước cách đây
mục cha
commit
8a86f32700
1 tập tin đã thay đổi với 23 bổ sung2 xóa
  1. 23 2
      backend/main.py

+ 23 - 2
backend/main.py

@@ -11,6 +11,7 @@ import requests
 import mimetypes
 import shutil
 import os
+import inspect
 import asyncio
 
 from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
@@ -204,6 +205,8 @@ async def get_function_call_response(prompt, tool_id, template, task_model_id, u
 
         # Parse the function response
         if content is not None:
+
+            print(content)
             result = json.loads(content)
             print(result)
 
@@ -218,7 +221,24 @@ async def get_function_call_response(prompt, tool_id, template, task_model_id, u
                 function = getattr(toolkit_module, result["name"])
                 function_result = None
                 try:
-                    function_result = function(**result["parameters"])
+                    # Get the signature of the function
+                    sig = inspect.signature(function)
+                    # Check if 'user' is a parameter of the function
+                    if "user" in sig.parameters:
+                        # Call the function with the 'user' parameter included
+                        function_result = function(
+                            **{
+                                **result["parameters"],
+                                "user": {
+                                    "id": user.id,
+                                    "name": user.name,
+                                    "role": user.role,
+                                },
+                            }
+                        )
+                    else:
+                        # Call the function without modifying the parameters
+                        function_result = function(**result["parameters"])
                 except Exception as e:
                     print(e)
 
@@ -284,6 +304,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
             # If tool_ids field is present, call the functions
             if "tool_ids" in data:
+                print(data["tool_ids"])
                 prompt = get_last_user_message(data["messages"])
                 for tool_id in data["tool_ids"]:
                     print(tool_id)
@@ -299,7 +320,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                         context += ("\n" if context != "" else "") + response
                 del data["tool_ids"]
 
-                print(context)
+                print(f"tool_context: {context}")
 
             # If docs field is present, generate RAG completions
             if "docs" in data: