Browse Source

chore: format

Timothy J. Baek 6 months ago
parent
commit
9936583477

+ 1 - 1
backend/open_webui/apps/ollama/main.py

@@ -547,7 +547,7 @@ class GenerateEmbeddingsForm(BaseModel):
 
 class GenerateEmbedForm(BaseModel):
     model: str
-    input: list[str]|str
+    input: list[str] | str
     truncate: Optional[bool] = None
     options: Optional[dict] = None
     keep_alive: Optional[Union[int, str]] = None

+ 4 - 6
backend/open_webui/apps/retrieval/vector/dbs/chroma.py

@@ -110,9 +110,8 @@ class ChromaClient:
     def insert(self, collection_name: str, items: list[VectorItem]):
         # Insert the items into the collection, if the collection does not exist, it will be created.
         collection = self.client.get_or_create_collection(
-            name=collection_name,
-            metadata={"hnsw:space": "cosine"}
-            )
+            name=collection_name, metadata={"hnsw:space": "cosine"}
+        )
 
         ids = [item["id"] for item in items]
         documents = [item["text"] for item in items]
@@ -131,9 +130,8 @@ class ChromaClient:
     def upsert(self, collection_name: str, items: list[VectorItem]):
         # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
         collection = self.client.get_or_create_collection(
-            name=collection_name,
-            metadata={"hnsw:space": "cosine"}
-            )
+            name=collection_name, metadata={"hnsw:space": "cosine"}
+        )
 
         ids = [item["id"] for item in items]
         documents = [item["text"] for item in items]

+ 24 - 21
backend/open_webui/apps/retrieval/vector/dbs/qdrant.py

@@ -9,6 +9,7 @@ from open_webui.config import QDRANT_URI
 
 NO_LIMIT = 999999999
 
+
 class QdrantClient:
     def __init__(self):
         self.collection_prefix = "open-webui"
@@ -38,15 +39,15 @@ class QdrantClient:
         collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
         self.client.create_collection(
             collection_name=collection_name_with_prefix,
-            vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE),
+            vectors_config=models.VectorParams(
+                size=dimension, distance=models.Distance.COSINE
+            ),
         )
 
         print(f"collection {collection_name_with_prefix} successfully created!")
 
     def _create_collection_if_not_exists(self, collection_name, dimension):
-        if not self.has_collection(
-                collection_name=collection_name
-        ):
+        if not self.has_collection(collection_name=collection_name):
             self._create_collection(
                 collection_name=collection_name, dimension=dimension
             )
@@ -56,22 +57,23 @@ class QdrantClient:
             PointStruct(
                 id=item["id"],
                 vector=item["vector"],
-                payload={
-                    "text": item["text"],
-                    "metadata": item["metadata"]
-                },
+                payload={"text": item["text"], "metadata": item["metadata"]},
             )
             for item in items
         ]
 
     def has_collection(self, collection_name: str) -> bool:
-        return self.client.collection_exists(f"{self.collection_prefix}_{collection_name}")
+        return self.client.collection_exists(
+            f"{self.collection_prefix}_{collection_name}"
+        )
 
     def delete_collection(self, collection_name: str):
-        return self.client.delete_collection(collection_name=f"{self.collection_prefix}_{collection_name}")
+        return self.client.delete_collection(
+            collection_name=f"{self.collection_prefix}_{collection_name}"
+        )
 
     def search(
-            self, collection_name: str, vectors: list[list[float | int]], limit: int
+        self, collection_name: str, vectors: list[list[float | int]], limit: int
     ) -> Optional[SearchResult]:
         # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
         if limit is None:
@@ -87,7 +89,7 @@ class QdrantClient:
             ids=get_result.ids,
             documents=get_result.documents,
             metadatas=get_result.metadatas,
-            distances=[[point.score for point in query_response.points]]
+            distances=[[point.score for point in query_response.points]],
         )
 
     def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
