|
@@ -302,6 +302,7 @@ async def get_function_call_response(
|
|
user,
|
|
user,
|
|
model,
|
|
model,
|
|
__event_emitter__=None,
|
|
__event_emitter__=None,
|
|
|
|
+ __event_call__=None,
|
|
):
|
|
):
|
|
tool = Tools.get_tool_by_id(tool_id)
|
|
tool = Tools.get_tool_by_id(tool_id)
|
|
tools_specs = json.dumps(tool.specs, indent=2)
|
|
tools_specs = json.dumps(tool.specs, indent=2)
|
|
@@ -445,6 +446,13 @@ async def get_function_call_response(
|
|
"__event_emitter__": __event_emitter__,
|
|
"__event_emitter__": __event_emitter__,
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ if "__event_call__" in sig.parameters:
|
|
|
|
+ # Call the function with the '__event_call__' parameter included
|
|
|
|
+ params = {
|
|
|
|
+ **params,
|
|
|
|
+ "__event_call__": __event_call__,
|
|
|
|
+ }
|
|
|
|
+
|
|
if inspect.iscoroutinefunction(function):
|
|
if inspect.iscoroutinefunction(function):
|
|
function_result = await function(**params)
|
|
function_result = await function(**params)
|
|
else:
|
|
else:
|
|
@@ -468,7 +476,9 @@ async def get_function_call_response(
|
|
return None, None, False
|
|
return None, None, False
|
|
|
|
|
|
|
|
|
|
-async def chat_completion_functions_handler(body, model, user, __event_emitter__):
|
|
|
|
|
|
+async def chat_completion_functions_handler(
|
|
|
|
+ body, model, user, __event_emitter__, __event_call__
|
|
|
|
+):
|
|
skip_files = None
|
|
skip_files = None
|
|
|
|
|
|
filter_ids = get_filter_function_ids(model)
|
|
filter_ids = get_filter_function_ids(model)
|
|
@@ -534,12 +544,19 @@ async def chat_completion_functions_handler(body, model, user, __event_emitter__
|
|
**params,
|
|
**params,
|
|
"__model__": model,
|
|
"__model__": model,
|
|
}
|
|
}
|
|
|
|
+
|
|
if "__event_emitter__" in sig.parameters:
|
|
if "__event_emitter__" in sig.parameters:
|
|
params = {
|
|
params = {
|
|
**params,
|
|
**params,
|
|
"__event_emitter__": __event_emitter__,
|
|
"__event_emitter__": __event_emitter__,
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ if "__event_call__" in sig.parameters:
|
|
|
|
+ params = {
|
|
|
|
+ **params,
|
|
|
|
+ "__event_call__": __event_call__,
|
|
|
|
+ }
|
|
|
|
+
|
|
if inspect.iscoroutinefunction(inlet):
|
|
if inspect.iscoroutinefunction(inlet):
|
|
body = await inlet(**params)
|
|
body = await inlet(**params)
|
|
else:
|
|
else:
|
|
@@ -556,7 +573,9 @@ async def chat_completion_functions_handler(body, model, user, __event_emitter__
|
|
return body, {}
|
|
return body, {}
|
|
|
|
|
|
|
|
|
|
-async def chat_completion_tools_handler(body, model, user, __event_emitter__):
|
|
|
|
|
|
+async def chat_completion_tools_handler(
|
|
|
|
+ body, model, user, __event_emitter__, __event_call__
|
|
|
|
+):
|
|
skip_files = None
|
|
skip_files = None
|
|
|
|
|
|
contexts = []
|
|
contexts = []
|
|
@@ -579,6 +598,7 @@ async def chat_completion_tools_handler(body, model, user, __event_emitter__):
|
|
user=user,
|
|
user=user,
|
|
model=model,
|
|
model=model,
|
|
__event_emitter__=__event_emitter__,
|
|
__event_emitter__=__event_emitter__,
|
|
|
|
+ __event_call__=__event_call__,
|
|
)
|
|
)
|
|
|
|
|
|
print(file_handler)
|
|
print(file_handler)
|
|
@@ -676,6 +696,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
to=session_id,
|
|
to=session_id,
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ async def __event_call__(data):
|
|
|
|
+ response = await sio.call(
|
|
|
|
+ "chat-events",
|
|
|
|
+ {"chat_id": chat_id, "message_id": message_id, "data": data},
|
|
|
|
+ to=session_id,
|
|
|
|
+ )
|
|
|
|
+ return response
|
|
|
|
+
|
|
# Initialize data_items to store additional data to be sent to the client
|
|
# Initialize data_items to store additional data to be sent to the client
|
|
data_items = []
|
|
data_items = []
|
|
|
|
|
|
@@ -685,7 +713,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
|
|
|
try:
|
|
try:
|
|
body, flags = await chat_completion_functions_handler(
|
|
body, flags = await chat_completion_functions_handler(
|
|
- body, model, user, __event_emitter__
|
|
|
|
|
|
+ body, model, user, __event_emitter__, __event_call__
|
|
)
|
|
)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
return JSONResponse(
|
|
@@ -695,7 +723,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
|
|
|
try:
|
|
try:
|
|
body, flags = await chat_completion_tools_handler(
|
|
body, flags = await chat_completion_tools_handler(
|
|
- body, model, user, __event_emitter__
|
|
|
|
|
|
+ body, model, user, __event_emitter__, __event_call__
|
|
)
|
|
)
|
|
|
|
|
|
contexts.extend(flags.get("contexts", []))
|
|
contexts.extend(flags.get("contexts", []))
|