Browse Source

typing and tweaks

Michael Poluektov 8 months ago
parent
commit
d598d4bb93
2 changed files with 28 additions and 26 deletions
  1. 3 10
      backend/apps/webui/models/users.py
  2. 25 16
      backend/main.py

+ 3 - 10
backend/apps/webui/models/users.py

@@ -1,12 +1,10 @@
-from pydantic import BaseModel, ConfigDict, parse_obj_as
-from typing import Union, Optional
+from pydantic import BaseModel, ConfigDict
+from typing import Optional
 import time
 
 from sqlalchemy import String, Column, BigInteger, Text
 
-from utils.misc import get_gravatar_url
-
-from apps.webui.internal.db import Base, JSONField, Session, get_db
+from apps.webui.internal.db import Base, JSONField, get_db
 from apps.webui.models.chats import Chats
 
 ####################
@@ -78,7 +76,6 @@ class UserUpdateForm(BaseModel):
 
 
 class UsersTable:
-
     def insert_new_user(
         self,
         id: str,
@@ -122,7 +119,6 @@ class UsersTable:
     def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
         try:
             with get_db() as db:
-
                 user = db.query(User).filter_by(api_key=api_key).first()
                 return UserModel.model_validate(user)
         except Exception:
@@ -131,7 +127,6 @@ class UsersTable:
     def get_user_by_email(self, email: str) -> Optional[UserModel]:
         try:
             with get_db() as db:
-
                 user = db.query(User).filter_by(email=email).first()
                 return UserModel.model_validate(user)
         except Exception:
@@ -140,7 +135,6 @@ class UsersTable:
     def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
         try:
             with get_db() as db:
-
                 user = db.query(User).filter_by(oauth_sub=sub).first()
                 return UserModel.model_validate(user)
         except Exception:
@@ -195,7 +189,6 @@ class UsersTable:
     def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
         try:
             with get_db() as db:
-
                 db.query(User).filter_by(id=id).update(
                     {"last_active_at": int(time.time())}
                 )

+ 25 - 16
backend/main.py

@@ -57,7 +57,7 @@ from apps.webui.models.auths import Auths
 from apps.webui.models.models import Models
 from apps.webui.models.tools import Tools
 from apps.webui.models.functions import Functions
-from apps.webui.models.users import Users
+from apps.webui.models.users import Users, User
 
 from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
 
@@ -322,7 +322,7 @@ async def call_tool_from_completion(
         return None
 
 
-def get_tool_calling_payload(messages, task_model_id, content):
+def get_tool_call_payload(messages, task_model_id, content):
     user_message = get_last_user_message(messages)
     history = "\n".join(
         f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
@@ -345,13 +345,19 @@ def get_tool_calling_payload(messages, task_model_id, content):
 async def get_tool_call_response(
     messages, files, tool_id, template, task_model_id, user, extra_params
 ) -> tuple[Optional[str], Optional[dict], bool]:
+    """
+    return: tuple of (function_result, citation, file_handler) where
+    - function_result: Optional[str] is the result of the tool call if successful
+    - citation: Optional[dict] is the citation object if the tool has citation
+    - file_handler: bool, True if tool handles files
+    """
     tool = Tools.get_tool_by_id(tool_id)
     if tool is None:
         return None, None, False
 
     tools_specs = json.dumps(tool.specs, indent=2)
     content = tool_calling_generation_template(template, tools_specs)
-    payload = get_tool_calling_payload(messages, task_model_id, content)
+    payload = get_tool_call_payload(messages, task_model_id, content)
 
     try:
         payload = filter_pipeline(payload, user)
@@ -486,7 +492,9 @@ async def chat_completion_inlets_handler(body, model, extra_params):
     return body, {}
 
 
-async def chat_completion_tools_handler(body, user, extra_params):
+async def chat_completion_tools_handler(
+    body: dict, user: User, extra_params: dict
+) -> tuple[dict, dict]:
     skip_files = None
 
     contexts = []
@@ -498,21 +506,22 @@ async def chat_completion_tools_handler(body, user, extra_params):
     if "tool_ids" not in body:
         return body, {}
 
-    print(body["tool_ids"])
+    log.debug(f"tool_ids: {body['tool_ids']}")
+    kwargs = {
+        "messages": body["messages"],
+        "files": body.get("files", []),
+        "template": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
+        "task_model_id": task_model_id,
+        "user": user,
+        "extra_params": extra_params,
+    }
     for tool_id in body["tool_ids"]:
-        print(tool_id)
+        log.debug(f"{tool_id=}")
         try:
             response, citation, file_handler = await get_tool_call_response(
-                messages=body["messages"],
-                files=body.get("files", []),
-                tool_id=tool_id,
-                template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
-                task_model_id=task_model_id,
-                user=user,
-                extra_params=extra_params,
+                tool_id=tool_id, **kwargs
             )
 
-            print(file_handler)
             if isinstance(response, str):
                 contexts.append(response)
 
@@ -526,10 +535,10 @@ async def chat_completion_tools_handler(body, user, extra_params):
                 skip_files = True
 
         except Exception as e:
-            print(f"Error: {e}")
+            log.exception(f"Error: {e}")
 
     del body["tool_ids"]
-    print(f"tool_contexts: {contexts}")
+    log.debug(f"tool_contexts: {contexts}")
 
     if skip_files:
         if "files" in body: