Browse Source

sort and fix backend imports

Pascal Lim 8 months ago
parent
commit
c386d0b1a5
63 changed files with 548 additions and 973 deletions
  1. 19 32
      backend/apps/audio/main.py
  2. 23 35
      backend/apps/images/main.py
  3. 7 9
      backend/apps/images/utils/comfyui.py
  4. 19 33
      backend/apps/ollama/main.py
  5. 21 29
      backend/apps/openai/main.py
  6. 87 116
      backend/apps/rag/main.py
  7. 2 2
      backend/apps/rag/search/brave.py
  8. 2 1
      backend/apps/rag/search/duckduckgo.py
  9. 2 3
      backend/apps/rag/search/google_pse.py
  10. 3 3
      backend/apps/rag/search/jina_search.py
  11. 1 0
      backend/apps/rag/search/main.py
  12. 2 3
      backend/apps/rag/search/searxng.py
  13. 2 2
      backend/apps/rag/search/serper.py
  14. 2 3
      backend/apps/rag/search/serply.py
  15. 2 3
      backend/apps/rag/search/serpstack.py
  16. 1 2
      backend/apps/rag/search/tavily.py
  17. 14 24
      backend/apps/rag/utils.py
  18. 1 2
      backend/apps/socket/main.py
  19. 9 14
      backend/apps/webui/internal/db.py
  20. 5 5
      backend/apps/webui/internal/wrappers.py
  21. 35 41
      backend/apps/webui/main.py
  22. 6 12
      backend/apps/webui/models/auths.py
  23. 6 30
      backend/apps/webui/models/chats.py
  24. 5 15
      backend/apps/webui/models/documents.py
  25. 5 17
      backend/apps/webui/models/files.py
  26. 5 23
      backend/apps/webui/models/functions.py
  27. 5 16
      backend/apps/webui/models/memories.py
  28. 4 7
      backend/apps/webui/models/models.py
  29. 4 13
      backend/apps/webui/models/prompts.py
  30. 7 17
      backend/apps/webui/models/tags.py
  31. 4 15
      backend/apps/webui/models/tools.py
  32. 6 7
      backend/apps/webui/models/users.py
  33. 16 28
      backend/apps/webui/routers/auths.py
  34. 8 31
      backend/apps/webui/routers/chats.py
  35. 3 21
      backend/apps/webui/routers/configs.py
  36. 6 11
      backend/apps/webui/routers/documents.py
  37. 11 36
      backend/apps/webui/routers/files.py
  38. 7 17
      backend/apps/webui/routers/functions.py
  39. 5 11
      backend/apps/webui/routers/memories.py
  40. 4 10
      backend/apps/webui/routers/models.py
  41. 4 10
      backend/apps/webui/routers/prompts.py
  42. 7 13
      backend/apps/webui/routers/tools.py
  43. 9 24
      backend/apps/webui/routers/users.py
  44. 10 17
      backend/apps/webui/routers/utils.py
  45. 4 5
      backend/apps/webui/utils.py
  46. 16 44
      backend/config.py
  47. 6 12
      backend/env.py
  48. 92 98
      backend/main.py
  49. 1 16
      backend/migrations/env.py
  50. 3 3
      backend/migrations/versions/7e5b5dc7342b_init.py
  51. 1 3
      backend/migrations/versions/ca81bd47c050_add_config_table.py
  52. 1 3
      backend/test/apps/webui/routers/test_auths.py
  53. 1 3
      backend/test/apps/webui/routers/test_chats.py
  54. 0 1
      backend/test/apps/webui/routers/test_documents.py
  55. 0 1
      backend/test/apps/webui/routers/test_models.py
  56. 0 1
      backend/test/apps/webui/routers/test_prompts.py
  57. 0 1
      backend/test/apps/webui/routers/test_users.py
  58. 4 4
      backend/utils/misc.py
  59. 1 1
      backend/utils/schemas.py
  60. 1 2
      backend/utils/task.py
  61. 0 1
      backend/utils/tools.py
  62. 8 9
      backend/utils/utils.py
  63. 3 2
      backend/utils/webhook.py

+ 19 - 32
backend/apps/audio/main.py

@@ -7,46 +7,33 @@ from functools import lru_cache
 from pathlib import Path
 
 import requests
-from fastapi import (
-    FastAPI,
-    Request,
-    Depends,
-    HTTPException,
-    status,
-    UploadFile,
-    File,
-)
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import FileResponse
-from pydantic import BaseModel
-
 from config import (
-    SRC_LOG_LEVELS,
-    CACHE_DIR,
-    WHISPER_MODEL,
-    WHISPER_MODEL_DIR,
-    WHISPER_MODEL_AUTO_UPDATE,
-    DEVICE_TYPE,
+    AUDIO_STT_ENGINE,
+    AUDIO_STT_MODEL,
     AUDIO_STT_OPENAI_API_BASE_URL,
     AUDIO_STT_OPENAI_API_KEY,
-    AUDIO_TTS_OPENAI_API_BASE_URL,
-    AUDIO_TTS_OPENAI_API_KEY,
     AUDIO_TTS_API_KEY,
-    AUDIO_STT_ENGINE,
-    AUDIO_STT_MODEL,
     AUDIO_TTS_ENGINE,
     AUDIO_TTS_MODEL,
-    AUDIO_TTS_VOICE,
+    AUDIO_TTS_OPENAI_API_BASE_URL,
+    AUDIO_TTS_OPENAI_API_KEY,
     AUDIO_TTS_SPLIT_ON,
-    AppConfig,
+    AUDIO_TTS_VOICE,
+    CACHE_DIR,
     CORS_ALLOW_ORIGIN,
+    DEVICE_TYPE,
+    WHISPER_MODEL,
+    WHISPER_MODEL_AUTO_UPDATE,
+    WHISPER_MODEL_DIR,
+    AppConfig,
 )
 from constants import ERROR_MESSAGES
-from utils.utils import (
-    get_current_user,
-    get_verified_user,
-    get_admin_user,
-)
+from env import SRC_LOG_LEVELS
+from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse
+from pydantic import BaseModel
+from utils.utils import get_admin_user, get_current_user, get_verified_user
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["AUDIO"])
@@ -211,7 +198,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
             body = json.loads(body)
             body["model"] = app.state.config.TTS_MODEL
             body = json.dumps(body).encode("utf-8")
-        except Exception as e:
+        except Exception:
             pass
 
         r = None
@@ -488,7 +475,7 @@ def get_available_voices() -> dict:
     elif app.state.config.TTS_ENGINE == "elevenlabs":
         try:
             ret = get_elevenlabs_voices()
-        except Exception as e:
+        except Exception:
             # Avoided @lru_cache with exception
             pass
 

+ 23 - 35
backend/apps/images/main.py

@@ -1,52 +1,42 @@
-from fastapi import (
-    FastAPI,
-    Request,
-    Depends,
-    HTTPException,
-)
-from fastapi.middleware.cors import CORSMiddleware
-from typing import Optional
-from pydantic import BaseModel
-from pathlib import Path
-import mimetypes
-import uuid
+import asyncio
 import base64
 import json
 import logging
+import mimetypes
 import re
-import requests
-import asyncio
-
-from utils.utils import (
-    get_verified_user,
-    get_admin_user,
-)
+import uuid
+from pathlib import Path
+from typing import Optional
 
+import requests
 from apps.images.utils.comfyui import (
-    ComfyUIWorkflow,
     ComfyUIGenerateImageForm,
+    ComfyUIWorkflow,
     comfyui_generate_image,
 )
-
-from constants import ERROR_MESSAGES
 from config import (
-    SRC_LOG_LEVELS,
-    CACHE_DIR,
-    IMAGE_GENERATION_ENGINE,
-    ENABLE_IMAGE_GENERATION,
-    AUTOMATIC1111_BASE_URL,
     AUTOMATIC1111_API_AUTH,
+    AUTOMATIC1111_BASE_URL,
+    CACHE_DIR,
     COMFYUI_BASE_URL,
     COMFYUI_WORKFLOW,
     COMFYUI_WORKFLOW_NODES,
-    IMAGES_OPENAI_API_BASE_URL,
-    IMAGES_OPENAI_API_KEY,
+    CORS_ALLOW_ORIGIN,
+    ENABLE_IMAGE_GENERATION,
+    IMAGE_GENERATION_ENGINE,
     IMAGE_GENERATION_MODEL,
     IMAGE_SIZE,
     IMAGE_STEPS,
-    CORS_ALLOW_ORIGIN,
+    IMAGES_OPENAI_API_BASE_URL,
+    IMAGES_OPENAI_API_KEY,
     AppConfig,
 )
+from constants import ERROR_MESSAGES
+from env import SRC_LOG_LEVELS
+from fastapi import Depends, FastAPI, HTTPException, Request
+from fastapi.middleware.cors import CORSMiddleware
+from pydantic import BaseModel
+from utils.utils import get_admin_user, get_verified_user
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["IMAGES"])
@@ -186,7 +176,7 @@ async def verify_url(user=Depends(get_admin_user)):
             )
             r.raise_for_status()
             return True
-        except Exception as e:
+        except Exception:
             app.state.config.ENABLED = False
             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
     elif app.state.config.ENGINE == "comfyui":
@@ -194,7 +184,7 @@ async def verify_url(user=Depends(get_admin_user)):
             r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
             r.raise_for_status()
             return True
-        except Exception as e:
+        except Exception:
             app.state.config.ENABLED = False
             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
     else:
@@ -397,7 +387,6 @@ def save_url_image(url):
         r = requests.get(url)
         r.raise_for_status()
         if r.headers["content-type"].split("/")[0] == "image":
-
             mime_type = r.headers["content-type"]
             image_format = mimetypes.guess_extension(mime_type)
 
@@ -412,7 +401,7 @@ def save_url_image(url):
                     image_file.write(chunk)
             return image_filename
         else:
-            log.error(f"Url does not point to an image.")
+            log.error("Url does not point to an image.")
             return None
 
     except Exception as e:
@@ -430,7 +419,6 @@ async def image_generations(
     r = None
     try:
         if app.state.config.ENGINE == "openai":
-
             headers = {}
             headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
             headers["Content-Type"] = "application/json"

+ 7 - 9
backend/apps/images/utils/comfyui.py

@@ -1,20 +1,18 @@
 import asyncio
-import websocket  # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
 import json
-import urllib.request
-import urllib.parse
-import random
 import logging
+import random
+import urllib.parse
+import urllib.request
+from typing import Optional
 
-from config import SRC_LOG_LEVELS
+import websocket  # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
+from env import SRC_LOG_LEVELS
+from pydantic import BaseModel
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
 
-from pydantic import BaseModel
-
-from typing import Optional
-
 default_headers = {"User-Agent": "Mozilla/5.0"}
 
 

+ 19 - 33
backend/apps/ollama/main.py

@@ -1,54 +1,40 @@
-from fastapi import (
-    FastAPI,
-    Request,
-    HTTPException,
-    Depends,
-    UploadFile,
-    File,
-)
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import StreamingResponse
-
-from pydantic import BaseModel, ConfigDict
-
-import os
-import re
-import random
-import requests
-import json
-import aiohttp
 import asyncio
+import json
 import logging
+import os
+import random
+import re
 import time
-from urllib.parse import urlparse
 from typing import Optional, Union
+from urllib.parse import urlparse
 
-from starlette.background import BackgroundTask
-
+import aiohttp
+import requests
 from apps.webui.models.models import Models
-from constants import ERROR_MESSAGES
-from utils.utils import (
-    get_verified_user,
-    get_admin_user,
-)
-
 from config import (
-    SRC_LOG_LEVELS,
-    OLLAMA_BASE_URLS,
-    ENABLE_OLLAMA_API,
     AIOHTTP_CLIENT_TIMEOUT,
+    CORS_ALLOW_ORIGIN,
     ENABLE_MODEL_FILTER,
+    ENABLE_OLLAMA_API,
     MODEL_FILTER_LIST,
+    OLLAMA_BASE_URLS,
     UPLOAD_DIR,
     AppConfig,
-    CORS_ALLOW_ORIGIN,
 )
+from constants import ERROR_MESSAGES
+from env import SRC_LOG_LEVELS
+from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import StreamingResponse
+from pydantic import BaseModel, ConfigDict
+from starlette.background import BackgroundTask
 from utils.misc import (
-    calculate_sha256,
     apply_model_params_to_body_ollama,
     apply_model_params_to_body_openai,
     apply_model_system_prompt_to_body,
+    calculate_sha256,
 )
+from utils.utils import get_admin_user, get_verified_user
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["OLLAMA"])

+ 21 - 29
backend/apps/openai/main.py

@@ -1,44 +1,36 @@
-from fastapi import FastAPI, Request, HTTPException, Depends
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import StreamingResponse, FileResponse
-
-import requests
-import aiohttp
 import asyncio
+import hashlib
 import json
 import logging
+from pathlib import Path
+from typing import Literal, Optional, overload
 
-from pydantic import BaseModel
-from starlette.background import BackgroundTask
-
+import aiohttp
+import requests
 from apps.webui.models.models import Models
-from constants import ERROR_MESSAGES
-from utils.utils import (
-    get_verified_user,
-    get_admin_user,
-)
-from utils.misc import (
-    apply_model_params_to_body_openai,
-    apply_model_system_prompt_to_body,
-)
-
 from config import (
-    SRC_LOG_LEVELS,
-    ENABLE_OPENAI_API,
     AIOHTTP_CLIENT_TIMEOUT,
-    OPENAI_API_BASE_URLS,
-    OPENAI_API_KEYS,
     CACHE_DIR,
+    CORS_ALLOW_ORIGIN,
     ENABLE_MODEL_FILTER,
+    ENABLE_OPENAI_API,
     MODEL_FILTER_LIST,
+    OPENAI_API_BASE_URLS,
+    OPENAI_API_KEYS,
     AppConfig,
-    CORS_ALLOW_ORIGIN,
 )
-from typing import Optional, Literal, overload
-
-
-import hashlib
-from pathlib import Path
+from constants import ERROR_MESSAGES
+from env import SRC_LOG_LEVELS
+from fastapi import Depends, FastAPI, HTTPException, Request
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse, StreamingResponse
+from pydantic import BaseModel
+from starlette.background import BackgroundTask
+from utils.misc import (
+    apply_model_params_to_body_openai,
+    apply_model_system_prompt_to_body,
+)
+from utils.utils import get_admin_user, get_verified_user
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["OPENAI"])

+ 87 - 116
backend/apps/rag/main.py

@@ -1,143 +1,118 @@
-from fastapi import (
-    FastAPI,
-    Depends,
-    HTTPException,
-    status,
-    UploadFile,
-    File,
-    Form,
-)
-from fastapi.middleware.cors import CORSMiddleware
-import requests
-import os, shutil, logging, re
+import json
+import logging
+import mimetypes
+import os
+import shutil
+import socket
+import urllib.parse
+import uuid
 from datetime import datetime
-
 from pathlib import Path
-from typing import Union, Sequence, Iterator, Any
-
-from chromadb.utils.batch_utils import create_batches
-from langchain_core.documents import Document
-
-from langchain_community.document_loaders import (
-    WebBaseLoader,
-    TextLoader,
-    PyPDFLoader,
-    CSVLoader,
-    BSHTMLLoader,
-    Docx2txtLoader,
-    UnstructuredEPubLoader,
-    UnstructuredWordDocumentLoader,
-    UnstructuredMarkdownLoader,
-    UnstructuredXMLLoader,
-    UnstructuredRSTLoader,
-    UnstructuredExcelLoader,
-    UnstructuredPowerPointLoader,
-    YoutubeLoader,
-    OutlookMessageLoader,
-)
-from langchain.text_splitter import RecursiveCharacterTextSplitter
+from typing import Iterator, Optional, Sequence, Union
 
+import requests
 import validators
-import urllib.parse
-import socket
-
-
-from pydantic import BaseModel
-from typing import Optional
-import mimetypes
-import uuid
-import json
-
-from apps.webui.models.documents import (
-    Documents,
-    DocumentForm,
-    DocumentResponse,
-)
-from apps.webui.models.files import (
-    Files,
-)
-
-from apps.rag.utils import (
-    get_model_path,
-    get_embedding_function,
-    query_doc,
-    query_doc_with_hybrid_search,
-    query_collection,
-    query_collection_with_hybrid_search,
-)
-
 from apps.rag.search.brave import search_brave
+from apps.rag.search.duckduckgo import search_duckduckgo
 from apps.rag.search.google_pse import search_google_pse
+from apps.rag.search.jina_search import search_jina
 from apps.rag.search.main import SearchResult
+from apps.rag.search.searchapi import search_searchapi
 from apps.rag.search.searxng import search_searxng
 from apps.rag.search.serper import search_serper
-from apps.rag.search.serpstack import search_serpstack
 from apps.rag.search.serply import search_serply
-from apps.rag.search.duckduckgo import search_duckduckgo
+from apps.rag.search.serpstack import search_serpstack
 from apps.rag.search.tavily import search_tavily
-from apps.rag.search.jina_search import search_jina
-from apps.rag.search.searchapi import search_searchapi
-
-from utils.misc import (
-    calculate_sha256,
-    calculate_sha256_string,
-    sanitize_filename,
-    extract_folders_after_data_docs,
+from apps.rag.utils import (
+    get_embedding_function,
+    get_model_path,
+    query_collection,
+    query_collection_with_hybrid_search,
+    query_doc,
+    query_doc_with_hybrid_search,
 )
-from utils.utils import get_verified_user, get_admin_user
-
+from apps.webui.models.documents import DocumentForm, Documents
+from apps.webui.models.files import Files
+from chromadb.utils.batch_utils import create_batches
 from config import (
-    AppConfig,
-    ENV,
-    SRC_LOG_LEVELS,
-    UPLOAD_DIR,
-    DOCS_DIR,
+    BRAVE_SEARCH_API_KEY,
+    CHROMA_CLIENT,
+    CHUNK_OVERLAP,
+    CHUNK_SIZE,
     CONTENT_EXTRACTION_ENGINE,
-    TIKA_SERVER_URL,
-    RAG_TOP_K,
-    RAG_RELEVANCE_THRESHOLD,
-    RAG_FILE_MAX_SIZE,
-    RAG_FILE_MAX_COUNT,
+    CORS_ALLOW_ORIGIN,
+    DEVICE_TYPE,
+    DOCS_DIR,
+    ENABLE_RAG_HYBRID_SEARCH,
+    ENABLE_RAG_LOCAL_WEB_FETCH,
+    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
+    ENABLE_RAG_WEB_SEARCH,
+    ENV,
+    GOOGLE_PSE_API_KEY,
+    GOOGLE_PSE_ENGINE_ID,
+    PDF_EXTRACT_IMAGES,
     RAG_EMBEDDING_ENGINE,
     RAG_EMBEDDING_MODEL,
     RAG_EMBEDDING_MODEL_AUTO_UPDATE,
     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
-    ENABLE_RAG_HYBRID_SEARCH,
-    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
+    RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+    RAG_FILE_MAX_COUNT,
+    RAG_FILE_MAX_SIZE,
+    RAG_OPENAI_API_BASE_URL,
+    RAG_OPENAI_API_KEY,
+    RAG_RELEVANCE_THRESHOLD,
     RAG_RERANKING_MODEL,
-    PDF_EXTRACT_IMAGES,
     RAG_RERANKING_MODEL_AUTO_UPDATE,
     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
-    RAG_OPENAI_API_BASE_URL,
-    RAG_OPENAI_API_KEY,
-    DEVICE_TYPE,
-    CHROMA_CLIENT,
-    CHUNK_SIZE,
-    CHUNK_OVERLAP,
     RAG_TEMPLATE,
-    ENABLE_RAG_LOCAL_WEB_FETCH,
-    YOUTUBE_LOADER_LANGUAGE,
-    ENABLE_RAG_WEB_SEARCH,
-    RAG_WEB_SEARCH_ENGINE,
+    RAG_TOP_K,
+    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
     RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
+    RAG_WEB_SEARCH_ENGINE,
+    RAG_WEB_SEARCH_RESULT_COUNT,
+    SEARCHAPI_API_KEY,
+    SEARCHAPI_ENGINE,
     SEARXNG_QUERY_URL,
-    GOOGLE_PSE_API_KEY,
-    GOOGLE_PSE_ENGINE_ID,
-    BRAVE_SEARCH_API_KEY,
-    SERPSTACK_API_KEY,
-    SERPSTACK_HTTPS,
     SERPER_API_KEY,
     SERPLY_API_KEY,
+    SERPSTACK_API_KEY,
+    SERPSTACK_HTTPS,
     TAVILY_API_KEY,
-    SEARCHAPI_API_KEY,
-    SEARCHAPI_ENGINE,
-    RAG_WEB_SEARCH_RESULT_COUNT,
-    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
-    RAG_EMBEDDING_OPENAI_BATCH_SIZE,
-    CORS_ALLOW_ORIGIN,
+    TIKA_SERVER_URL,
+    UPLOAD_DIR,
+    YOUTUBE_LOADER_LANGUAGE,
+    AppConfig,
 )
-
 from constants import ERROR_MESSAGES
+from env import SRC_LOG_LEVELS
+from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
+from fastapi.middleware.cors import CORSMiddleware
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from langchain_community.document_loaders import (
+    BSHTMLLoader,
+    CSVLoader,
+    Docx2txtLoader,
+    OutlookMessageLoader,
+    PyPDFLoader,
+    TextLoader,
+    UnstructuredEPubLoader,
+    UnstructuredExcelLoader,
+    UnstructuredMarkdownLoader,
+    UnstructuredPowerPointLoader,
+    UnstructuredRSTLoader,
+    UnstructuredXMLLoader,
+    WebBaseLoader,
+    YoutubeLoader,
+)
+from langchain_core.documents import Document
+from pydantic import BaseModel
+from utils.misc import (
+    calculate_sha256,
+    calculate_sha256_string,
+    extract_folders_after_data_docs,
+    sanitize_filename,
+)
+from utils.utils import get_admin_user, get_verified_user
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -539,9 +514,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
         app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
         app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
         app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key
-        app.state.config.SEARCHAPI_ENGINE = (
-            form_data.web.search.searchapi_engine
-        )
+        app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine
         app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
         app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
             form_data.web.search.concurrent_requests
@@ -981,7 +954,6 @@ def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
 def store_data_in_vector_db(
     data, collection_name, metadata: Optional[dict] = None, overwrite: bool = False
 ) -> bool:
-
     text_splitter = RecursiveCharacterTextSplitter(
         chunk_size=app.state.config.CHUNK_SIZE,
         chunk_overlap=app.state.config.CHUNK_OVERLAP,
@@ -1342,7 +1314,6 @@ def store_text(
     form_data: TextRAGForm,
     user=Depends(get_verified_user),
 ):
-
     collection_name = form_data.collection_name
     if collection_name is None:
         collection_name = calculate_sha256_string(form_data.content)

+ 2 - 2
backend/apps/rag/search/brave.py

@@ -1,9 +1,9 @@
 import logging
 from typing import Optional
-import requests
 
+import requests
 from apps.rag.search.main import SearchResult, get_filtered_results
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])

+ 2 - 1
backend/apps/rag/search/duckduckgo.py

@@ -1,8 +1,9 @@
 import logging
 from typing import Optional
+
 from apps.rag.search.main import SearchResult, get_filtered_results
 from duckduckgo_search import DDGS
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])

+ 2 - 3
backend/apps/rag/search/google_pse.py

@@ -1,10 +1,9 @@
-import json
 import logging
 from typing import Optional
-import requests
 
+import requests
 from apps.rag.search.main import SearchResult, get_filtered_results
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])

+ 3 - 3
backend/apps/rag/search/jina_search.py

@@ -1,9 +1,9 @@
 import logging
-import requests
-from yarl import URL
 
+import requests
 from apps.rag.search.main import SearchResult
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
+from yarl import URL
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])

+ 1 - 0
backend/apps/rag/search/main.py

@@ -1,5 +1,6 @@
 from typing import Optional
 from urllib.parse import urlparse
+
 from pydantic import BaseModel
 
 

+ 2 - 3
backend/apps/rag/search/searxng.py

@@ -1,10 +1,9 @@
 import logging
-import requests
-
 from typing import Optional
 
+import requests
 from apps.rag.search.main import SearchResult, get_filtered_results
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])

+ 2 - 2
backend/apps/rag/search/serper.py

@@ -1,10 +1,10 @@
 import json
 import logging
 from typing import Optional
-import requests
 
+import requests
 from apps.rag.search.main import SearchResult, get_filtered_results
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])

+ 2 - 3
backend/apps/rag/search/serply.py

@@ -1,11 +1,10 @@
-import json
 import logging
 from typing import Optional
-import requests
 from urllib.parse import urlencode
 
+import requests
 from apps.rag.search.main import SearchResult, get_filtered_results
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])

+ 2 - 3
backend/apps/rag/search/serpstack.py

@@ -1,10 +1,9 @@
-import json
 import logging
 from typing import Optional
-import requests
 
+import requests
 from apps.rag.search.main import SearchResult, get_filtered_results
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])

+ 1 - 2
backend/apps/rag/search/tavily.py

@@ -1,9 +1,8 @@
 import logging
 
 import requests
-
 from apps.rag.search.main import SearchResult
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])

+ 14 - 24
backend/apps/rag/utils.py

@@ -1,27 +1,16 @@
-import os
 import logging
-import requests
-
-from typing import Union
-
-from apps.ollama.main import (
-    generate_ollama_embeddings,
-    GenerateEmbeddingsForm,
-)
+import os
+from typing import Optional, Union
 
+import requests
+from apps.ollama.main import GenerateEmbeddingsForm, generate_ollama_embeddings
+from config import CHROMA_CLIENT
+from env import SRC_LOG_LEVELS
 from huggingface_hub import snapshot_download
-
-from langchain_core.documents import Document
+from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
 from langchain_community.retrievers import BM25Retriever
-from langchain.retrievers import (
-    ContextualCompressionRetriever,
-    EnsembleRetriever,
-)
-
-from typing import Optional
-
-from utils.misc import get_last_user_message, add_or_update_system_message
-from config import SRC_LOG_LEVELS, CHROMA_CLIENT
+from langchain_core.documents import Document
+from utils.misc import get_last_user_message
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -261,7 +250,9 @@ def get_rag_context(
         collection_names = (
             file["collection_names"]
             if file["type"] == "collection"
-            else [file["collection_name"]] if file["collection_name"] else []
+            else [file["collection_name"]]
+            if file["collection_name"]
+            else []
         )
 
         collection_names = set(collection_names).difference(extracted_collections)
@@ -401,8 +392,8 @@ def generate_openai_batch_embeddings(
 
 from typing import Any
 
-from langchain_core.retrievers import BaseRetriever
 from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.retrievers import BaseRetriever
 
 
 class ChromaRetriever(BaseRetriever):
@@ -439,11 +430,10 @@ class ChromaRetriever(BaseRetriever):
 
 
 import operator
-
 from typing import Optional, Sequence
 
-from langchain_core.documents import BaseDocumentCompressor, Document
 from langchain_core.callbacks import Callbacks
+from langchain_core.documents import BaseDocumentCompressor, Document
 from langchain_core.pydantic_v1 import Extra
 
 

+ 1 - 2
backend/apps/socket/main.py

@@ -1,7 +1,6 @@
-import socketio
 import asyncio
 
-
+import socketio
 from apps.webui.models.users import Users
 from utils.utils import decode_token
 

+ 9 - 14
backend/apps/webui/internal/db.py

@@ -1,21 +1,16 @@
-import os
-import logging
 import json
+import logging
 from contextlib import contextmanager
+from typing import Any, Optional
 
-
-from typing import Optional, Any
-from typing_extensions import Self
-
-from sqlalchemy import create_engine, types, Dialect
-from sqlalchemy.sql.type_api import _T
-from sqlalchemy.ext.declarative import declarative_base
-from sqlalchemy.orm import sessionmaker, scoped_session
-
-
-from peewee_migrate import Router
 from apps.webui.internal.wrappers import register_connection
-from env import SRC_LOG_LEVELS, BACKEND_DIR, DATABASE_URL
+from env import BACKEND_DIR, DATABASE_URL, SRC_LOG_LEVELS
+from peewee_migrate import Router
+from sqlalchemy import Dialect, create_engine, types
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import scoped_session, sessionmaker
+from sqlalchemy.sql.type_api import _T
+from typing_extensions import Self
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["DB"])

+ 5 - 5
backend/apps/webui/internal/wrappers.py

@@ -1,13 +1,13 @@
+import logging
 from contextvars import ContextVar
-from peewee import *
-from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError
 
-import logging
+from env import SRC_LOG_LEVELS
+from peewee import *
+from peewee import InterfaceError as PeeWeeInterfaceError
+from peewee import PostgresqlDatabase
 from playhouse.db_url import connect, parse
 from playhouse.shortcuts import ReconnectMixin
 
-from env import SRC_LOG_LEVELS
-
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["DB"])
 

+ 35 - 41
backend/apps/webui/main.py

@@ -1,65 +1,59 @@
-from fastapi import FastAPI
-from fastapi.responses import StreamingResponse
-from fastapi.middleware.cors import CORSMiddleware
+import inspect
+import json
+import logging
+from typing import AsyncGenerator, Generator, Iterator
+
+from apps.socket.main import get_event_call, get_event_emitter
+from apps.webui.models.functions import Functions
+from apps.webui.models.models import Models
 from apps.webui.routers import (
     auths,
-    users,
     chats,
+    configs,
     documents,
-    tools,
+    files,
+    functions,
+    memories,
     models,
     prompts,
-    configs,
-    memories,
+    tools,
+    users,
     utils,
-    files,
-    functions,
 )
-from apps.webui.models.functions import Functions
-from apps.webui.models.models import Models
 from apps.webui.utils import load_function_module_by_id
-
-from utils.misc import (
-    openai_chat_chunk_message_template,
-    openai_chat_completion_message_template,
-    apply_model_params_to_body_openai,
-    apply_model_system_prompt_to_body,
-)
-
-from utils.tools import get_tools
-
 from config import (
-    SHOW_ADMIN_DETAILS,
     ADMIN_EMAIL,
-    WEBUI_AUTH,
+    CORS_ALLOW_ORIGIN,
     DEFAULT_MODELS,
     DEFAULT_PROMPT_SUGGESTIONS,
     DEFAULT_USER_ROLE,
-    ENABLE_SIGNUP,
+    ENABLE_COMMUNITY_SHARING,
     ENABLE_LOGIN_FORM,
+    ENABLE_MESSAGE_RATING,
+    ENABLE_SIGNUP,
+    JWT_EXPIRES_IN,
+    OAUTH_EMAIL_CLAIM,
+    OAUTH_PICTURE_CLAIM,
+    OAUTH_USERNAME_CLAIM,
+    SHOW_ADMIN_DETAILS,
     USER_PERMISSIONS,
     WEBHOOK_URL,
-    WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
-    WEBUI_AUTH_TRUSTED_NAME_HEADER,
-    JWT_EXPIRES_IN,
+    WEBUI_AUTH,
     WEBUI_BANNERS,
-    ENABLE_COMMUNITY_SHARING,
-    ENABLE_MESSAGE_RATING,
     AppConfig,
-    OAUTH_USERNAME_CLAIM,
-    OAUTH_PICTURE_CLAIM,
-    OAUTH_EMAIL_CLAIM,
-    CORS_ALLOW_ORIGIN,
 )
-
-from apps.socket.main import get_event_call, get_event_emitter
-
-import inspect
-import json
-import logging
-
-from typing import Iterator, Generator, AsyncGenerator
+from env import WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER
+from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import StreamingResponse
 from pydantic import BaseModel
+from utils.misc import (
+    apply_model_params_to_body_openai,
+    apply_model_system_prompt_to_body,
+    openai_chat_chunk_message_template,
+    openai_chat_completion_message_template,
+)
+from utils.tools import get_tools
 
 app = FastAPI()
 

+ 6 - 12
backend/apps/webui/models/auths.py

@@ -1,15 +1,13 @@
-from pydantic import BaseModel
-from typing import Optional
-import uuid
 import logging
