Browse Source

refac: middleware

Timothy Jaeryang Baek 2 months ago
parent
commit
e7da506add
3 changed files with 58 additions and 39 deletions
  1. 3 1
      backend/open_webui/main.py
  2. 54 37
      backend/open_webui/utils/middleware.py
  3. 1 1
      src/lib/apis/openai/index.ts

+ 3 - 1
backend/open_webui/main.py

@@ -889,9 +889,10 @@ async def chat_completion(
         }
         form_data["metadata"] = metadata
 
-        form_data, events = await process_chat_payload(
+        form_data, metadata, events = await process_chat_payload(
             request, form_data, metadata, user, model
         )
+
     except Exception as e:
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
@@ -900,6 +901,7 @@ async def chat_completion(
 
     try:
         response = await chat_completion_handler(request, form_data, user)
+
         return await process_chat_response(
             request, response, form_data, user, events, metadata, tasks
         )

+ 54 - 37
backend/open_webui/utils/middleware.py

@@ -183,7 +183,7 @@ async def chat_completion_filter_functions_handler(request, body, model, extra_p
 
 
 async def chat_completion_tools_handler(
-    request: Request, body: dict, user: UserModel, models, extra_params: dict
+    request: Request, body: dict, user: UserModel, models, tools
 ) -> tuple[dict, dict]:
     async def get_content_from_response(response) -> Optional[str]:
         content = None
@@ -218,35 +218,15 @@ async def chat_completion_tools_handler(
             "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
         }
 
-    # If tool_ids field is present, call the functions
-    metadata = body.get("metadata", {})
-
-    tool_ids = metadata.get("tool_ids", None)
-    log.debug(f"{tool_ids=}")
-    if not tool_ids:
-        return body, {}
-
-    skip_files = False
-    sources = []
-
     task_model_id = get_task_model_id(
         body["model"],
         request.app.state.config.TASK_MODEL,
         request.app.state.config.TASK_MODEL_EXTERNAL,
         models,
     )
-    tools = get_tools(
-        request,
-        tool_ids,
-        user,
-        {
-            **extra_params,
-            "__model__": models[task_model_id],
-            "__messages__": body["messages"],
-            "__files__": metadata.get("files", []),
-        },
-    )
-    log.info(f"{tools=}")
+
+    skip_files = False
+    sources = []
 
     specs = [tool["spec"] for tool in tools.values()]
     tools_specs = json.dumps(specs)
@@ -281,6 +261,8 @@ async def chat_completion_tools_handler(
             result = json.loads(content)
 
             async def tool_call_handler(tool_call):
+                nonlocal skip_files
+
                 log.debug(f"{tool_call=}")
 
                 tool_function_name = tool_call.get("name", None)
@@ -725,6 +707,12 @@ async def process_chat_payload(request, form_data, metadata, user, model):
     # Initialize events to store additional event to be sent to the client
     # Initialize contexts and citation
     models = request.app.state.MODELS
+    task_model_id = get_task_model_id(
+        form_data["model"],
+        request.app.state.config.TASK_MODEL,
+        request.app.state.config.TASK_MODEL_EXTERNAL,
+        models,
+    )
 
     events = []
     sources = []
@@ -809,15 +797,41 @@ async def process_chat_payload(request, form_data, metadata, user, model):
     }
     form_data["metadata"] = metadata
 
-    if not form_data["metadata"].get("function_calling") == "native":
-        # If the function calling is not native, then call the tools function calling handler
-        try:
-            form_data, flags = await chat_completion_tools_handler(
-                request, form_data, user, models, extra_params
-            )
-            sources.extend(flags.get("sources", []))
-        except Exception as e:
-            log.exception(e)
+    tool_ids = metadata.get("tool_ids", None)
+    log.debug(f"{tool_ids=}")
+
+    if tool_ids:
+        # If tool_ids field is present, then get the tools
+        tools = get_tools(
+            request,
+            tool_ids,
+            user,
+            {
+                **extra_params,
+                "__model__": models[task_model_id],
+                "__messages__": form_data["messages"],
+                "__files__": metadata.get("files", []),
+            },
+        )
+        log.info(f"{tools=}")
+
+        if metadata.get("function_calling") == "native":
+            # If the function calling is native, then call the tools function calling handler
+            metadata["tools"] = tools
+            form_data["tools"] = [
+                {"type": "function", "function": tool.get("spec", {})}
+                for tool in tools.values()
+            ]
+        else:
+            # If the function calling is not native, then call the tools function calling handler
+            try:
+                form_data, flags = await chat_completion_tools_handler(
+                    request, form_data, user, models, tools
+                )
+                sources.extend(flags.get("sources", []))
+
+            except Exception as e:
+                log.exception(e)
 
     try:
         form_data, flags = await chat_completion_files_handler(request, form_data, user)
@@ -833,11 +847,11 @@ async def process_chat_payload(request, form_data, metadata, user, model):
 
             if "document" in source:
                 for doc_idx, doc_context in enumerate(source["document"]):
-                    metadata = source.get("metadata")
+                    doc_metadata = source.get("metadata")
                     doc_source_id = None
 
-                    if metadata:
-                        doc_source_id = metadata[doc_idx].get("source", source_id)
+                    if doc_metadata:
+                        doc_source_id = doc_metadata[doc_idx].get("source", source_id)
 
                     if source_id:
                         context_string += f"<source><source_id>{doc_source_id if doc_source_id is not None else source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
@@ -894,12 +908,15 @@ async def process_chat_payload(request, form_data, metadata, user, model):
             }
         )
 
-    return form_data, events
+    return form_data, metadata, events
 
 
 async def process_chat_response(
     request, response, form_data, user, events, metadata, tasks
 ):
+
+    print("metadata", metadata)
+
     async def background_tasks_handler():
         message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
         message = message_map.get(metadata["message_id"]) if message_map else None

+ 1 - 1
src/lib/apis/openai/index.ts

@@ -322,7 +322,7 @@ export const generateOpenAIChatCompletion = async (
 			return res.json();
 		})
 		.catch((err) => {
-			error = `${err?.detail ?? 'Network Problem'}`;
+			error = `${err?.detail ?? err}`;
 			return null;
 		});