|
@@ -33,7 +33,7 @@ from starlette.middleware.sessions import SessionMiddleware
|
|
|
from starlette.responses import StreamingResponse, Response, RedirectResponse
|
|
|
|
|
|
|
|
|
-from apps.socket.main import app as socket_app
|
|
|
+from apps.socket.main import sio, app as socket_app
|
|
|
from apps.ollama.main import (
|
|
|
app as ollama_app,
|
|
|
OpenAIChatCompletionForm,
|
|
@@ -277,7 +277,14 @@ def get_filter_function_ids(model):
|
|
|
|
|
|
|
|
|
async def get_function_call_response(
|
|
|
- messages, files, tool_id, template, task_model_id, user, model
|
|
|
+ messages,
|
|
|
+ files,
|
|
|
+ tool_id,
|
|
|
+ template,
|
|
|
+ task_model_id,
|
|
|
+ user,
|
|
|
+ model,
|
|
|
+ __event_emitter__=None,
|
|
|
):
|
|
|
tool = Tools.get_tool_by_id(tool_id)
|
|
|
tools_specs = json.dumps(tool.specs, indent=2)
|
|
@@ -414,6 +421,13 @@ async def get_function_call_response(
|
|
|
"__id__": tool_id,
|
|
|
}
|
|
|
|
|
|
+ if "__event_emitter__" in sig.parameters:
|
|
|
+ # Call the function with the '__event_emitter__' parameter included
|
|
|
+ params = {
|
|
|
+ **params,
|
|
|
+ "__event_emitter__": model,
|
|
|
+ }
|
|
|
+
|
|
|
if inspect.iscoroutinefunction(function):
|
|
|
function_result = await function(**params)
|
|
|
else:
|
|
@@ -437,7 +451,7 @@ async def get_function_call_response(
|
|
|
return None, None, False
|
|
|
|
|
|
|
|
|
-async def chat_completion_functions_handler(body, model, user):
|
|
|
+async def chat_completion_functions_handler(body, model, user, __event_emitter__):
|
|
|
skip_files = None
|
|
|
|
|
|
filter_ids = get_filter_function_ids(model)
|
|
@@ -503,6 +517,11 @@ async def chat_completion_functions_handler(body, model, user):
|
|
|
**params,
|
|
|
"__model__": model,
|
|
|
}
|
|
|
+ if "__event_emitter__" in sig.parameters:
|
|
|
+ params = {
|
|
|
+ **params,
|
|
|
+ "__event_emitter__": __event_emitter__,
|
|
|
+ }
|
|
|
|
|
|
if inspect.iscoroutinefunction(inlet):
|
|
|
body = await inlet(**params)
|
|
@@ -520,7 +539,7 @@ async def chat_completion_functions_handler(body, model, user):
|
|
|
return body, {}
|
|
|
|
|
|
|
|
|
-async def chat_completion_tools_handler(body, model, user):
|
|
|
+async def chat_completion_tools_handler(body, model, user, __event_emitter__):
|
|
|
skip_files = None
|
|
|
|
|
|
contexts = []
|
|
@@ -542,6 +561,7 @@ async def chat_completion_tools_handler(body, model, user):
|
|
|
task_model_id=task_model_id,
|
|
|
user=user,
|
|
|
model=model,
|
|
|
+ __event_emitter__=__event_emitter__,
|
|
|
)
|
|
|
|
|
|
print(file_handler)
|
|
@@ -614,7 +634,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
content={"detail": str(e)},
|
|
|
)
|
|
|
|
|
|
- # Extract chat_id and message_id from the request body
|
|
|
+ # Extract session_id, chat_id and message_id from the request body
|
|
|
+ session_id = None
|
|
|
+ if "session_id" in body:
|
|
|
+ session_id = body["session_id"]
|
|
|
+ del body["session_id"]
|
|
|
chat_id = None
|
|
|
if "chat_id" in body:
|
|
|
chat_id = body["chat_id"]
|
|
@@ -624,6 +648,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
message_id = body["id"]
|
|
|
del body["id"]
|
|
|
|
|
|
+ async def __event_emitter__(data):
|
|
|
+ await sio.emit(
|
|
|
+ "chat-events",
|
|
|
+ {
|
|
|
+ "chat_id": chat_id,
|
|
|
+ "message_id": message_id,
|
|
|
+ "data": data,
|
|
|
+ },
|
|
|
+ to=session_id,
|
|
|
+ )
|
|
|
+
|
|
|
# Initialize data_items to store additional data to be sent to the client
|
|
|
data_items = []
|
|
|
|
|
@@ -631,10 +666,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
contexts = []
|
|
|
citations = []
|
|
|
|
|
|
- print(body)
|
|
|
-
|
|
|
try:
|
|
|
- body, flags = await chat_completion_functions_handler(body, model, user)
|
|
|
+ body, flags = await chat_completion_functions_handler(
|
|
|
+ body, model, user, __event_emitter__
|
|
|
+ )
|
|
|
except Exception as e:
|
|
|
return JSONResponse(
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
@@ -642,7 +677,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
)
|
|
|
|
|
|
try:
|
|
|
- body, flags = await chat_completion_tools_handler(body, model, user)
|
|
|
+ body, flags = await chat_completion_tools_handler(
|
|
|
+ body, model, user, __event_emitter__
|
|
|
+ )
|
|
|
|
|
|
contexts.extend(flags.get("contexts", []))
|
|
|
citations.extend(flags.get("citations", []))
|