소스 검색

add __tools__ custom param

Michael Poluektov 8 달 전
부모
커밋
13c03bfd7d
2개의 변경된 파일22개의 추가작업 그리고 16개의 파일을 삭제
  1. 19 11
      backend/apps/webui/main.py
  2. 3 5
      backend/main.py

+ 19 - 11
backend/apps/webui/main.py

@@ -26,6 +26,7 @@ from utils.misc import (
     apply_model_system_prompt_to_body,
 )
 
+from utils.tools import get_tools
 
 from config import (
     SHOW_ADMIN_DETAILS,
@@ -47,6 +48,7 @@ from config import (
     OAUTH_USERNAME_CLAIM,
     OAUTH_PICTURE_CLAIM,
     OAUTH_EMAIL_CLAIM,
+    ENABLE_TOOLS_FILTER,
 )
 
 from apps.socket.main import get_event_call, get_event_emitter
@@ -271,7 +273,7 @@ def get_function_params(function_module, form_data, user, extra_params={}):
     return params
 
 
-async def generate_function_chat_completion(form_data, user):
+async def generate_function_chat_completion(form_data, user, files, tool_ids):
     model_id = form_data.get("model")
     model_info = Models.get_model_by_id(model_id)
     metadata = form_data.pop("metadata", None)
@@ -286,6 +288,21 @@ async def generate_function_chat_completion(form_data, user):
             __event_call__ = get_event_call(metadata)
         __task__ = metadata.get("task", None)
 
+    extra_params = {
+        "__event_emitter__": __event_emitter__,
+        "__event_call__": __event_call__,
+        "__task__": __task__,
+    }
+    if not ENABLE_TOOLS_FILTER:
+        tools_params = {
+            **extra_params,
+            "__model__": app.state.MODELS[form_data["model"]],
+            "__messages__": form_data["messages"],
+            "__files__": files,
+        }
+        configured_tools = get_tools(app, tool_ids, user, tools_params)
+
+        extra_params["__tools__"] = configured_tools
     if model_info:
         if model_info.base_model_id:
             form_data["model"] = model_info.base_model_id
@@ -298,16 +315,7 @@ async def generate_function_chat_completion(form_data, user):
     function_module = get_function_module(pipe_id)
 
     pipe = function_module.pipe
-    params = get_function_params(
-        function_module,
-        form_data,
-        user,
-        {
-            "__event_emitter__": __event_emitter__,
-            "__event_call__": __event_call__,
-            "__task__": __task__,
-        },
-    )
+    params = get_function_params(function_module, form_data, user, extra_params)
 
     if form_data["stream"]:
 

+ 3 - 5
backend/main.py

@@ -994,13 +994,11 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
             detail="Model not found",
         )
     model = app.state.MODELS[model_id]
+    files = form_data.pop("files", None)
+    tool_ids = form_data.pop("tool_ids", None)
 
     if model.get("pipe"):
-        return await generate_function_chat_completion(form_data, user=user)
-
-    for key in ["tool_ids", "files"]:
-        if key in form_data:
-            del form_data[key]
+        return await generate_function_chat_completion(form_data, user, files, tool_ids)
     if model["owned_by"] == "ollama":
         return await generate_ollama_chat_completion(form_data, user=user)
     else: