Browse Source

enh: __messages__ support for tools

Timothy J. Baek 10 tháng trước cách đây
mục cha
commit
55dfc2013a
2 tập tin đã thay đổi với 38 bổ sung17 xóa
  1. 22 16
      backend/main.py
  2. 16 1
      backend/utils/misc.py

+ 22 - 16
backend/main.py

@@ -244,23 +244,28 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
                 try:
                     # Get the signature of the function
                     sig = inspect.signature(function)
-                    # Check if '__user__' is a parameter of the function
+                    params = result["parameters"]
+
                     if "__user__" in sig.parameters:
                         # Call the function with the '__user__' parameter included
-                        function_result = function(
-                            **{
-                                **result["parameters"],
-                                "__user__": {
-                                    "id": user.id,
-                                    "email": user.email,
-                                    "name": user.name,
-                                    "role": user.role,
-                                },
-                            }
-                        )
-                    else:
-                        # Call the function without modifying the parameters
-                        function_result = function(**result["parameters"])
+                        params = {
+                            **params,
+                            "__user__": {
+                                "id": user.id,
+                                "email": user.email,
+                                "name": user.name,
+                                "role": user.role,
+                            },
+                        }
+
+                    if "__messages__" in sig.parameters:
+                        # Call the function with the '__messages__' parameter included
+                        params = {
+                            **params,
+                            "__messages__": messages,
+                        }
+
+                    function_result = function(**params)
                 except Exception as e:
                     print(e)
 
@@ -339,8 +344,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                             user=user,
                         )
 
-                        if response:
+                        if isinstance(response, str):
                             context += ("\n" if context != "" else "") + response
+
                     except Exception as e:
                         print(f"Error: {e}")
                 del data["tool_ids"]

+ 16 - 1
backend/utils/misc.py

@@ -3,7 +3,7 @@ import hashlib
 import json
 import re
 from datetime import timedelta
-from typing import Optional, List
+from typing import Optional, List, Tuple
 
 
 def get_last_user_message(messages: List[dict]) -> str:
@@ -28,6 +28,21 @@ def get_last_assistant_message(messages: List[dict]) -> str:
     return None
 
 
+def get_system_message(messages: List[dict]) -> dict:
+    for message in messages:
+        if message["role"] == "system":
+            return message
+    return None
+
+
+def remove_system_message(messages: List[dict]) -> List[dict]:
+    return [message for message in messages if message["role"] != "system"]
+
+
+def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
+    return get_system_message(messages), remove_system_message(messages)
+
+
 def add_or_update_system_message(content: str, messages: List[dict]):
     """
     Adds a new system message at the beginning of the messages list