Timothy J. Baek 8 maanden geleden
bovenliggende
commit
c4946d42e0
1 gewijzigde bestanden met toevoegingen van 31 en 27 verwijderingen
  1. 31 27
      backend/main.py

+ 31 - 27
backend/main.py

@@ -283,21 +283,6 @@ def get_filter_function_ids(model):
     return filter_ids
 
 
-async def get_content_from_response(response) -> Optional[str]:
-    content = None
-    if hasattr(response, "body_iterator"):
-        async for chunk in response.body_iterator:
-            data = json.loads(chunk.decode("utf-8"))
-            content = data["choices"][0]["message"]["content"]
-
-        # Cleanup any remaining background tasks if necessary
-        if response.background is not None:
-            await response.background()
-    else:
-        content = response["choices"][0]["message"]["content"]
-    return content
-
-
 def get_tool_call_payload(messages, task_model_id, content):
     user_message = get_last_user_message(messages)
     history = "\n".join(
@@ -403,8 +388,8 @@ def get_tool_with_custom_params(
 
 
 # Mutation on extra_params
-def get_configured_tools(
-    tool_ids: list[str], extra_params: dict, user: UserModel
+def get_tools(
+    tool_ids: list[str], user: UserModel, extra_params: dict
 ) -> dict[str, dict]:
     tools = {}
     for tool_id in tool_ids:
@@ -420,6 +405,7 @@ def get_configured_tools(
         extra_params["__id__"] = tool_id
         has_citation = hasattr(module, "citation") and module.citation
         handles_files = hasattr(module, "file_handler") and module.file_handler
+
         if hasattr(module, "valves") and hasattr(module, "Valves"):
             valves = Tools.get_tool_valves_by_id(tool_id) or {}
             module.valves = module.Valves(**valves)
@@ -459,35 +445,51 @@ def get_configured_tools(
     return tools
 
 
+async def get_content_from_response(response) -> Optional[str]:
+    content = None
+    if hasattr(response, "body_iterator"):
+        async for chunk in response.body_iterator:
+            data = json.loads(chunk.decode("utf-8"))
+            content = data["choices"][0]["message"]["content"]
+
+        # Cleanup any remaining background tasks if necessary
+        if response.background is not None:
+            await response.background()
+    else:
+        content = response["choices"][0]["message"]["content"]
+    return content
+
+
 async def chat_completion_tools_handler(
     body: dict, user: UserModel, extra_params: dict
 ) -> tuple[dict, dict]:
     skip_files = False
     contexts = []
     citations = []
-    task_model_id = get_task_model_id(body["model"])
 
+    task_model_id = get_task_model_id(body["model"])
     # If tool_ids field is present, call the functions
     tool_ids = body.pop("tool_ids", None)
     if not tool_ids:
         return body, {}
 
     log.debug(f"{tool_ids=}")
+
     custom_params = {
         **extra_params,
         "__model__": app.state.MODELS[task_model_id],
         "__messages__": body["messages"],
         "__files__": body.get("files", []),
     }
-    configured_tools = get_configured_tools(tool_ids, custom_params, user)
+    tools = get_tools(tool_ids, user, custom_params)
+    log.info(f"{tools=}")
 
-    log.info(f"{configured_tools=}")
-
-    specs = [tool["spec"] for tool in configured_tools.values()]
+    specs = [tool["spec"] for tool in tools.values()]
     tools_specs = json.dumps(specs)
     template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
     content = tool_calling_generation_template(template, tools_specs)
     payload = get_tool_call_payload(body["messages"], task_model_id, content)
+
     try:
         payload = filter_pipeline(payload, user)
     except Exception as e:
@@ -503,16 +505,18 @@ async def chat_completion_tools_handler(
 
         result = json.loads(content)
         tool_name = result.get("name", None)
-        if tool_name not in configured_tools:
+        if tool_name not in tools:
             return body, {}
 
         tool_params = result.get("parameters", {})
-        toolkit_id = configured_tools[tool_name]["toolkit_id"]
+        toolkit_id = tools[tool_name]["toolkit_id"]
+
         try:
-            tool_output = await configured_tools[tool_name]["callable"](**tool_params)
+            tool_output = await tools[tool_name]["callable"](**tool_params)
         except Exception as e:
             tool_output = str(e)
-        if configured_tools[tool_name]["citation"]:
+
+        if tools[tool_name]["citation"]:
             citations.append(
                 {
                     "source": {"name": f"TOOL:{toolkit_id}/{tool_name}"},
@@ -520,7 +524,7 @@ async def chat_completion_tools_handler(
                     "metadata": [{"source": tool_name}],
                 }
             )
-        if configured_tools[tool_name]["file_handler"]:
+        if tools[tool_name]["file_handler"]:
             skip_files = True
 
         if isinstance(tool_output, str):