Timothy Jaeryang Baek преди 1 месец
родител
ревизия
3b74431ea3
променени са 3 файла, в които са добавени 24 реда и са изтрити 10 реда
  1. 6 1
      backend/open_webui/utils/chat.py
  2. 4 3
      backend/open_webui/utils/filter.py
  3. 14 6
      backend/open_webui/utils/middleware.py

+ 6 - 1
backend/open_webui/utils/chat.py

@@ -328,9 +328,14 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
     }
 
     try:
+        filter_functions = [
+            Functions.get_function_by_id(filter_id)
+            for filter_id in get_sorted_filter_ids(model)
+        ]
+
         result, _ = await process_filter_functions(
             request=request,
-            filter_ids=get_sorted_filter_ids(model),
+            filter_functions=filter_functions,
             filter_type="outlet",
             form_data=data,
             extra_params=extra_params,

+ 4 - 3
backend/open_webui/utils/filter.py

@@ -33,12 +33,13 @@ def get_sorted_filter_ids(model: dict):
 
 
 async def process_filter_functions(
-    request, filter_ids, filter_type, form_data, extra_params
+    request, filter_functions, filter_type, form_data, extra_params
 ):
     skip_files = None
 
-    for filter_id in filter_ids:
-        filter = Functions.get_function_by_id(filter_id)
+    for function in filter_functions:
+        filter = function
+        filter_id = function.id
         if not filter:
             continue
 

+ 14 - 6
backend/open_webui/utils/middleware.py

@@ -715,9 +715,14 @@ async def process_chat_payload(request, form_data, user, metadata, model):
         raise e
 
     try:
+        filter_functions = [
+            Functions.get_function_by_id(filter_id)
+            for filter_id in get_sorted_filter_ids(model)
+        ]
+
         form_data, flags = await process_filter_functions(
             request=request,
-            filter_ids=get_sorted_filter_ids(model),
+            filter_functions=filter_functions,
             filter_type="inlet",
             form_data=form_data,
             extra_params=extra_params,
@@ -1071,9 +1076,12 @@ async def process_chat_response(
         "__request__": request,
         "__model__": model,
     }
-    filter_ids = get_sorted_filter_ids(model)
+    filter_functions = [
+        Functions.get_function_by_id(filter_id)
+        for filter_id in get_sorted_filter_ids(model)
+    ]
 
-    print(f"{filter_ids=}")
+    print(f"{filter_functions=}")
 
     # Streaming response
     if event_emitter and event_caller:
@@ -1480,7 +1488,7 @@ async def process_chat_response(
 
                             data, _ = await process_filter_functions(
                                 request=request,
-                                filter_ids=filter_ids,
+                                filter_functions=filter_functions,
                                 filter_type="stream",
                                 form_data=data,
                                 extra_params=extra_params,
@@ -2077,7 +2085,7 @@ async def process_chat_response(
             for event in events:
                 event, _ = await process_filter_functions(
                     request=request,
-                    filter_ids=filter_ids,
+                    filter_functions=filter_functions,
                     filter_type="stream",
                     form_data=event,
                     extra_params=extra_params,
@@ -2089,7 +2097,7 @@ async def process_chat_response(
             async for data in original_generator:
                 data, _ = await process_filter_functions(
                     request=request,
-                    filter_ids=filter_ids,
+                    filter_functions=filter_functions,
                     filter_type="stream",
                     form_data=data,
                     extra_params=extra_params,