-from sqlalchemy import String, Column, Boolean, Text
-
-from utils.utils import verify_password
+import uuid
+from typing import Optional
 
-from apps.webui.models.users import UserModel, Users
 from apps.webui.internal.db import Base, get_db
-
+from apps.webui.models.users import UserModel, Users
 from env import SRC_LOG_LEVELS
+from pydantic import BaseModel
+from sqlalchemy import Boolean, Column, String, Text
+from utils.utils import verify_password
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -92,7 +90,6 @@ class AddUserForm(SignupForm):
 
 
 class AuthsTable:
-
     def insert_new_auth(
         self,
         email: str,
@@ -103,7 +100,6 @@ class AuthsTable:
         oauth_sub: Optional[str] = None,
     ) -> Optional[UserModel]:
         with get_db() as db:
-
             log.info("insert_new_auth")
 
             id = str(uuid.uuid4())
@@ -130,7 +126,6 @@ class AuthsTable:
         log.info(f"authenticate_user: {email}")
         try:
             with get_db() as db:
-
                 auth = db.query(Auth).filter_by(email=email, active=True).first()
                 if auth:
                     if verify_password(password, auth.password):
@@ -189,7 +184,6 @@ class AuthsTable:
     def delete_auth_by_id(self, id: str) -> bool:
         try:
             with get_db() as db:
-
                 # Delete User
                 result = Users.delete_user_by_id(id)
 

+ 6 - 30
backend/apps/webui/models/chats.py

@@ -1,14 +1,11 @@
-from pydantic import BaseModel, ConfigDict
-from typing import Union, Optional
-
 import json
-import uuid
 import time
-
-from sqlalchemy import Column, String, BigInteger, Boolean, Text
+import uuid
+from typing import Optional
 
 from apps.webui.internal.db import Base, get_db
-
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Boolean, Column, String, Text
 
 ####################
 # Chat DB Schema
@@ -77,10 +74,8 @@ class ChatTitleIdResponse(BaseModel):
 
 
 class ChatTable:
-
     def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
         with get_db() as db:
-
             id = str(uuid.uuid4())
             chat = ChatModel(
                 **{
@@ -106,7 +101,6 @@ class ChatTable:
     def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
         try:
             with get_db() as db:
-
                 chat_obj = db.get(Chat, id)
                 chat_obj.chat = json.dumps(chat)
                 chat_obj.title = chat["title"] if "title" in chat else "New Chat"
@@ -115,12 +109,11 @@ class ChatTable:
                 db.refresh(chat_obj)
 
                 return ChatModel.model_validate(chat_obj)
-        except Exception as e:
+        except Exception:
             return None
 
     def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
         with get_db() as db:
-
             # Get the existing chat to share
             chat = db.get(Chat, chat_id)
             # Check if the chat is already shared
@@ -154,7 +147,6 @@ class ChatTable:
     def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
         try:
             with get_db() as db:
-
                 print("update_shared_chat_by_id")
                 chat = db.get(Chat, chat_id)
                 print(chat)
@@ -170,7 +162,6 @@ class ChatTable:
     def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
         try:
             with get_db() as db:
-
                 db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
                 db.commit()
 
@@ -183,7 +174,6 @@ class ChatTable:
     ) -> Optional[ChatModel]:
         try:
             with get_db() as db:
-
                 chat = db.get(Chat, id)
                 chat.share_id = share_id
                 db.commit()
@@ -195,7 +185,6 @@ class ChatTable:
     def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
         try:
             with get_db() as db:
-
                 chat = db.get(Chat, id)
                 chat.archived = not chat.archived
                 db.commit()
@@ -217,7 +206,6 @@ class ChatTable:
         self, user_id: str, skip: int = 0, limit: int = 50
     ) -> list[ChatModel]:
         with get_db() as db:
-
             all_chats = (
                 db.query(Chat)
                 .filter_by(user_id=user_id, archived=True)
@@ -297,7 +285,6 @@ class ChatTable:
     def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
         try:
             with get_db() as db:
-
                 chat = db.get(Chat, id)
                 return ChatModel.model_validate(chat)
         except Exception:
@@ -306,20 +293,18 @@ class ChatTable:
     def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
         try:
             with get_db() as db:
-
                 chat = db.query(Chat).filter_by(share_id=id).first()
 
                 if chat:
                     return self.get_chat_by_id(id)
                 else:
                     return None
-        except Exception as e:
+        except Exception:
             return None
 
     def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
         try:
             with get_db() as db:
-
                 chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
                 return ChatModel.model_validate(chat)
         except Exception:
@@ -327,7 +312,6 @@ class ChatTable:
 
     def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]:
         with get_db() as db:
-
             all_chats = (
                 db.query(Chat)
                 # .limit(limit).offset(skip)
@@ -337,7 +321,6 @@ class ChatTable:
 
     def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
         with get_db() as db:
-
             all_chats = (
                 db.query(Chat)
                 .filter_by(user_id=user_id)
@@ -347,7 +330,6 @@ class ChatTable:
 
     def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
         with get_db() as db:
-
             all_chats = (
                 db.query(Chat)
                 .filter_by(user_id=user_id, archived=True)
@@ -358,7 +340,6 @@ class ChatTable:
     def delete_chat_by_id(self, id: str) -> bool:
         try:
             with get_db() as db:
-
                 db.query(Chat).filter_by(id=id).delete()
                 db.commit()
 
@@ -369,7 +350,6 @@ class ChatTable:
     def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
         try:
             with get_db() as db:
-
                 db.query(Chat).filter_by(id=id, user_id=user_id).delete()
                 db.commit()
 
@@ -379,9 +359,7 @@ class ChatTable:
 
     def delete_chats_by_user_id(self, user_id: str) -> bool:
         try:
-
             with get_db() as db:
-
                 self.delete_shared_chats_by_user_id(user_id)
 
                 db.query(Chat).filter_by(user_id=user_id).delete()
@@ -393,9 +371,7 @@ class ChatTable:
 
     def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
         try:
-
             with get_db() as db:
-
                 chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
                 shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
 

+ 5 - 15
backend/apps/webui/models/documents.py

@@ -1,15 +1,12 @@
-from pydantic import BaseModel, ConfigDict
-from typing import Optional
-import time
+import json
 import logging
-
-from sqlalchemy import String, Column, BigInteger, Text
+import time
+from typing import Optional
 
 from apps.webui.internal.db import Base, get_db
-
-import json
-
 from env import SRC_LOG_LEVELS
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, String, Text
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -70,12 +67,10 @@ class DocumentForm(DocumentUpdateForm):
 
 
 class DocumentsTable:
-
     def insert_new_doc(
         self, user_id: str, form_data: DocumentForm
     ) -> Optional[DocumentModel]:
         with get_db() as db:
-
             document = DocumentModel(
                 **{
                     **form_data.model_dump(),
@@ -99,7 +94,6 @@ class DocumentsTable:
     def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
         try:
             with get_db() as db:
-
                 document = db.query(Document).filter_by(name=name).first()
                 return DocumentModel.model_validate(document) if document else None
         except Exception:
@@ -107,7 +101,6 @@ class DocumentsTable:
 
     def get_docs(self) -> list[DocumentModel]:
         with get_db() as db:
-
             return [
                 DocumentModel.model_validate(doc) for doc in db.query(Document).all()
             ]
@@ -117,7 +110,6 @@ class DocumentsTable:
     ) -> Optional[DocumentModel]:
         try:
             with get_db() as db:
-
                 db.query(Document).filter_by(name=name).update(
                     {
                         "title": form_data.title,
@@ -140,7 +132,6 @@ class DocumentsTable:
             doc_content = {**doc_content, **updated}
 
             with get_db() as db:
-
                 db.query(Document).filter_by(name=name).update(
                     {
                         "content": json.dumps(doc_content),
@@ -156,7 +147,6 @@ class DocumentsTable:
     def delete_doc_by_name(self, name: str) -> bool:
         try:
             with get_db() as db:
-
                 db.query(Document).filter_by(name=name).delete()
                 db.commit()
                 return True

+ 5 - 17
backend/apps/webui/models/files.py

@@ -1,15 +1,11 @@
-from pydantic import BaseModel, ConfigDict
-from typing import Union, Optional
-import time
 import logging
+import time
+from typing import Optional
 
-from sqlalchemy import Column, String, BigInteger, Text
-
-from apps.webui.internal.db import JSONField, Base, get_db
-
-import json
-
+from apps.webui.internal.db import Base, JSONField, get_db
 from env import SRC_LOG_LEVELS
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, String, Text
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -59,10 +55,8 @@ class FileForm(BaseModel):
 
 
 class FilesTable:
-
     def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
         with get_db() as db:
-
             file = FileModel(
                 **{
                     **form_data.model_dump(),
@@ -86,7 +80,6 @@ class FilesTable:
 
     def get_file_by_id(self, id: str) -> Optional[FileModel]:
         with get_db() as db:
-
             try:
                 file = db.get(File, id)
                 return FileModel.model_validate(file)
@@ -95,7 +88,6 @@ class FilesTable:
 
     def get_files(self) -> list[FileModel]:
         with get_db() as db:
-
             return [FileModel.model_validate(file) for file in db.query(File).all()]
 
     def get_files_by_user_id(self, user_id: str) -> list[FileModel]:
@@ -106,9 +98,7 @@ class FilesTable:
             ]
 
     def delete_file_by_id(self, id: str) -> bool:
-
         with get_db() as db:
-
             try:
                 db.query(File).filter_by(id=id).delete()
                 db.commit()
@@ -118,9 +108,7 @@ class FilesTable:
                 return False
 
     def delete_all_files(self) -> bool:
-
         with get_db() as db:
-
             try:
                 db.query(File).delete()
                 db.commit()

+ 5 - 23
backend/apps/webui/models/functions.py

@@ -1,18 +1,12 @@
-from pydantic import BaseModel, ConfigDict
-from typing import Union, Optional
-import time
 import logging
+import time
+from typing import Optional
 
-from sqlalchemy import Column, String, Text, BigInteger, Boolean
-
-from apps.webui.internal.db import JSONField, Base, get_db
+from apps.webui.internal.db import Base, JSONField, get_db
 from apps.webui.models.users import Users
-
-import json
-import copy
-
-
 from env import SRC_LOG_LEVELS
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Boolean, Column, String, Text
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -87,11 +81,9 @@ class FunctionValves(BaseModel):
 
 
 class FunctionsTable:
-
     def insert_new_function(
         self, user_id: str, type: str, form_data: FunctionForm
     ) -> Optional[FunctionModel]:
-
         function = FunctionModel(
             **{
                 **form_data.model_dump(),
@@ -119,7 +111,6 @@ class FunctionsTable:
     def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
         try:
             with get_db() as db:
-
                 function = db.get(Function, id)
                 return FunctionModel.model_validate(function)
         except Exception:
@@ -127,7 +118,6 @@ class FunctionsTable:
 
     def get_functions(self, active_only=False) -> list[FunctionModel]:
         with get_db() as db:
-
             if active_only:
                 return [
                     FunctionModel.model_validate(function)
@@ -143,7 +133,6 @@ class FunctionsTable:
         self, type: str, active_only=False
     ) -> list[FunctionModel]:
         with get_db() as db:
-
             if active_only:
                 return [
                     FunctionModel.model_validate(function)
@@ -159,7 +148,6 @@ class FunctionsTable:
 
     def get_global_filter_functions(self) -> list[FunctionModel]:
         with get_db() as db:
-
             return [
                 FunctionModel.model_validate(function)
                 for function in db.query(Function)
@@ -178,7 +166,6 @@ class FunctionsTable:
 
     def get_function_valves_by_id(self, id: str) -> Optional[dict]:
         with get_db() as db:
-
             try:
                 function = db.get(Function, id)
                 return function.valves if function.valves else {}
@@ -190,7 +177,6 @@ class FunctionsTable:
         self, id: str, valves: dict
     ) -> Optional[FunctionValves]:
         with get_db() as db:
-
             try:
                 function = db.get(Function, id)
                 function.valves = valves
@@ -204,7 +190,6 @@ class FunctionsTable:
     def get_user_valves_by_id_and_user_id(
         self, id: str, user_id: str
     ) -> Optional[dict]:
-
         try:
             user = Users.get_user_by_id(user_id)
             user_settings = user.settings.model_dump() if user.settings else {}
@@ -223,7 +208,6 @@ class FunctionsTable:
     def update_user_valves_by_id_and_user_id(
         self, id: str, user_id: str, valves: dict
     ) -> Optional[dict]:
-
         try:
             user = Users.get_user_by_id(user_id)
             user_settings = user.settings.model_dump() if user.settings else {}
@@ -246,7 +230,6 @@ class FunctionsTable:
 
     def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
         with get_db() as db:
-
             try:
                 db.query(Function).filter_by(id=id).update(
                     {
@@ -261,7 +244,6 @@ class FunctionsTable:
 
     def deactivate_all_functions(self) -> Optional[bool]:
         with get_db() as db:
-
             try:
                 db.query(Function).update(
                     {

+ 5 - 16
backend/apps/webui/models/memories.py

@@ -1,12 +1,10 @@
-from pydantic import BaseModel, ConfigDict
-from typing import Union, Optional
-
-from sqlalchemy import Column, String, BigInteger, Text
-
-from apps.webui.internal.db import Base, get_db
-
 import time
 import uuid
+from typing import Optional
+
+from apps.webui.internal.db import Base, get_db
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, String, Text
 
 ####################
 # Memory DB Schema
@@ -39,13 +37,11 @@ class MemoryModel(BaseModel):
 
 
 class MemoriesTable:
-
     def insert_new_memory(
         self,
         user_id: str,
         content: str,
     ) -> Optional[MemoryModel]:
-
         with get_db() as db:
             id = str(uuid.uuid4())
 
@@ -73,7 +69,6 @@ class MemoriesTable:
         content: str,
     ) -> Optional[MemoryModel]:
         with get_db() as db:
-
             try:
                 db.query(Memory).filter_by(id=id).update(
                     {"content": content, "updated_at": int(time.time())}
@@ -85,7 +80,6 @@ class MemoriesTable:
 
     def get_memories(self) -> list[MemoryModel]:
         with get_db() as db:
-
             try:
                 memories = db.query(Memory).all()
                 return [MemoryModel.model_validate(memory) for memory in memories]
@@ -94,7 +88,6 @@ class MemoriesTable:
 
     def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]:
         with get_db() as db:
-
             try:
                 memories = db.query(Memory).filter_by(user_id=user_id).all()
                 return [MemoryModel.model_validate(memory) for memory in memories]
@@ -103,7 +96,6 @@ class MemoriesTable:
 
     def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
         with get_db() as db:
-
             try:
                 memory = db.get(Memory, id)
                 return MemoryModel.model_validate(memory)
@@ -112,7 +104,6 @@ class MemoriesTable:
 
     def delete_memory_by_id(self, id: str) -> bool:
         with get_db() as db:
-
             try:
                 db.query(Memory).filter_by(id=id).delete()
                 db.commit()
@@ -124,7 +115,6 @@ class MemoriesTable:
 
     def delete_memories_by_user_id(self, user_id: str) -> bool:
         with get_db() as db:
-
             try:
                 db.query(Memory).filter_by(user_id=user_id).delete()
                 db.commit()
@@ -135,7 +125,6 @@ class MemoriesTable:
 
     def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
         with get_db() as db:
-
             try:
                 db.query(Memory).filter_by(id=id, user_id=user_id).delete()
                 db.commit()

+ 4 - 7
backend/apps/webui/models/models.py

@@ -1,14 +1,11 @@
 import logging
-from typing import Optional, List
-
-from pydantic import BaseModel, ConfigDict
-from sqlalchemy import Column, BigInteger, Text
+import time
+from typing import Optional
 
 from apps.webui.internal.db import Base, JSONField, get_db
-
 from env import SRC_LOG_LEVELS
-
-import time
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, Text
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])

+ 4 - 13
backend/apps/webui/models/prompts.py

@@ -1,12 +1,9 @@
-from pydantic import BaseModel, ConfigDict
-from typing import Optional
 import time
-
-from sqlalchemy import String, Column, BigInteger, Text
+from typing import Optional
 
 from apps.webui.internal.db import Base, get_db
-
-import json
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, String, Text
 
 ####################
 # Prompts DB Schema
@@ -45,7 +42,6 @@ class PromptForm(BaseModel):
 
 
 class PromptsTable:
-
     def insert_new_prompt(
         self, user_id: str, form_data: PromptForm
     ) -> Optional[PromptModel]:
@@ -61,7 +57,6 @@ class PromptsTable:
 
         try:
             with get_db() as db:
-
                 result = Prompt(**prompt.dict())
                 db.add(result)
                 db.commit()
@@ -70,13 +65,12 @@ class PromptsTable:
                     return PromptModel.model_validate(result)
                 else:
                     return None
-        except Exception as e:
+        except Exception:
             return None
 
     def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
         try:
             with get_db() as db:
-
                 prompt = db.query(Prompt).filter_by(command=command).first()
                 return PromptModel.model_validate(prompt)
         except Exception:
@@ -84,7 +78,6 @@ class PromptsTable:
 
     def get_prompts(self) -> list[PromptModel]:
         with get_db() as db:
-
             return [
                 PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
             ]
@@ -94,7 +87,6 @@ class PromptsTable:
     ) -> Optional[PromptModel]:
         try:
             with get_db() as db:
-
                 prompt = db.query(Prompt).filter_by(command=command).first()
                 prompt.title = form_data.title
                 prompt.content = form_data.content
@@ -107,7 +99,6 @@ class PromptsTable:
     def delete_prompt_by_command(self, command: str) -> bool:
         try:
             with get_db() as db:
-
                 db.query(Prompt).filter_by(command=command).delete()
                 db.commit()
 

+ 7 - 17
backend/apps/webui/models/tags.py

@@ -1,16 +1,12 @@
-from pydantic import BaseModel, ConfigDict
-from typing import Optional
-
-import json
-import uuid
-import time
 import logging
-
-from sqlalchemy import String, Column, BigInteger, Text
+import time
+import uuid
+from typing import Optional
 
 from apps.webui.internal.db import Base, get_db
-
 from env import SRC_LOG_LEVELS
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, String, Text
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -77,10 +73,8 @@ class ChatTagsResponse(BaseModel):
 
 
 class TagTable:
-
     def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
         with get_db() as db:
-
             id = str(uuid.uuid4())
             tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
             try:
@@ -92,7 +86,7 @@ class TagTable:
                     return TagModel.model_validate(result)
                 else:
                     return None
-            except Exception as e:
+            except Exception:
                 return None
 
     def get_tag_by_name_and_user_id(
@@ -102,7 +96,7 @@ class TagTable:
             with get_db() as db:
                 tag = db.query(Tag).filter_by(name=name, user_id=user_id).first()
                 return TagModel.model_validate(tag)
-        except Exception as e:
+        except Exception:
             return None
 
     def add_tag_to_chat(
@@ -161,7 +155,6 @@ class TagTable:
         self, chat_id: str, user_id: str
     ) -> list[TagModel]:
         with get_db() as db:
-
             tag_names = [
                 chat_id_tag.tag_name
                 for chat_id_tag in (
@@ -186,7 +179,6 @@ class TagTable:
         self, tag_name: str, user_id: str
     ) -> list[ChatIdTagModel]:
         with get_db() as db:
-
             return [
                 ChatIdTagModel.model_validate(chat_id_tag)
                 for chat_id_tag in (
@@ -201,7 +193,6 @@ class TagTable:
         self, tag_name: str, user_id: str
     ) -> int:
         with get_db() as db:
-
             return (
                 db.query(ChatIdTag)
                 .filter_by(tag_name=tag_name, user_id=user_id)
@@ -236,7 +227,6 @@ class TagTable:
     ) -> bool:
         try:
             with get_db() as db:
-
                 res = (
                     db.query(ChatIdTag)
                     .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)

+ 4 - 15
backend/apps/webui/models/tools.py

@@ -1,17 +1,12 @@
-from pydantic import BaseModel, ConfigDict
-from typing import Optional
-import time
 import logging
-from sqlalchemy import String, Column, BigInteger, Text
+import time
+from typing import Optional
 
 from apps.webui.internal.db import Base, JSONField, get_db
 from apps.webui.models.users import Users
-
-import json
-import copy
-
-
 from env import SRC_LOG_LEVELS
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, String, Text
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -79,13 +74,10 @@ class ToolValves(BaseModel):
 
 
 class ToolsTable:
-
     def insert_new_tool(
         self, user_id: str, form_data: ToolForm, specs: list[dict]
     ) -> Optional[ToolModel]:
-
         with get_db() as db:
-
             tool = ToolModel(
                 **{
                     **form_data.model_dump(),
@@ -112,7 +104,6 @@ class ToolsTable:
     def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
         try:
             with get_db() as db:
-
                 tool = db.get(Tool, id)
                 return ToolModel.model_validate(tool)
         except Exception:
@@ -125,7 +116,6 @@ class ToolsTable:
     def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
         try:
             with get_db() as db:
-
                 tool = db.get(Tool, id)
                 return tool.valves if tool.valves else {}
         except Exception as e:
@@ -135,7 +125,6 @@ class ToolsTable:
     def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
         try:
             with get_db() as db:
-
                 db.query(Tool).filter_by(id=id).update(
                     {"valves": valves, "updated_at": int(time.time())}
                 )

+ 6 - 7
backend/apps/webui/models/users.py

@@ -1,11 +1,10 @@
-from pydantic import BaseModel, ConfigDict
-from typing import Optional
 import time
-
-from sqlalchemy import String, Column, BigInteger, Text
+from typing import Optional
 
 from apps.webui.internal.db import Base, JSONField, get_db
 from apps.webui.models.chats import Chats
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, String, Text
 
 ####################
 # User DB Schema
@@ -113,7 +112,7 @@ class UsersTable:
             with get_db() as db:
                 user = db.query(User).filter_by(id=id).first()
                 return UserModel.model_validate(user)
-        except Exception as e:
+        except Exception:
             return None
 
     def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
@@ -221,7 +220,7 @@ class UsersTable:
                 user = db.query(User).filter_by(id=id).first()
                 return UserModel.model_validate(user)
                 # return UserModel(**user.dict())
-        except Exception as e:
+        except Exception:
             return None
 
     def delete_user_by_id(self, id: str) -> bool:
@@ -255,7 +254,7 @@ class UsersTable:
             with get_db() as db:
                 user = db.query(User).filter_by(id=id).first()
                 return user.api_key
-        except Exception as e:
+        except Exception:
             return None
 
 

+ 16 - 28
backend/apps/webui/routers/auths.py

@@ -1,43 +1,33 @@
-import logging
-
-from fastapi import Request, UploadFile, File
-from fastapi import Depends, HTTPException, status
-from fastapi.responses import Response
-
-from fastapi import APIRouter
-from pydantic import BaseModel
 import re
 import uuid
-import csv
 
 from apps.webui.models.auths import (
+    AddUserForm,
+    ApiKey,
+    Auths,
     SigninForm,
+    SigninResponse,
     SignupForm,
-    AddUserForm,
-    UpdateProfileForm,
     UpdatePasswordForm,
+    UpdateProfileForm,
     UserResponse,
-    SigninResponse,
-    Auths,
-    ApiKey,
 )
 from apps.webui.models.users import Users
-
+from config import WEBUI_AUTH
+from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
+from env import WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+from fastapi.responses import Response
+from pydantic import BaseModel
+from utils.misc import parse_duration, validate_email_format
 from utils.utils import (
-    get_password_hash,
-    get_current_user,
-    get_admin_user,
-    create_token,
     create_api_key,
+    create_token,
+    get_admin_user,
+    get_current_user,
+    get_password_hash,
 )
-from utils.misc import parse_duration, validate_email_format
 from utils.webhook import post_webhook
-from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
-from config import (
-    WEBUI_AUTH,
-    WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
-    WEBUI_AUTH_TRUSTED_NAME_HEADER,
-)
 
 router = APIRouter()
 
@@ -273,7 +263,6 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
 
 @router.post("/add", response_model=SigninResponse)
 async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
-
     if not validate_email_format(form_data.email.lower()):
         raise HTTPException(
             status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
@@ -283,7 +272,6 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
 
     try:
-
         print(form_data)
         hashed = get_password_hash(form_data.password)
         user = Auths.insert_new_auth(

+ 8 - 31
backend/apps/webui/routers/chats.py

@@ -1,34 +1,15 @@
-from fastapi import Depends, Request, HTTPException, status
-from datetime import datetime, timedelta
-from typing import Union, Optional
-from utils.utils import get_verified_user, get_admin_user
-from fastapi import APIRouter
-from pydantic import BaseModel
 import json
 import logging
+from typing import Optional
 
-from apps.webui.models.users import Users
-from apps.webui.models.chats import (
-    ChatModel,
-    ChatResponse,
-    ChatTitleForm,
-    ChatForm,
-    ChatTitleIdResponse,
-    Chats,
-)
-
-
-from apps.webui.models.tags import (
-    TagModel,
-    ChatIdTagModel,
-    ChatIdTagForm,
-    ChatTagsResponse,
-    Tags,
-)
-
+from apps.webui.models.chats import ChatForm, ChatResponse, Chats, ChatTitleIdResponse
+from apps.webui.models.tags import ChatIdTagForm, ChatIdTagModel, TagModel, Tags
+from config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
 from constants import ERROR_MESSAGES
-
-from config import SRC_LOG_LEVELS, ENABLE_ADMIN_EXPORT, ENABLE_ADMIN_CHAT_ACCESS
+from env import SRC_LOG_LEVELS
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+from pydantic import BaseModel
+from utils.utils import get_admin_user, get_verified_user
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -61,7 +42,6 @@ async def get_session_user_chat_list(
 
 @router.delete("/", response_model=bool)
 async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
-
     if (
         user.role == "user"
         and not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]
@@ -220,7 +200,6 @@ class TagNameForm(BaseModel):
 async def get_user_chat_list_by_tag_name(
     form_data: TagNameForm, user=Depends(get_verified_user)
 ):
-
     chat_ids = [
         chat_id_tag.chat_id
         for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
@@ -299,7 +278,6 @@ async def update_chat_by_id(
 
 @router.delete("/{id}", response_model=bool)
 async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
-
     if user.role == "admin":
         result = Chats.delete_chat_by_id(id)
         return result
@@ -323,7 +301,6 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
 async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
-
         chat_body = json.loads(chat.chat)
         updated_chat = {
             **chat_body,

+ 3 - 21
backend/apps/webui/routers/configs.py

@@ -1,25 +1,7 @@
-from fastapi import Response, Request
-from fastapi import Depends, FastAPI, HTTPException, status
-from datetime import datetime, timedelta
-from typing import Union
-
-from fastapi import APIRouter
-from pydantic import BaseModel
-import time
-import uuid
-
 from config import BannerModel
-
-from apps.webui.models.users import Users
-
-from utils.utils import (
-    get_password_hash,
-    get_verified_user,
-    get_admin_user,
-    create_token,
-)
-from utils.misc import get_gravatar_url, validate_email_format
-from constants import ERROR_MESSAGES
+from fastapi import APIRouter, Depends, Request
+from pydantic import BaseModel
+from utils.utils import get_admin_user, get_verified_user
 
 router = APIRouter()
 

+ 6 - 11
backend/apps/webui/routers/documents.py

@@ -1,21 +1,16 @@
-from fastapi import Depends, FastAPI, HTTPException, status
-from datetime import datetime, timedelta
-from typing import Union, Optional
-
-from fastapi import APIRouter
-from pydantic import BaseModel
 import json
+from typing import Optional
 
 from apps.webui.models.documents import (
-    Documents,
     DocumentForm,
-    DocumentUpdateForm,
-    DocumentModel,
     DocumentResponse,
+    Documents,
+    DocumentUpdateForm,
 )
-
-from utils.utils import get_verified_user, get_admin_user
 from constants import ERROR_MESSAGES
+from fastapi import APIRouter, Depends, HTTPException, status
+from pydantic import BaseModel
+from utils.utils import get_admin_user, get_verified_user
 
 router = APIRouter()
 

+ 11 - 36
backend/apps/webui/routers/files.py

@@ -1,42 +1,17 @@
-from fastapi import (
-    Depends,
-    FastAPI,
-    HTTPException,
-    status,
-    Request,
-    UploadFile,
-    File,
-    Form,
-)
-
-
-from datetime import datetime, timedelta
-from typing import Union, Optional
-from pathlib import Path
-
-from fastapi import APIRouter
-from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
-
-from pydantic import BaseModel
-import json
-
-from apps.webui.models.files import (
-    Files,
-    FileForm,
-    FileModel,
-    FileModelResponse,
-)
-from utils.utils import get_verified_user, get_admin_user
-from constants import ERROR_MESSAGES
-
-from importlib import util
+import logging
 import os
+import shutil
 import uuid
-import os, shutil, logging, re
-
-
-from config import SRC_LOG_LEVELS, UPLOAD_DIR
+from pathlib import Path
+from typing import Optional
 
+from apps.webui.models.files import FileForm, FileModel, Files
+from config import UPLOAD_DIR
+from constants import ERROR_MESSAGES
+from env import SRC_LOG_LEVELS
+from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
+from fastapi.responses import FileResponse
+from utils.utils import get_admin_user, get_verified_user
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])

+ 7 - 17
backend/apps/webui/routers/functions.py

@@ -1,27 +1,18 @@
-from fastapi import Depends, FastAPI, HTTPException, status, Request
-from datetime import datetime, timedelta
-from typing import Union, Optional
-
-from fastapi import APIRouter
-from pydantic import BaseModel
-import json
+import os
+from pathlib import Path
+from typing import Optional
 
 from apps.webui.models.functions import (
-    Functions,
     FunctionForm,
     FunctionModel,
     FunctionResponse,
+    Functions,
 )
 from apps.webui.utils import load_function_module_by_id
-from utils.utils import get_verified_user, get_admin_user
+from config import CACHE_DIR, FUNCTIONS_DIR
 from constants import ERROR_MESSAGES
-
-from importlib import util
-import os
-from pathlib import Path
-
-from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR
-
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+from utils.utils import get_admin_user, get_verified_user
 
 router = APIRouter()
 
@@ -304,7 +295,6 @@ async def update_function_valves_by_id(
 ):
     function = Functions.get_function_by_id(id)
     if function:
-
         if id in request.app.state.FUNCTIONS:
             function_module = request.app.state.FUNCTIONS[id]
         else:

+ 5 - 11
backend/apps/webui/routers/memories.py

@@ -1,18 +1,12 @@
-from fastapi import Response, Request
-from fastapi import Depends, FastAPI, HTTPException, status
-from datetime import datetime, timedelta
-from typing import Union, Optional
-
-from fastapi import APIRouter
-from pydantic import BaseModel
 import logging
+from typing import Optional
 
 from apps.webui.models.memories import Memories, MemoryModel
-
+from config import CHROMA_CLIENT
+from env import SRC_LOG_LEVELS
+from fastapi import APIRouter, Depends, HTTPException, Request
+from pydantic import BaseModel
 from utils.utils import get_verified_user
-from constants import ERROR_MESSAGES
-
-from config import SRC_LOG_LEVELS, CHROMA_CLIENT
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])

+ 4 - 10
backend/apps/webui/routers/models.py

@@ -1,15 +1,9 @@
-from fastapi import Depends, FastAPI, HTTPException, status, Request
-from datetime import datetime, timedelta
-from typing import Union, Optional
+from typing import Optional
 
-from fastapi import APIRouter
-from pydantic import BaseModel
-import json
-
-from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
-
-from utils.utils import get_verified_user, get_admin_user
+from apps.webui.models.models import ModelForm, ModelModel, ModelResponse, Models
 from constants import ERROR_MESSAGES
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+from utils.utils import get_admin_user, get_verified_user
 
 router = APIRouter()
 

+ 4 - 10
backend/apps/webui/routers/prompts.py

@@ -1,15 +1,9 @@
-from fastapi import Depends, FastAPI, HTTPException, status
-from datetime import datetime, timedelta
-from typing import Union, Optional
+from typing import Optional
 
-from fastapi import APIRouter
-from pydantic import BaseModel
-import json
-
-from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
-
-from utils.utils import get_verified_user, get_admin_user
+from apps.webui.models.prompts import PromptForm, PromptModel, Prompts
 from constants import ERROR_MESSAGES
+from fastapi import APIRouter, Depends, HTTPException, status
+from utils.utils import get_admin_user, get_verified_user
 
 router = APIRouter()
 

+ 7 - 13
backend/apps/webui/routers/tools.py

@@ -1,20 +1,14 @@
-from fastapi import Depends, HTTPException, status, Request
+import os
+from pathlib import Path
 from typing import Optional
 
-from fastapi import APIRouter
-
-from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
+from apps.webui.models.tools import ToolForm, ToolModel, ToolResponse, Tools
 from apps.webui.utils import load_toolkit_module_by_id
-
-from utils.utils import get_admin_user, get_verified_user
-from utils.tools import get_tools_specs
+from config import CACHE_DIR, DATA_DIR
 from constants import ERROR_MESSAGES
-
-import os
-from pathlib import Path
-
-from config import DATA_DIR, CACHE_DIR
-
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+from utils.tools import get_tools_specs
+from utils.utils import get_admin_user, get_verified_user
 
 TOOLS_DIR = f"{DATA_DIR}/tools"
 os.makedirs(TOOLS_DIR, exist_ok=True)

+ 9 - 24
backend/apps/webui/routers/users.py

@@ -1,33 +1,20 @@
-from fastapi import Response, Request
-from fastapi import Depends, FastAPI, HTTPException, status
-from datetime import datetime, timedelta
-from typing import Union, Optional
-
-from fastapi import APIRouter
-from pydantic import BaseModel
-import time
-import uuid
 import logging
+from typing import Optional
 
+from apps.webui.models.auths import Auths
+from apps.webui.models.chats import Chats
 from apps.webui.models.users import (
     UserModel,
-    UserUpdateForm,
     UserRoleUpdateForm,
-    UserSettings,
     Users,
-)
-from apps.webui.models.auths import Auths
-from apps.webui.models.chats import Chats
-
-from utils.utils import (
-    get_verified_user,
-    get_password_hash,
-    get_current_user,
-    get_admin_user,
+    UserSettings,
+    UserUpdateForm,
 )
 from constants import ERROR_MESSAGES
-
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+from pydantic import BaseModel
+from utils.utils import get_admin_user, get_password_hash, get_verified_user
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -69,7 +56,6 @@ async def update_user_permissions(
 
 @router.post("/update/role", response_model=Optional[UserModel])
 async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
-
     if user.id != form_data.id and form_data.id != Users.get_first_user().id:
         return Users.update_user_role_by_id(form_data.id, form_data.role)
 
@@ -173,7 +159,6 @@ class UserResponse(BaseModel):
 
 @router.get("/{user_id}", response_model=UserResponse)
 async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
-
     # Check if user_id is a shared chat
     # If it is, get the user_id from the chat
     if user_id.startswith("shared-"):

+ 10 - 17
backend/apps/webui/routers/utils.py

@@ -1,23 +1,16 @@
-from pathlib import Path
 import site
+from pathlib import Path
 
-from fastapi import APIRouter, UploadFile, File, Response
-from fastapi import Depends, HTTPException, status
-from starlette.responses import StreamingResponse, FileResponse
-from pydantic import BaseModel
-
-
-from fpdf import FPDF
-import markdown
 import black
-
-
-from utils.utils import get_admin_user
-from utils.misc import calculate_sha256, get_gravatar_url
-
-from config import OLLAMA_BASE_URLS, DATA_DIR, UPLOAD_DIR, ENABLE_ADMIN_EXPORT
+import markdown
+from config import DATA_DIR, ENABLE_ADMIN_EXPORT
 from constants import ERROR_MESSAGES
-
+from fastapi import APIRouter, Depends, HTTPException, Response, status
+from fpdf import FPDF
+from pydantic import BaseModel
+from starlette.responses import FileResponse
+from utils.misc import get_gravatar_url
+from utils.utils import get_admin_user
 
 router = APIRouter()
 
@@ -115,7 +108,7 @@ async def download_chat_as_pdf(
     return Response(
         content=bytes(pdf_bytes),
         media_type="application/pdf",
-        headers={"Content-Disposition": f"attachment;filename=chat.pdf"},
+        headers={"Content-Disposition": "attachment;filename=chat.pdf"},
     )
 
 

+ 4 - 5
backend/apps/webui/utils.py

@@ -1,13 +1,12 @@
-from importlib import util
 import os
 import re
-import sys
 import subprocess
+import sys
+from importlib import util
 
-
-from apps.webui.models.tools import Tools
 from apps.webui.models.functions import Functions
-from config import TOOLS_DIR, FUNCTIONS_DIR
+from apps.webui.models.tools import Tools
+from config import FUNCTIONS_DIR, TOOLS_DIR
 
 
 def extract_frontmatter(file_path):

+ 16 - 44
backend/config.py

@@ -1,58 +1,30 @@
-from sqlalchemy import create_engine, Column, Integer, DateTime, JSON, func
-from contextlib import contextmanager
-
-
-import os
-import sys
+import json
 import logging
-import importlib.metadata
-import pkgutil
-from urllib.parse import urlparse
+import os
+import shutil
 from datetime import datetime
-
-import chromadb
-from chromadb import Settings
-from typing import TypeVar, Generic
-from pydantic import BaseModel
-from typing import Optional
-
 from pathlib import Path
-import json
-import yaml
+from typing import Generic, Optional, TypeVar
+from urllib.parse import urlparse
 
+import chromadb
 import requests
-import shutil
-
-
+import yaml
 from apps.webui.internal.db import Base, get_db
-
-from constants import ERROR_MESSAGES
-
+from chromadb import Settings
 from env import (
-    ENV,
-    VERSION,
-    SAFE_MODE,
-    GLOBAL_LOG_LEVEL,
-    SRC_LOG_LEVELS,
-    BASE_DIR,
-    DATA_DIR,
     BACKEND_DIR,
-    FRONTEND_BUILD_DIR,
-    WEBUI_NAME,
-    WEBUI_URL,
-    WEBUI_FAVICON_URL,
-    WEBUI_BUILD_HASH,
     CONFIG_DATA,
-    DATABASE_URL,
-    CHANGELOG,
+    DATA_DIR,
+    ENV,
+    FRONTEND_BUILD_DIR,
     WEBUI_AUTH,
-    WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
-    WEBUI_AUTH_TRUSTED_NAME_HEADER,
-    WEBUI_SECRET_KEY,
-    WEBUI_SESSION_COOKIE_SAME_SITE,
-    WEBUI_SESSION_COOKIE_SECURE,
+    WEBUI_FAVICON_URL,
+    WEBUI_NAME,
     log,
 )
+from pydantic import BaseModel
+from sqlalchemy import JSON, Column, DateTime, Integer, func
 
 
 class EndpointFilter(logging.Filter):
@@ -72,8 +44,8 @@ logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
 def run_migrations():
     print("Running migrations")
     try:
-        from alembic.config import Config
         from alembic import command
+        from alembic.config import Config
 
         alembic_cfg = Config(BACKEND_DIR / "alembic.ini")
         command.upgrade(alembic_cfg, "head")

+ 6 - 12
backend/env.py

@@ -1,19 +1,13 @@
-from pathlib import Path
-import os
-import logging
-import sys
-import json
-
-
 import importlib.metadata
+import json
+import logging
+import os
 import pkgutil
-from urllib.parse import urlparse
-from datetime import datetime
-
+import sys
+from pathlib import Path
 
 import markdown
 from bs4 import BeautifulSoup
-
 from constants import ERROR_MESSAGES
 
 ####################################
@@ -26,7 +20,7 @@ BASE_DIR = BACKEND_DIR.parent  # the path containing the backend/
 print(BASE_DIR)
 
 try:
-    from dotenv import load_dotenv, find_dotenv
+    from dotenv import find_dotenv, load_dotenv
 
     load_dotenv(find_dotenv(str(BASE_DIR / ".env")))
 except ImportError:

+ 92 - 98
backend/main.py

@@ -1,130 +1,124 @@
 import base64
-import uuid
-from contextlib import asynccontextmanager
-from authlib.integrations.starlette_client import OAuth
-from authlib.oidc.core import UserInfo
+import inspect
 import json
-import time
-import os
-import sys
 import logging
-import aiohttp
-import requests
 import mimetypes
+import os
 import shutil
-import inspect
+import sys
+import time
+import uuid
+from contextlib import asynccontextmanager
 from typing import Optional
 
-from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
-from fastapi.staticfiles import StaticFiles
-from fastapi.responses import JSONResponse
-from fastapi import HTTPException
-from fastapi.middleware.cors import CORSMiddleware
-from sqlalchemy import text
-from starlette.exceptions import HTTPException as StarletteHTTPException
-from starlette.middleware.base import BaseHTTPMiddleware
-from starlette.middleware.sessions import SessionMiddleware
-from starlette.responses import StreamingResponse, Response, RedirectResponse
-
-
-from apps.socket.main import app as socket_app, get_event_emitter, get_event_call
+import aiohttp
+import requests
+from apps.audio.main import app as audio_app
+from apps.images.main import app as images_app
+from apps.ollama.main import app as ollama_app
 from apps.ollama.main import (
-    app as ollama_app,
-    get_all_models as get_ollama_models,
     generate_openai_chat_completion as generate_ollama_chat_completion,
 )
-from apps.openai.main import (
-    app as openai_app,
-    get_all_models as get_openai_models,
-    generate_chat_completion as generate_openai_chat_completion,
-)
-
-from apps.audio.main import app as audio_app
-from apps.images.main import app as images_app
+from apps.ollama.main import get_all_models as get_ollama_models
+from apps.openai.main import app as openai_app
+from apps.openai.main import generate_chat_completion as generate_openai_chat_completion
+from apps.openai.main import get_all_models as get_openai_models
 from apps.rag.main import app as rag_app
-from apps.webui.main import (
-    app as webui_app,
-    get_pipe_models,
-    generate_function_chat_completion,
-)
+from apps.rag.utils import get_rag_context, rag_template
+from apps.socket.main import app as socket_app
+from apps.socket.main import get_event_call, get_event_emitter
 from apps.webui.internal.db import Session
-
-
-from pydantic import BaseModel
-
+from apps.webui.main import app as webui_app
+from apps.webui.main import generate_function_chat_completion, get_pipe_models
 from apps.webui.models.auths import Auths
-from apps.webui.models.models import Models
 from apps.webui.models.functions import Functions
-from apps.webui.models.users import Users, UserModel
-
+from apps.webui.models.models import Models
+from apps.webui.models.users import UserModel, Users
 from apps.webui.utils import load_function_module_by_id
-
-from utils.utils import (
-    get_admin_user,
-    get_verified_user,
-    get_current_user,
-    get_http_authorization_cred,
-    get_password_hash,
-    create_token,
-    decode_token,
-)
-from utils.task import (
-    title_generation_template,
-    search_query_generation_template,
-    tools_function_calling_generation_template,
-    moa_response_generation_template,
-)
-
-from utils.tools import get_tools
-from utils.misc import (
-    get_last_user_message,
-    add_or_update_system_message,
-    prepend_to_first_user_message_content,
-    parse_duration,
-)
-
-from apps.rag.utils import get_rag_context, rag_template
-
+from authlib.integrations.starlette_client import OAuth
+from authlib.oidc.core import UserInfo
 from config import (
-    run_migrations,
-    WEBUI_NAME,
-    WEBUI_URL,
-    WEBUI_AUTH,
-    ENV,
-    VERSION,
-    CHANGELOG,
-    FRONTEND_BUILD_DIR,
     CACHE_DIR,
-    STATIC_DIR,
+    CORS_ALLOW_ORIGIN,
     DEFAULT_LOCALE,
-    ENABLE_OPENAI_API,
-    ENABLE_OLLAMA_API,
+    ENABLE_ADMIN_CHAT_ACCESS,
+    ENABLE_ADMIN_EXPORT,
     ENABLE_MODEL_FILTER,
+    ENABLE_OAUTH_SIGNUP,
+    ENABLE_OLLAMA_API,
+    ENABLE_OPENAI_API,
+    ENV,
+    FRONTEND_BUILD_DIR,
     MODEL_FILTER_LIST,
-    GLOBAL_LOG_LEVEL,
-    SRC_LOG_LEVELS,
-    WEBHOOK_URL,
-    ENABLE_ADMIN_EXPORT,
-    WEBUI_BUILD_HASH,
+    OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
+    OAUTH_PROVIDERS,
+    SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
+    SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
+    STATIC_DIR,
     TASK_MODEL,
     TASK_MODEL_EXTERNAL,
     TITLE_GENERATION_PROMPT_TEMPLATE,
-    SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
-    SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
     TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
+    WEBHOOK_URL,
+    WEBUI_AUTH,
+    WEBUI_NAME,
+    AppConfig,
+    run_migrations,
+)
+from constants import ERROR_MESSAGES, TASKS, WEBHOOK_MESSAGES
+from env import (
+    CHANGELOG,
+    GLOBAL_LOG_LEVEL,
     SAFE_MODE,
-    OAUTH_PROVIDERS,
-    ENABLE_OAUTH_SIGNUP,
-    OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
+    SRC_LOG_LEVELS,
+    VERSION,
+    WEBUI_BUILD_HASH,
     WEBUI_SECRET_KEY,
     WEBUI_SESSION_COOKIE_SAME_SITE,
     WEBUI_SESSION_COOKIE_SECURE,
-    ENABLE_ADMIN_CHAT_ACCESS,
-    AppConfig,
-    CORS_ALLOW_ORIGIN,
+    WEBUI_URL,
+)
+from fastapi import (
+    Depends,
+    FastAPI,
+    File,
+    Form,
+    HTTPException,
+    Request,
+    UploadFile,
+    status,
+)
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import JSONResponse
+from fastapi.staticfiles import StaticFiles
+from pydantic import BaseModel
+from sqlalchemy import text
+from starlette.exceptions import HTTPException as StarletteHTTPException
+from starlette.middleware.base import BaseHTTPMiddleware
+from starlette.middleware.sessions import SessionMiddleware
+from starlette.responses import RedirectResponse, Response, StreamingResponse
+from utils.misc import (
+    add_or_update_system_message,
+    get_last_user_message,
+    parse_duration,
+    prepend_to_first_user_message_content,
+)
+from utils.task import (
+    moa_response_generation_template,
+    search_query_generation_template,
+    title_generation_template,
+    tools_function_calling_generation_template,
+)
+from utils.tools import get_tools
+from utils.utils import (
+    create_token,
+    decode_token,
+    get_admin_user,
+    get_current_user,
+    get_http_authorization_cred,
+    get_password_hash,
+    get_verified_user,
 )
-
-from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
 from utils.webhook import post_webhook
 
 if SAFE_MODE:

+ 1 - 16
backend/migrations/env.py

@@ -1,24 +1,9 @@
-import os
 from logging.config import fileConfig
 
-from sqlalchemy import engine_from_config
-from sqlalchemy import pool
-
 from alembic import context
-
 from apps.webui.models.auths import Auth
-from apps.webui.models.chats import Chat
-from apps.webui.models.documents import Document
-from apps.webui.models.memories import Memory
-from apps.webui.models.models import Model
-from apps.webui.models.prompts import Prompt
-from apps.webui.models.tags import Tag, ChatIdTag
-from apps.webui.models.tools import Tool
-from apps.webui.models.users import User
-from apps.webui.models.files import File
-from apps.webui.models.functions import Function
-
 from env import DATABASE_URL
+from sqlalchemy import engine_from_config, pool
 
 # this is the Alembic Config object, which provides
 # access to the values within the .ini file in use.

+ 3 - 3
backend/migrations/versions/7e5b5dc7342b_init.py

@@ -1,16 +1,16 @@
 """init
 
 Revision ID: 7e5b5dc7342b
-Revises: 
+Revises:
 Create Date: 2024-06-24 13:15:33.808998
 
 """
 
 from typing import Sequence, Union
 
-from alembic import op
-import sqlalchemy as sa
 import apps.webui.internal.db
+import sqlalchemy as sa
+from alembic import op
 from migrations.util import get_existing_tables
 
 # revision identifiers, used by Alembic.

+ 1 - 3
backend/migrations/versions/ca81bd47c050_add_config_table.py

@@ -8,10 +8,8 @@ Create Date: 2024-08-25 15:26:35.241684
 
 from typing import Sequence, Union
 
-from alembic import op
 import sqlalchemy as sa
-import apps.webui.internal.db
-
+from alembic import op
 
 # revision identifiers, used by Alembic.
 revision: str = "ca81bd47c050"

+ 1 - 3
backend/test/apps/webui/routers/test_auths.py

@@ -1,5 +1,3 @@
-import pytest
-
 from test.util.abstract_integration_test import AbstractPostgresTest
 from test.util.mock_user import mock_webui_user
 
@@ -9,8 +7,8 @@ class TestAuths(AbstractPostgresTest):
 
     def setup_class(cls):
         super().setup_class()
-        from apps.webui.models.users import Users
         from apps.webui.models.auths import Auths
+        from apps.webui.models.users import Users
 
         cls.users = Users
         cls.auths = Auths

+ 1 - 3
backend/test/apps/webui/routers/test_chats.py

@@ -5,7 +5,6 @@ from test.util.mock_user import mock_webui_user
 
 
 class TestChats(AbstractPostgresTest):
-
     BASE_PATH = "/api/v1/chats"
 
     def setup_class(cls):
@@ -13,8 +12,7 @@ class TestChats(AbstractPostgresTest):
 
     def setup_method(self):
         super().setup_method()
-        from apps.webui.models.chats import ChatForm
-        from apps.webui.models.chats import Chats
+        from apps.webui.models.chats import ChatForm, Chats
 
         self.chats = Chats
         self.chats.insert_new_chat(

+ 0 - 1
backend/test/apps/webui/routers/test_documents.py

@@ -3,7 +3,6 @@ from test.util.mock_user import mock_webui_user
 
 
 class TestDocuments(AbstractPostgresTest):
-
     BASE_PATH = "/api/v1/documents"
 
     def setup_class(cls):

+ 0 - 1
backend/test/apps/webui/routers/test_models.py

@@ -3,7 +3,6 @@ from test.util.mock_user import mock_webui_user
 
 
 class TestModels(AbstractPostgresTest):
-
     BASE_PATH = "/api/v1/models"
 
     def setup_class(cls):

+ 0 - 1
backend/test/apps/webui/routers/test_prompts.py

@@ -3,7 +3,6 @@ from test.util.mock_user import mock_webui_user
 
 
 class TestPrompts(AbstractPostgresTest):
-
     BASE_PATH = "/api/v1/prompts"
 
     def test_prompts(self):

+ 0 - 1
backend/test/apps/webui/routers/test_users.py

@@ -21,7 +21,6 @@ def _assert_user(data, id, **kwargs):
 
 
 class TestUsers(AbstractPostgresTest):
-
     BASE_PATH = "/api/v1/users"
 
     def setup_class(cls):

+ 4 - 4
backend/utils/misc.py

@@ -1,10 +1,10 @@
-from pathlib import Path
 import hashlib
 import re
-from datetime import timedelta
-from typing import Optional, Callable
-import uuid
 import time
+import uuid
+from datetime import timedelta
+from pathlib import Path
+from typing import Callable, Optional
 
 from utils.task import prompt_template
 

+ 1 - 1
backend/utils/schemas.py

@@ -1,7 +1,7 @@
 from ast import literal_eval
+from typing import Any, Literal, Optional, Type
 
 from pydantic import BaseModel, Field, create_model
-from typing import Any, Optional, Type, Literal
 
 
 def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]:

+ 1 - 2
backend/utils/task.py

@@ -1,6 +1,5 @@
-import re
 import math
-
+import re
 from datetime import datetime
 from typing import Optional
 

+ 0 - 1
backend/utils/tools.py

@@ -5,7 +5,6 @@ from typing import Awaitable, Callable, get_type_hints
 from apps.webui.models.tools import Tools
 from apps.webui.models.users import UserModel
 from apps.webui.utils import load_toolkit_module_by_id
-
 from utils.schemas import json_schema_to_model
 
 log = logging.getLogger(__name__)

+ 8 - 9
backend/utils/utils.py

@@ -1,16 +1,15 @@
-from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
-from fastapi import HTTPException, status, Depends, Request
+import logging
+import uuid
+from datetime import UTC, datetime, timedelta
+from typing import Optional, Union
 
+import jwt
 from apps.webui.models.users import Users
-
-from typing import Union, Optional
 from constants import ERROR_MESSAGES
-from passlib.context import CryptContext
-from datetime import datetime, timedelta, UTC
-import jwt
-import uuid
-import logging
 from env import WEBUI_SECRET_KEY
+from fastapi import Depends, HTTPException, Request, status
+from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
+from passlib.context import CryptContext
 
 logging.getLogger("passlib").setLevel(logging.ERROR)
 

+ 3 - 2
backend/utils/webhook.py

@@ -1,8 +1,9 @@
 import json
-import requests
 import logging
 
-from config import SRC_LOG_LEVELS, VERSION, WEBUI_FAVICON_URL, WEBUI_NAME
+import requests
+from config import WEBUI_FAVICON_URL, WEBUI_NAME
+from env import SRC_LOG_LEVELS, VERSION
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])