@@ -101,7 +103,10 @@ class QdrantClient:
             field_conditions = []
             for key, value in filter.items():
                 field_conditions.append(
-                    models.FieldCondition(key=f"metadata.{key}", match=models.MatchValue(value=value)))
+                    models.FieldCondition(
+                        key=f"metadata.{key}", match=models.MatchValue(value=value)
+                    )
+                )
 
             points = self.client.query_points(
                 collection_name=f"{self.collection_prefix}_{collection_name}",
@@ -117,7 +122,7 @@ class QdrantClient:
         # Get all the items in the collection.
         points = self.client.query_points(
             collection_name=f"{self.collection_prefix}_{collection_name}",
-            limit=NO_LIMIT  # otherwise qdrant would set limit to 10!
+            limit=NO_LIMIT,  # otherwise qdrant would set limit to 10!
         )
         return self._result_to_get_result(points.points)
 
@@ -134,10 +139,10 @@ class QdrantClient:
         return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
 
     def delete(
-            self,
-            collection_name: str,
-            ids: Optional[list[str]] = None,
-            filter: Optional[dict] = None,
+        self,
+        collection_name: str,
+        ids: Optional[list[str]] = None,
+        filter: Optional[dict] = None,
     ):
         # Delete the items from the collection based on the ids.
         field_conditions = []
@@ -162,9 +167,7 @@ class QdrantClient:
         return self.client.delete(
             collection_name=f"{self.collection_prefix}_{collection_name}",
             points_selector=models.FilterSelector(
-                filter=models.Filter(
-                    must=field_conditions
-                )
+                filter=models.Filter(must=field_conditions)
             ),
         )
 

+ 5 - 1
backend/open_webui/config.py

@@ -409,7 +409,10 @@ OAUTH_ROLES_CLAIM = PersistentConfig(
 OAUTH_ALLOWED_ROLES = PersistentConfig(
     "OAUTH_ALLOWED_ROLES",
     "oauth.allowed_roles",
-    [role.strip() for role in os.environ.get("OAUTH_ALLOWED_ROLES", "user,admin").split(",")],
+    [
+        role.strip()
+        for role in os.environ.get("OAUTH_ALLOWED_ROLES", "user,admin").split(",")
+    ],
 )
 
 OAUTH_ADMIN_ROLES = PersistentConfig(
@@ -418,6 +421,7 @@ OAUTH_ADMIN_ROLES = PersistentConfig(
     [role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
 )
 
+
 def load_oauth_providers():
     OAUTH_PROVIDERS.clear()
     if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value:

+ 33 - 35
backend/open_webui/main.py

@@ -208,8 +208,6 @@ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
 app.state.MODELS = {}
 
 
-
-
 ##################################
 #
 # ChatCompletion Middleware
@@ -223,14 +221,14 @@ def get_task_model_id(default_model_id):
     # Check if the user has a custom task model and use that model
     if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
         if (
-                app.state.config.TASK_MODEL
-                and app.state.config.TASK_MODEL in app.state.MODELS
+            app.state.config.TASK_MODEL
+            and app.state.config.TASK_MODEL in app.state.MODELS
         ):
             task_model_id = app.state.config.TASK_MODEL
     else:
         if (
-                app.state.config.TASK_MODEL_EXTERNAL
-                and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
+            app.state.config.TASK_MODEL_EXTERNAL
+            and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
         ):
             task_model_id = app.state.config.TASK_MODEL_EXTERNAL
 
@@ -367,7 +365,7 @@ async def get_content_from_response(response) -> Optional[str]:
 
 
 async def chat_completion_tools_handler(
-        body: dict, user: UserModel, extra_params: dict
+    body: dict, user: UserModel, extra_params: dict
 ) -> tuple[dict, dict]:
     # If tool_ids field is present, call the functions
     metadata = body.get("metadata", {})
@@ -681,15 +679,15 @@ def get_sorted_filters(model_id):
         model
         for model in app.state.MODELS.values()
         if "pipeline" in model
-           and "type" in model["pipeline"]
-           and model["pipeline"]["type"] == "filter"
-           and (
-                   model["pipeline"]["pipelines"] == ["*"]
-                   or any(
-               model_id == target_model_id
-               for target_model_id in model["pipeline"]["pipelines"]
-           )
-           )
+        and "type" in model["pipeline"]
+        and model["pipeline"]["type"] == "filter"
+        and (
+            model["pipeline"]["pipelines"] == ["*"]
+            or any(
+                model_id == target_model_id
+                for target_model_id in model["pipeline"]["pipelines"]
+            )
+        )
     ]
     sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
     return sorted_filters
@@ -875,8 +873,8 @@ async def update_embedding_function(request: Request, call_next):
 @app.middleware("http")
 async def inspect_websocket(request: Request, call_next):
     if (
-            "/ws/socket.io" in request.url.path
-            and request.query_params.get("transport") == "websocket"
+        "/ws/socket.io" in request.url.path
+        and request.query_params.get("transport") == "websocket"
     ):
         upgrade = (request.headers.get("Upgrade") or "").lower()
         connection = (request.headers.get("Connection") or "").lower().split(",")
@@ -945,8 +943,8 @@ async def get_all_models():
         if custom_model.base_model_id is None:
             for model in models:
                 if (
-                        custom_model.id == model["id"]
-                        or custom_model.id == model["id"].split(":")[0]
+                    custom_model.id == model["id"]
+                    or custom_model.id == model["id"].split(":")[0]
                 ):
                     model["name"] = custom_model.name
                     model["info"] = custom_model.model_dump()
@@ -963,8 +961,8 @@ async def get_all_models():
 
             for model in models:
                 if (
-                        custom_model.base_model_id == model["id"]
-                        or custom_model.base_model_id == model["id"].split(":")[0]
+                    custom_model.base_model_id == model["id"]
+                    or custom_model.base_model_id == model["id"].split(":")[0]
                 ):
                     owned_by = model["owned_by"]
                     if "pipe" in model:
@@ -1840,7 +1838,7 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
 
 @app.post("/api/pipelines/upload")
 async def upload_pipeline(
-        urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
+    urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
 ):
     print("upload_pipeline", urlIdx, file.filename)
     # Check if the uploaded file is a python file
@@ -2017,9 +2015,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
 
 @app.get("/api/pipelines/{pipeline_id}/valves")
 async def get_pipeline_valves(
-        urlIdx: Optional[int],
-        pipeline_id: str,
-        user=Depends(get_admin_user),
+    urlIdx: Optional[int],
+    pipeline_id: str,
+    user=Depends(get_admin_user),
 ):
     r = None
     try:
@@ -2055,9 +2053,9 @@ async def get_pipeline_valves(
 
 @app.get("/api/pipelines/{pipeline_id}/valves/spec")
 async def get_pipeline_valves_spec(
-        urlIdx: Optional[int],
-        pipeline_id: str,
-        user=Depends(get_admin_user),
+    urlIdx: Optional[int],
+    pipeline_id: str,
+    user=Depends(get_admin_user),
 ):
     r = None
     try:
@@ -2092,10 +2090,10 @@ async def get_pipeline_valves_spec(
 
 @app.post("/api/pipelines/{pipeline_id}/valves/update")
 async def update_pipeline_valves(
-        urlIdx: Optional[int],
-        pipeline_id: str,
-        form_data: dict,
-        user=Depends(get_admin_user),
+    urlIdx: Optional[int],
+    pipeline_id: str,
+    form_data: dict,
+    user=Depends(get_admin_user),
 ):
     r = None
     try:
@@ -2219,7 +2217,7 @@ class ModelFilterConfigForm(BaseModel):
 
 @app.post("/api/config/model/filter")
 async def update_model_filter_config(
-        form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
+    form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
 ):
     app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
     app.state.config.MODEL_FILTER_LIST = form_data.models
@@ -2274,7 +2272,7 @@ async def get_app_latest_release_version():
         timeout = aiohttp.ClientTimeout(total=1)
         async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
             async with session.get(
-                    "https://api.github.com/repos/open-webui/open-webui/releases/latest"
+                "https://api.github.com/repos/open-webui/open-webui/releases/latest"
             ) as response:
                 response.raise_for_status()
                 data = await response.json()

+ 18 - 8
backend/open_webui/utils/oauth.py

@@ -25,7 +25,10 @@ from open_webui.config import (
     OAUTH_PICTURE_CLAIM,
     OAUTH_USERNAME_CLAIM,
     OAUTH_ALLOWED_ROLES,
-    OAUTH_ADMIN_ROLES, WEBHOOK_URL, JWT_EXPIRES_IN, AppConfig,
+    OAUTH_ADMIN_ROLES,
+    WEBHOOK_URL,
+    JWT_EXPIRES_IN,
+    AppConfig,
 )
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.env import WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE
@@ -170,7 +173,9 @@ class OAuthManager:
             # If the user does not exist, check if signups are enabled
             if auth_manager_config.ENABLE_OAUTH_SIGNUP.value:
                 # Check if an existing user with the same email already exists
-                existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
+                existing_user = Users.get_user_by_email(
+                    user_data.get("email", "").lower()
+                )
                 if existing_user:
                     raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
 
@@ -182,16 +187,18 @@ class OAuthManager:
                         async with aiohttp.ClientSession() as session:
                             async with session.get(picture_url) as resp:
                                 picture = await resp.read()
-                                base64_encoded_picture = base64.b64encode(picture).decode(
-                                    "utf-8"
-                                )
+                                base64_encoded_picture = base64.b64encode(
+                                    picture
+                                ).decode("utf-8")
                                 guessed_mime_type = mimetypes.guess_type(picture_url)[0]
                                 if guessed_mime_type is None:
                                     # assume JPG, browsers are tolerant enough of image formats
                                     guessed_mime_type = "image/jpeg"
                                 picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
                     except Exception as e:
-                        log.error(f"Error downloading profile image '{picture_url}': {e}")
+                        log.error(
+                            f"Error downloading profile image '{picture_url}': {e}"
+                        )
                         picture_url = ""
                 if not picture_url:
                     picture_url = "/user.png"
@@ -216,7 +223,9 @@ class OAuthManager:
                         auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
                         {
                             "action": "signup",
-                            "message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
+                            "message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(
+                                user.name
+                            ),
                             "user": user.model_dump_json(exclude_none=True),
                         },
                     )
@@ -243,4 +252,5 @@ class OAuthManager:
         redirect_url = f"{request.base_url}auth#token={jwt_token}"
         return RedirectResponse(url=redirect_url)
 
-oauth_manager = OAuthManager()
+
+oauth_manager = OAuthManager()