Timothy J. Baek hai 9 meses
pai
achega
8dcb3d78dc
Modificáronse 2 ficheiros con 30 adicións e 6 borrados
  1. 25 1
      backend/apps/webui/main.py
  2. 5 5
      backend/constants.py

+ 25 - 1
backend/apps/webui/main.py

@@ -47,6 +47,8 @@ from config import (
     OAUTH_PICTURE_CLAIM,
     OAUTH_PICTURE_CLAIM,
 )
 )
 
 
+from apps.socket.main import get_event_call, get_event_emitter
+
 import inspect
 import inspect
 import uuid
 import uuid
 import time
 import time
@@ -197,8 +199,21 @@ async def generate_function_chat_completion(form_data, user):
         metadata = form_data["metadata"]
         metadata = form_data["metadata"]
         del form_data["metadata"]
         del form_data["metadata"]
 
 
+    __event_emitter__ = None
+    __event_call__ = None
+    __task__ = None
+
     if metadata:
     if metadata:
-        print(metadata)
+        if (
+            metadata.get("session_id")
+            and metadata.get("chat_id")
+            and metadata.get("message_id")
+        ):
+            __event_emitter__ = await get_event_emitter(metadata)
+            __event_call__ = await get_event_call(metadata)
+
+        if metadata.get("task"):
+            __task__ = metadata.get("task")
 
 
     if model_info:
     if model_info:
         if model_info.base_model_id:
         if model_info.base_model_id:
@@ -314,6 +329,15 @@ async def generate_function_chat_completion(form_data, user):
 
 
             params = {**params, "__user__": __user__}
             params = {**params, "__user__": __user__}
 
 
+        if "__event_emitter__" in sig.parameters:
+            params = {**params, "__event_emitter__": __event_emitter__}
+
+        if "__event_call__" in sig.parameters:
+            params = {**params, "__event_call__": __event_call__}
+
+        if "__task__" in sig.parameters:
+            params = {**params, "__task__": __task__}
+
         if form_data["stream"]:
         if form_data["stream"]:
 
 
             async def stream_content():
             async def stream_content():

+ 5 - 5
backend/constants.py

@@ -95,8 +95,8 @@ class TASKS(str, Enum):
     def __str__(self) -> str:
     def __str__(self) -> str:
         return super().__str__()
         return super().__str__()
 
 
-    DEFAULT = lambda task="": f"{task if task else 'default'}"
-    TITLE_GENERATION = "Title Generation"
-    EMOJI_GENERATION = "Emoji Generation"
-    QUERY_GENERATION = "Query Generation"
-    FUNCTION_CALLING = "Function Calling"
+    DEFAULT = lambda task="": f"{task if task else 'generation'}"
+    TITLE_GENERATION = "title_generation"
+    EMOJI_GENERATION = "emoji_generation"
+    QUERY_GENERATION = "query_generation"
+    FUNCTION_CALLING = "function_calling"