Parcourir la source

Merge pull request #7896 from open-webui/dev

0.5.0
Timothy Jaeryang Baek il y a 4 mois
Parent
commit
22132e155a
100 fichiers modifiés avec 3896 ajouts et 5929 suppressions
  1. 29 0
      CHANGELOG.md
  2. 0 703
      backend/open_webui/apps/audio/main.py
  3. 0 1494
      backend/open_webui/apps/retrieval/main.py
  4. 0 22
      backend/open_webui/apps/retrieval/vector/connector.py
  5. 0 506
      backend/open_webui/apps/webui/main.py
  6. 121 4
      backend/open_webui/config.py
  7. 1 3
      backend/open_webui/env.py
  8. 316 0
      backend/open_webui/functions.py
  9. 2 2
      backend/open_webui/internal/db.py
  10. 0 0
      backend/open_webui/internal/migrations/001_initial_schema.py
  11. 0 0
      backend/open_webui/internal/migrations/002_add_local_sharing.py
  12. 0 0
      backend/open_webui/internal/migrations/003_add_auth_api_key.py
  13. 0 0
      backend/open_webui/internal/migrations/004_add_archived.py
  14. 0 0
      backend/open_webui/internal/migrations/005_add_updated_at.py
  15. 0 0
      backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py
  16. 0 0
      backend/open_webui/internal/migrations/007_add_user_last_active_at.py
  17. 0 0
      backend/open_webui/internal/migrations/008_add_memory.py
  18. 0 0
      backend/open_webui/internal/migrations/009_add_models.py
  19. 0 0
      backend/open_webui/internal/migrations/010_migrate_modelfiles_to_models.py
  20. 0 0
      backend/open_webui/internal/migrations/011_add_user_settings.py
  21. 0 0
      backend/open_webui/internal/migrations/012_add_tools.py
  22. 0 0
      backend/open_webui/internal/migrations/013_add_user_info.py
  23. 0 0
      backend/open_webui/internal/migrations/014_add_files.py
  24. 0 0
      backend/open_webui/internal/migrations/015_add_functions.py
  25. 0 0
      backend/open_webui/internal/migrations/016_add_valves_and_is_active.py
  26. 0 0
      backend/open_webui/internal/migrations/017_add_user_oauth_sub.py
  27. 0 0
      backend/open_webui/internal/migrations/018_add_function_is_global.py
  28. 0 0
      backend/open_webui/internal/wrappers.py
  29. 655 2288
      backend/open_webui/main.py
  30. 1 1
      backend/open_webui/migrations/env.py
  31. 1 1
      backend/open_webui/migrations/script.py.mako
  32. 48 0
      backend/open_webui/migrations/versions/57c599a3cb57_add_channel_table.py
  33. 26 0
      backend/open_webui/migrations/versions/7826ab40b532_update_file_table.py
  34. 2 2
      backend/open_webui/migrations/versions/7e5b5dc7342b_init.py
  35. 3 3
      backend/open_webui/models/auths.py
  36. 132 0
      backend/open_webui/models/channels.py
  37. 87 2
      backend/open_webui/models/chats.py
  38. 2 2
      backend/open_webui/models/feedbacks.py
  39. 6 1
      backend/open_webui/models/files.py
  40. 2 2
      backend/open_webui/models/folders.py
  41. 2 2
      backend/open_webui/models/functions.py
  42. 9 2
      backend/open_webui/models/groups.py
  43. 3 3
      backend/open_webui/models/knowledge.py
  44. 1 1
      backend/open_webui/models/memories.py
  45. 141 0
      backend/open_webui/models/messages.py
  46. 2 2
      backend/open_webui/models/models.py
  47. 2 2
      backend/open_webui/models/prompts.py
  48. 1 1
      backend/open_webui/models/tags.py
  49. 2 2
      backend/open_webui/models/tools.py
  50. 43 8
      backend/open_webui/models/users.py
  51. 4 2
      backend/open_webui/retrieval/loaders/main.py
  52. 0 0
      backend/open_webui/retrieval/loaders/youtube.py
  53. 0 0
      backend/open_webui/retrieval/models/colbert.py
  54. 4 2
      backend/open_webui/retrieval/utils.py
  55. 22 0
      backend/open_webui/retrieval/vector/connector.py
  56. 1 1
      backend/open_webui/retrieval/vector/dbs/chroma.py
  57. 1 1
      backend/open_webui/retrieval/vector/dbs/milvus.py
  58. 1 1
      backend/open_webui/retrieval/vector/dbs/opensearch.py
  59. 2 2
      backend/open_webui/retrieval/vector/dbs/pgvector.py
  60. 1 1
      backend/open_webui/retrieval/vector/dbs/qdrant.py
  61. 0 0
      backend/open_webui/retrieval/vector/main.py
  62. 1 1
      backend/open_webui/retrieval/web/bing.py
  63. 1 1
      backend/open_webui/retrieval/web/brave.py
  64. 1 1
      backend/open_webui/retrieval/web/duckduckgo.py
  65. 1 1
      backend/open_webui/retrieval/web/google_pse.py
  66. 1 1
      backend/open_webui/retrieval/web/jina_search.py
  67. 48 0
      backend/open_webui/retrieval/web/kagi.py
  68. 0 0
      backend/open_webui/retrieval/web/main.py
  69. 1 1
      backend/open_webui/retrieval/web/mojeek.py
  70. 1 1
      backend/open_webui/retrieval/web/searchapi.py
  71. 1 1
      backend/open_webui/retrieval/web/searxng.py
  72. 1 1
      backend/open_webui/retrieval/web/serper.py
  73. 1 1
      backend/open_webui/retrieval/web/serply.py
  74. 1 1
      backend/open_webui/retrieval/web/serpstack.py
  75. 1 1
      backend/open_webui/retrieval/web/tavily.py
  76. 0 0
      backend/open_webui/retrieval/web/testdata/bing.json
  77. 0 0
      backend/open_webui/retrieval/web/testdata/brave.json
  78. 0 0
      backend/open_webui/retrieval/web/testdata/google_pse.json
  79. 0 0
      backend/open_webui/retrieval/web/testdata/searchapi.json
  80. 0 0
      backend/open_webui/retrieval/web/testdata/searxng.json
  81. 0 0
      backend/open_webui/retrieval/web/testdata/serper.json
  82. 0 0
      backend/open_webui/retrieval/web/testdata/serply.json
  83. 0 0
      backend/open_webui/retrieval/web/testdata/serpstack.json
  84. 3 3
      backend/open_webui/retrieval/web/utils.py
  85. 713 0
      backend/open_webui/routers/audio.py
  86. 41 6
      backend/open_webui/routers/auths.py
  87. 394 0
      backend/open_webui/routers/channels.py
  88. 4 5
      backend/open_webui/routers/chats.py
  89. 1 1
      backend/open_webui/routers/configs.py
  90. 3 3
      backend/open_webui/routers/evaluations.py
  91. 37 17
      backend/open_webui/routers/files.py
  92. 3 3
      backend/open_webui/routers/folders.py
  93. 3 3
      backend/open_webui/routers/functions.py
  94. 2 2
      backend/open_webui/routers/groups.py
  95. 170 175
      backend/open_webui/routers/images.py
  96. 99 9
      backend/open_webui/routers/knowledge.py
  97. 3 3
      backend/open_webui/routers/memories.py
  98. 2 2
      backend/open_webui/routers/models.py
  99. 353 344
      backend/open_webui/routers/ollama.py
  100. 329 272
      backend/open_webui/routers/openai.py

+ 29 - 0
CHANGELOG.md

@@ -5,6 +5,35 @@ All notable changes to this project will be documented in this file.
 The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
 The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
 and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
 and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
 
 
+## [0.5.0] - 2024-12-25
+
+### Added
+
+- **💬 True Asynchronous Chat Support**: Create chats, navigate away, and return anytime with responses ready. Ideal for reasoning models and multi-agent workflows, enhancing multitasking like never before.
+- **🔔 Chat Completion Notifications**: Never miss a completed response. Receive instant in-UI notifications when a chat finishes in a non-active tab, keeping you updated while you work elsewhere.
+- **🌐 Notification Webhook Integration**: Get alerts via webhooks even when your tab is closed! Configure your webhook URL in Settings > Account and receive timely updates for long-running chats or external integration needs.
+- **📚 Channels (Beta)**: Explore Discord/Slack-style chat rooms designed for real-time collaboration between users and AIs. Build bots for channels and unlock asynchronous communication for proactive multi-agent workflows. Opt-in via Admin Settings > General. A Comprehensive Bot SDK tutorial (https://github.com/open-webui/bot) is incoming, so stay tuned!
+- **🖼️ Client-Side Image Compression**: Now compress images before upload (Settings > Interface), saving bandwidth and improving performance seamlessly.
+- **🛠️ OAuth Management for User Groups**: Enable group-level management via OAuth integration for enhanced control and scalability in collaborative environments.
+- **✅ Structured Output for Ollama**: Pass structured data output directly to Ollama, unlocking new possibilities for streamlined automation and precise data handling.
+- **📜 Offline Swagger Documentation**: Developer-friendly Swagger API docs are now available offline, ensuring full accessibility wherever you are.
+- **📸 Quick Screen Capture Button**: Effortlessly capture your screen with a single click from the message input menu.
+- **🌍 i18n Updates**: Improved and refined translations across several languages, including Ukrainian, German, Brazilian Portuguese, Catalan, and more, ensuring a seamless global user experience.
+
+### Fixed
+
+- **📋 Table Export to CSV**: Resolved issues with CSV export where headers were missing or errors occurred due to values with commas, ensuring smooth and reliable data handling.
+- **🔓 BYPASS_MODEL_ACCESS_CONTROL**: Fixed an issue where users could see models but couldn’t use them with 'BYPASS_MODEL_ACCESS_CONTROL=True', restoring proper functionality for environments leveraging this setting.
+
+### Changed
+
+- **💡 API Key Authentication Restriction**: Narrowed API key auth permissions to '/api/models' and '/api/chat/completions' for enhanced security and better API governance.
+- **⚙️ Backend Overhaul for Performance**: Major backend restructuring; a heads-up that some "Functions" using internal variables may face compatibility issues. Moving forward, websocket support is mandatory to ensure Open WebUI operates seamlessly.
+
+### Removed
+
+- **⚠️ Legacy Functionality Clean-Up**: Deprecated outdated backend systems that were non-essential or overlapped with newer implementations, allowing for a leaner, more efficient platform.
+
 ## [0.4.8] - 2024-12-07
 ## [0.4.8] - 2024-12-07
 
 
 ### Added
 ### Added

+ 0 - 703
backend/open_webui/apps/audio/main.py

@@ -1,703 +0,0 @@
-import hashlib
-import json
-import logging
-import os
-import uuid
-from functools import lru_cache
-from pathlib import Path
-from pydub import AudioSegment
-from pydub.silence import split_on_silence
-
-import aiohttp
-import aiofiles
-import requests
-from open_webui.config import (
-    AUDIO_STT_ENGINE,
-    AUDIO_STT_MODEL,
-    AUDIO_STT_OPENAI_API_BASE_URL,
-    AUDIO_STT_OPENAI_API_KEY,
-    AUDIO_TTS_API_KEY,
-    AUDIO_TTS_ENGINE,
-    AUDIO_TTS_MODEL,
-    AUDIO_TTS_OPENAI_API_BASE_URL,
-    AUDIO_TTS_OPENAI_API_KEY,
-    AUDIO_TTS_SPLIT_ON,
-    AUDIO_TTS_VOICE,
-    AUDIO_TTS_AZURE_SPEECH_REGION,
-    AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
-    CACHE_DIR,
-    CORS_ALLOW_ORIGIN,
-    WHISPER_MODEL,
-    WHISPER_MODEL_AUTO_UPDATE,
-    WHISPER_MODEL_DIR,
-    AppConfig,
-)
-
-from open_webui.constants import ERROR_MESSAGES
-from open_webui.env import (
-    ENV,
-    SRC_LOG_LEVELS,
-    DEVICE_TYPE,
-    ENABLE_FORWARD_USER_INFO_HEADERS,
-)
-
-from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import FileResponse
-from pydantic import BaseModel
-from open_webui.utils.utils import get_admin_user, get_verified_user
-
-# Constants
-MAX_FILE_SIZE_MB = 25
-MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024  # Convert MB to bytes
-
-
-log = logging.getLogger(__name__)
-log.setLevel(SRC_LOG_LEVELS["AUDIO"])
-
-app = FastAPI(
-    docs_url="/docs" if ENV == "dev" else None,
-    openapi_url="/openapi.json" if ENV == "dev" else None,
-    redoc_url=None,
-)
-
-app.add_middleware(
-    CORSMiddleware,
-    allow_origins=CORS_ALLOW_ORIGIN,
-    allow_credentials=True,
-    allow_methods=["*"],
-    allow_headers=["*"],
-)
-
-app.state.config = AppConfig()
-
-app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
-app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
-app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
-app.state.config.STT_MODEL = AUDIO_STT_MODEL
-
-app.state.config.WHISPER_MODEL = WHISPER_MODEL
-app.state.faster_whisper_model = None
-
-app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
-app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
-app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
-app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
-app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
-app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
-app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
-
-
-app.state.speech_synthesiser = None
-app.state.speech_speaker_embeddings_dataset = None
-
-app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
-app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
-
-# setting device type for whisper model
-whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
-log.info(f"whisper_device_type: {whisper_device_type}")
-
-SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
-SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
-
-
-def set_faster_whisper_model(model: str, auto_update: bool = False):
-    if model and app.state.config.STT_ENGINE == "":
-        from faster_whisper import WhisperModel
-
-        faster_whisper_kwargs = {
-            "model_size_or_path": model,
-            "device": whisper_device_type,
-            "compute_type": "int8",
-            "download_root": WHISPER_MODEL_DIR,
-            "local_files_only": not auto_update,
-        }
-
-        try:
-            app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
-        except Exception:
-            log.warning(
-                "WhisperModel initialization failed, attempting download with local_files_only=False"
-            )
-            faster_whisper_kwargs["local_files_only"] = False
-            app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
-
-    else:
-        app.state.faster_whisper_model = None
-
-
-class TTSConfigForm(BaseModel):
-    OPENAI_API_BASE_URL: str
-    OPENAI_API_KEY: str
-    API_KEY: str
-    ENGINE: str
-    MODEL: str
-    VOICE: str
-    SPLIT_ON: str
-    AZURE_SPEECH_REGION: str
-    AZURE_SPEECH_OUTPUT_FORMAT: str
-
-
-class STTConfigForm(BaseModel):
-    OPENAI_API_BASE_URL: str
-    OPENAI_API_KEY: str
-    ENGINE: str
-    MODEL: str
-    WHISPER_MODEL: str
-
-
-class AudioConfigUpdateForm(BaseModel):
-    tts: TTSConfigForm
-    stt: STTConfigForm
-
-
-from pydub import AudioSegment
-from pydub.utils import mediainfo
-
-
-def is_mp4_audio(file_path):
-    """Check if the given file is an MP4 audio file."""
-    if not os.path.isfile(file_path):
-        print(f"File not found: {file_path}")
-        return False
-
-    info = mediainfo(file_path)
-    if (
-        info.get("codec_name") == "aac"
-        and info.get("codec_type") == "audio"
-        and info.get("codec_tag_string") == "mp4a"
-    ):
-        return True
-    return False
-
-
-def convert_mp4_to_wav(file_path, output_path):
-    """Convert MP4 audio file to WAV format."""
-    audio = AudioSegment.from_file(file_path, format="mp4")
-    audio.export(output_path, format="wav")
-    print(f"Converted {file_path} to {output_path}")
-
-
-@app.get("/config")
-async def get_audio_config(user=Depends(get_admin_user)):
-    return {
-        "tts": {
-            "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
-            "API_KEY": app.state.config.TTS_API_KEY,
-            "ENGINE": app.state.config.TTS_ENGINE,
-            "MODEL": app.state.config.TTS_MODEL,
-            "VOICE": app.state.config.TTS_VOICE,
-            "SPLIT_ON": app.state.config.TTS_SPLIT_ON,
-            "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
-            "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
-        },
-        "stt": {
-            "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
-            "ENGINE": app.state.config.STT_ENGINE,
-            "MODEL": app.state.config.STT_MODEL,
-            "WHISPER_MODEL": app.state.config.WHISPER_MODEL,
-        },
-    }
-
-
-@app.post("/config/update")
-async def update_audio_config(
-    form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
-):
-    app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
-    app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
-    app.state.config.TTS_API_KEY = form_data.tts.API_KEY
-    app.state.config.TTS_ENGINE = form_data.tts.ENGINE
-    app.state.config.TTS_MODEL = form_data.tts.MODEL
-    app.state.config.TTS_VOICE = form_data.tts.VOICE
-    app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
-    app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
-    app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
-        form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
-    )
-
-    app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
-    app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
-    app.state.config.STT_ENGINE = form_data.stt.ENGINE
-    app.state.config.STT_MODEL = form_data.stt.MODEL
-    app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
-    set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
-
-    return {
-        "tts": {
-            "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
-            "API_KEY": app.state.config.TTS_API_KEY,
-            "ENGINE": app.state.config.TTS_ENGINE,
-            "MODEL": app.state.config.TTS_MODEL,
-            "VOICE": app.state.config.TTS_VOICE,
-            "SPLIT_ON": app.state.config.TTS_SPLIT_ON,
-            "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
-            "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
-        },
-        "stt": {
-            "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
-            "ENGINE": app.state.config.STT_ENGINE,
-            "MODEL": app.state.config.STT_MODEL,
-            "WHISPER_MODEL": app.state.config.WHISPER_MODEL,
-        },
-    }
-
-
-def load_speech_pipeline():
-    from transformers import pipeline
-    from datasets import load_dataset
-
-    if app.state.speech_synthesiser is None:
-        app.state.speech_synthesiser = pipeline(
-            "text-to-speech", "microsoft/speecht5_tts"
-        )
-
-    if app.state.speech_speaker_embeddings_dataset is None:
-        app.state.speech_speaker_embeddings_dataset = load_dataset(
-            "Matthijs/cmu-arctic-xvectors", split="validation"
-        )
-
-
-@app.post("/speech")
-async def speech(request: Request, user=Depends(get_verified_user)):
-    body = await request.body()
-    name = hashlib.sha256(body).hexdigest()
-
-    file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
-    file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
-
-    # Check if the file already exists in the cache
-    if file_path.is_file():
-        return FileResponse(file_path)
-
-    if app.state.config.TTS_ENGINE == "openai":
-        headers = {}
-        headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
-        headers["Content-Type"] = "application/json"
-
-        if ENABLE_FORWARD_USER_INFO_HEADERS:
-            headers["X-OpenWebUI-User-Name"] = user.name
-            headers["X-OpenWebUI-User-Id"] = user.id
-            headers["X-OpenWebUI-User-Email"] = user.email
-            headers["X-OpenWebUI-User-Role"] = user.role
-
-        try:
-            body = body.decode("utf-8")
-            body = json.loads(body)
-            body["model"] = app.state.config.TTS_MODEL
-            body = json.dumps(body).encode("utf-8")
-        except Exception:
-            pass
-
-        try:
-            async with aiohttp.ClientSession() as session:
-                async with session.post(
-                    url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
-                    data=body,
-                    headers=headers,
-                ) as r:
-                    r.raise_for_status()
-                    async with aiofiles.open(file_path, "wb") as f:
-                        await f.write(await r.read())
-
-                    async with aiofiles.open(file_body_path, "w") as f:
-                        await f.write(json.dumps(json.loads(body.decode("utf-8"))))
-
-            return FileResponse(file_path)
-
-        except Exception as e:
-            log.exception(e)
-            error_detail = "Open WebUI: Server Connection Error"
-            try:
-                if r.status != 200:
-                    res = await r.json()
-                    if "error" in res:
-                        error_detail = f"External: {res['error']['message']}"
-            except Exception:
-                error_detail = f"External: {e}"
-
-            raise HTTPException(
-                status_code=getattr(r, "status", 500),
-                detail=error_detail,
-            )
-
-    elif app.state.config.TTS_ENGINE == "elevenlabs":
-        try:
-            payload = json.loads(body.decode("utf-8"))
-        except Exception as e:
-            log.exception(e)
-            raise HTTPException(status_code=400, detail="Invalid JSON payload")
-
-        voice_id = payload.get("voice", "")
-        if voice_id not in get_available_voices():
-            raise HTTPException(
-                status_code=400,
-                detail="Invalid voice id",
-            )
-
-        url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
-        headers = {
-            "Accept": "audio/mpeg",
-            "Content-Type": "application/json",
-            "xi-api-key": app.state.config.TTS_API_KEY,
-        }
-        data = {
-            "text": payload["input"],
-            "model_id": app.state.config.TTS_MODEL,
-            "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
-        }
-
-        try:
-            async with aiohttp.ClientSession() as session:
-                async with session.post(url, json=data, headers=headers) as r:
-                    r.raise_for_status()
-                    async with aiofiles.open(file_path, "wb") as f:
-                        await f.write(await r.read())
-
-                    async with aiofiles.open(file_body_path, "w") as f:
-                        await f.write(json.dumps(json.loads(body.decode("utf-8"))))
-
-            return FileResponse(file_path)
-
-        except Exception as e:
-            log.exception(e)
-            error_detail = "Open WebUI: Server Connection Error"
-            try:
-                if r.status != 200:
-                    res = await r.json()
-                    if "error" in res:
-                        error_detail = f"External: {res['error']['message']}"
-            except Exception:
-                error_detail = f"External: {e}"
-
-            raise HTTPException(
-                status_code=getattr(r, "status", 500),
-                detail=error_detail,
-            )
-
-    elif app.state.config.TTS_ENGINE == "azure":
-        try:
-            payload = json.loads(body.decode("utf-8"))
-        except Exception as e:
-            log.exception(e)
-            raise HTTPException(status_code=400, detail="Invalid JSON payload")
-
-        region = app.state.config.TTS_AZURE_SPEECH_REGION
-        language = app.state.config.TTS_VOICE
-        locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1])
-        output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
-        url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
-
-        headers = {
-            "Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY,
-            "Content-Type": "application/ssml+xml",
-            "X-Microsoft-OutputFormat": output_format,
-        }
-
-        data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
-                <voice name="{language}">{payload["input"]}</voice>
-            </speak>"""
-
-        try:
-            async with aiohttp.ClientSession() as session:
-                async with session.post(url, headers=headers, data=data) as response:
-                    if response.status == 200:
-                        async with aiofiles.open(file_path, "wb") as f:
-                            await f.write(await response.read())
-                        return FileResponse(file_path)
-                    else:
-                        error_msg = f"Error synthesizing speech - {response.reason}"
-                        log.error(error_msg)
-                        raise HTTPException(status_code=500, detail=error_msg)
-        except Exception as e:
-            log.exception(e)
-            raise HTTPException(status_code=500, detail=str(e))
-    elif app.state.config.TTS_ENGINE == "transformers":
-        payload = None
-        try:
-            payload = json.loads(body.decode("utf-8"))
-        except Exception as e:
-            log.exception(e)
-            raise HTTPException(status_code=400, detail="Invalid JSON payload")
-
-        import torch
-        import soundfile as sf
-
-        load_speech_pipeline()
-
-        embeddings_dataset = app.state.speech_speaker_embeddings_dataset
-
-        speaker_index = 6799
-        try:
-            speaker_index = embeddings_dataset["filename"].index(
-                app.state.config.TTS_MODEL
-            )
-        except Exception:
-            pass
-
-        speaker_embedding = torch.tensor(
-            embeddings_dataset[speaker_index]["xvector"]
-        ).unsqueeze(0)
-
-        speech = app.state.speech_synthesiser(
-            payload["input"],
-            forward_params={"speaker_embeddings": speaker_embedding},
-        )
-
-        sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
-        with open(file_body_path, "w") as f:
-            json.dump(json.loads(body.decode("utf-8")), f)
-
-        return FileResponse(file_path)
-
-
-def transcribe(file_path):
-    print("transcribe", file_path)
-    filename = os.path.basename(file_path)
-    file_dir = os.path.dirname(file_path)
-    id = filename.split(".")[0]
-
-    if app.state.config.STT_ENGINE == "":
-        if app.state.faster_whisper_model is None:
-            set_faster_whisper_model(app.state.config.WHISPER_MODEL)
-
-        model = app.state.faster_whisper_model
-        segments, info = model.transcribe(file_path, beam_size=5)
-        log.info(
-            "Detected language '%s' with probability %f"
-            % (info.language, info.language_probability)
-        )
-
-        transcript = "".join([segment.text for segment in list(segments)])
-        data = {"text": transcript.strip()}
-
-        # save the transcript to a json file
-        transcript_file = f"{file_dir}/{id}.json"
-        with open(transcript_file, "w") as f:
-            json.dump(data, f)
-
-        log.debug(data)
-        return data
-    elif app.state.config.STT_ENGINE == "openai":
-        if is_mp4_audio(file_path):
-            print("is_mp4_audio")
-            os.rename(file_path, file_path.replace(".wav", ".mp4"))
-            # Convert MP4 audio file to WAV format
-            convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
-
-        headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
-
-        files = {"file": (filename, open(file_path, "rb"))}
-        data = {"model": app.state.config.STT_MODEL}
-
-        log.debug(files, data)
-
-        r = None
-        try:
-            r = requests.post(
-                url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
-                headers=headers,
-                files=files,
-                data=data,
-            )
-
-            r.raise_for_status()
-
-            data = r.json()
-
-            # save the transcript to a json file
-            transcript_file = f"{file_dir}/{id}.json"
-            with open(transcript_file, "w") as f:
-                json.dump(data, f)
-
-            print(data)
-            return data
-        except Exception as e:
-            log.exception(e)
-            error_detail = "Open WebUI: Server Connection Error"
-            if r is not None:
-                try:
-                    res = r.json()
-                    if "error" in res:
-                        error_detail = f"External: {res['error']['message']}"
-                except Exception:
-                    error_detail = f"External: {e}"
-
-            raise Exception(error_detail)
-
-
-@app.post("/transcriptions")
-def transcription(
-    file: UploadFile = File(...),
-    user=Depends(get_verified_user),
-):
-    log.info(f"file.content_type: {file.content_type}")
-
-    if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
-        )
-
-    try:
-        ext = file.filename.split(".")[-1]
-        id = uuid.uuid4()
-
-        filename = f"{id}.{ext}"
-        contents = file.file.read()
-
-        file_dir = f"{CACHE_DIR}/audio/transcriptions"
-        os.makedirs(file_dir, exist_ok=True)
-        file_path = f"{file_dir}/{filename}"
-
-        with open(file_path, "wb") as f:
-            f.write(contents)
-
-        try:
-            if os.path.getsize(file_path) > MAX_FILE_SIZE:  # file is bigger than 25MB
-                log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB")
-                audio = AudioSegment.from_file(file_path)
-                audio = audio.set_frame_rate(16000).set_channels(1)  # Compress audio
-                compressed_path = f"{file_dir}/{id}_compressed.opus"
-                audio.export(compressed_path, format="opus", bitrate="32k")
-                log.debug(f"Compressed audio to {compressed_path}")
-                file_path = compressed_path
-
-                if (
-                    os.path.getsize(file_path) > MAX_FILE_SIZE
-                ):  # Still larger than 25MB after compression
-                    log.debug(
-                        f"Compressed file size is still larger than {MAX_FILE_SIZE_MB}MB: {os.path.getsize(file_path)}"
-                    )
-                    raise HTTPException(
-                        status_code=status.HTTP_400_BAD_REQUEST,
-                        detail=ERROR_MESSAGES.FILE_TOO_LARGE(
-                            size=f"{MAX_FILE_SIZE_MB}MB"
-                        ),
-                    )
-
-                data = transcribe(file_path)
-            else:
-                data = transcribe(file_path)
-
-            file_path = file_path.split("/")[-1]
-            return {**data, "filename": file_path}
-        except Exception as e:
-            log.exception(e)
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.DEFAULT(e),
-            )
-
-    except Exception as e:
-        log.exception(e)
-
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.DEFAULT(e),
-        )
-
-
-def get_available_models() -> list[dict]:
-    if app.state.config.TTS_ENGINE == "openai":
-        return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
-    elif app.state.config.TTS_ENGINE == "elevenlabs":
-        headers = {
-            "xi-api-key": app.state.config.TTS_API_KEY,
-            "Content-Type": "application/json",
-        }
-
-        try:
-            response = requests.get(
-                "https://api.elevenlabs.io/v1/models", headers=headers, timeout=5
-            )
-            response.raise_for_status()
-            models = response.json()
-            return [
-                {"name": model["name"], "id": model["model_id"]} for model in models
-            ]
-        except requests.RequestException as e:
-            log.error(f"Error fetching voices: {str(e)}")
-    return []
-
-
-@app.get("/models")
-async def get_models(user=Depends(get_verified_user)):
-    return {"models": get_available_models()}
-
-
-def get_available_voices() -> dict:
-    """Returns {voice_id: voice_name} dict"""
-    ret = {}
-    if app.state.config.TTS_ENGINE == "openai":
-        ret = {
-            "alloy": "alloy",
-            "echo": "echo",
-            "fable": "fable",
-            "onyx": "onyx",
-            "nova": "nova",
-            "shimmer": "shimmer",
-        }
-    elif app.state.config.TTS_ENGINE == "elevenlabs":
-        try:
-            ret = get_elevenlabs_voices()
-        except Exception:
-            # Avoided @lru_cache with exception
-            pass
-    elif app.state.config.TTS_ENGINE == "azure":
-        try:
-            region = app.state.config.TTS_AZURE_SPEECH_REGION
-            url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
-            headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY}
-
-            response = requests.get(url, headers=headers)
-            response.raise_for_status()
-            voices = response.json()
-            for voice in voices:
-                ret[voice["ShortName"]] = (
-                    f"{voice['DisplayName']} ({voice['ShortName']})"
-                )
-        except requests.RequestException as e:
-            log.error(f"Error fetching voices: {str(e)}")
-
-    return ret
-
-
-@lru_cache
-def get_elevenlabs_voices() -> dict:
-    """
-    Note, set the following in your .env file to use Elevenlabs:
-    AUDIO_TTS_ENGINE=elevenlabs
-    AUDIO_TTS_API_KEY=sk_...  # Your Elevenlabs API key
-    AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL  # From https://api.elevenlabs.io/v1/voices
-    AUDIO_TTS_MODEL=eleven_multilingual_v2
-    """
-    headers = {
-        "xi-api-key": app.state.config.TTS_API_KEY,
-        "Content-Type": "application/json",
-    }
-    try:
-        # TODO: Add retries
-        response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers)
-        response.raise_for_status()
-        voices_data = response.json()
-
-        voices = {}
-        for voice in voices_data.get("voices", []):
-            voices[voice["voice_id"]] = voice["name"]
-    except requests.RequestException as e:
-        # Avoid @lru_cache with exception
-        log.error(f"Error fetching voices: {str(e)}")
-        raise RuntimeError(f"Error fetching voices: {str(e)}")
-
-    return voices
-
-
-@app.get("/voices")
-async def get_voices(user=Depends(get_verified_user)):
-    return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]}

+ 0 - 1494
backend/open_webui/apps/retrieval/main.py

@@ -1,1494 +0,0 @@
-# TODO: Merge this with the webui_app and make it a single app
-
-import json
-import logging
-import mimetypes
-import os
-import shutil
-
-import uuid
-from datetime import datetime
-from pathlib import Path
-from typing import Iterator, Optional, Sequence, Union
-
-from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
-from fastapi.middleware.cors import CORSMiddleware
-from pydantic import BaseModel
-import tiktoken
-
-
-from open_webui.storage.provider import Storage
-from open_webui.apps.webui.models.knowledge import Knowledges
-from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
-
-# Document loaders
-from open_webui.apps.retrieval.loaders.main import Loader
-from open_webui.apps.retrieval.loaders.youtube import YoutubeLoader
-
-# Web search engines
-from open_webui.apps.retrieval.web.main import SearchResult
-from open_webui.apps.retrieval.web.utils import get_web_loader
-from open_webui.apps.retrieval.web.brave import search_brave
-from open_webui.apps.retrieval.web.mojeek import search_mojeek
-from open_webui.apps.retrieval.web.duckduckgo import search_duckduckgo
-from open_webui.apps.retrieval.web.google_pse import search_google_pse
-from open_webui.apps.retrieval.web.jina_search import search_jina
-from open_webui.apps.retrieval.web.searchapi import search_searchapi
-from open_webui.apps.retrieval.web.searxng import search_searxng
-from open_webui.apps.retrieval.web.serper import search_serper
-from open_webui.apps.retrieval.web.serply import search_serply
-from open_webui.apps.retrieval.web.serpstack import search_serpstack
-from open_webui.apps.retrieval.web.tavily import search_tavily
-from open_webui.apps.retrieval.web.bing import search_bing
-
-
-from open_webui.apps.retrieval.utils import (
-    get_embedding_function,
-    get_model_path,
-    query_collection,
-    query_collection_with_hybrid_search,
-    query_doc,
-    query_doc_with_hybrid_search,
-)
-
-from open_webui.apps.webui.models.files import Files
-from open_webui.config import (
-    BRAVE_SEARCH_API_KEY,
-    MOJEEK_SEARCH_API_KEY,
-    TIKTOKEN_ENCODING_NAME,
-    RAG_TEXT_SPLITTER,
-    CHUNK_OVERLAP,
-    CHUNK_SIZE,
-    CONTENT_EXTRACTION_ENGINE,
-    CORS_ALLOW_ORIGIN,
-    ENABLE_RAG_HYBRID_SEARCH,
-    ENABLE_RAG_LOCAL_WEB_FETCH,
-    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
-    ENABLE_RAG_WEB_SEARCH,
-    ENV,
-    GOOGLE_PSE_API_KEY,
-    GOOGLE_PSE_ENGINE_ID,
-    PDF_EXTRACT_IMAGES,
-    RAG_EMBEDDING_ENGINE,
-    RAG_EMBEDDING_MODEL,
-    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
-    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
-    RAG_EMBEDDING_BATCH_SIZE,
-    RAG_FILE_MAX_COUNT,
-    RAG_FILE_MAX_SIZE,
-    RAG_OPENAI_API_BASE_URL,
-    RAG_OPENAI_API_KEY,
-    RAG_OLLAMA_BASE_URL,
-    RAG_OLLAMA_API_KEY,
-    RAG_RELEVANCE_THRESHOLD,
-    RAG_RERANKING_MODEL,
-    RAG_RERANKING_MODEL_AUTO_UPDATE,
-    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
-    DEFAULT_RAG_TEMPLATE,
-    RAG_TEMPLATE,
-    RAG_TOP_K,
-    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
-    RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-    RAG_WEB_SEARCH_ENGINE,
-    RAG_WEB_SEARCH_RESULT_COUNT,
-    JINA_API_KEY,
-    SEARCHAPI_API_KEY,
-    SEARCHAPI_ENGINE,
-    SEARXNG_QUERY_URL,
-    SERPER_API_KEY,
-    SERPLY_API_KEY,
-    SERPSTACK_API_KEY,
-    SERPSTACK_HTTPS,
-    TAVILY_API_KEY,
-    BING_SEARCH_V7_ENDPOINT,
-    BING_SEARCH_V7_SUBSCRIPTION_KEY,
-    TIKA_SERVER_URL,
-    UPLOAD_DIR,
-    YOUTUBE_LOADER_LANGUAGE,
-    YOUTUBE_LOADER_PROXY_URL,
-    DEFAULT_LOCALE,
-    AppConfig,
-)
-from open_webui.constants import ERROR_MESSAGES
-from open_webui.env import (
-    SRC_LOG_LEVELS,
-    DEVICE_TYPE,
-    DOCKER,
-)
-from open_webui.utils.misc import (
-    calculate_sha256,
-    calculate_sha256_string,
-    extract_folders_after_data_docs,
-    sanitize_filename,
-)
-from open_webui.utils.utils import get_admin_user, get_verified_user
-
-from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
-from langchain_core.documents import Document
-
-
-log = logging.getLogger(__name__)
-log.setLevel(SRC_LOG_LEVELS["RAG"])
-
-app = FastAPI(
-    docs_url="/docs" if ENV == "dev" else None,
-    openapi_url="/openapi.json" if ENV == "dev" else None,
-    redoc_url=None,
-)
-
-app.state.config = AppConfig()
-
-app.state.config.TOP_K = RAG_TOP_K
-app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
-app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
-app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
-
-app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
-app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
-    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
-)
-
-app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
-app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
-
-app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
-app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
-
-app.state.config.CHUNK_SIZE = CHUNK_SIZE
-app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
-
-app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
-app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
-app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
-app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
-app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
-
-app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
-app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
-
-app.state.config.OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
-app.state.config.OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
-
-app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
-
-app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
-app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL
-app.state.YOUTUBE_LOADER_TRANSLATION = None
-
-
-app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
-app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
-app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
-
-app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
-app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
-app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
-app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
-app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY
-app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
-app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
-app.state.config.SERPER_API_KEY = SERPER_API_KEY
-app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
-app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
-app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
-app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
-app.state.config.JINA_API_KEY = JINA_API_KEY
-app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
-app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
-
-app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
-app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
-
-
-def update_embedding_model(
-    embedding_model: str,
-    auto_update: bool = False,
-):
-    if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
-        from sentence_transformers import SentenceTransformer
-
-        try:
-            app.state.sentence_transformer_ef = SentenceTransformer(
-                get_model_path(embedding_model, auto_update),
-                device=DEVICE_TYPE,
-                trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
-            )
-        except Exception as e:
-            log.debug(f"Error loading SentenceTransformer: {e}")
-            app.state.sentence_transformer_ef = None
-    else:
-        app.state.sentence_transformer_ef = None
-
-
-def update_reranking_model(
-    reranking_model: str,
-    auto_update: bool = False,
-):
-    if reranking_model:
-        if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
-            try:
-                from open_webui.apps.retrieval.models.colbert import ColBERT
-
-                app.state.sentence_transformer_rf = ColBERT(
-                    get_model_path(reranking_model, auto_update),
-                    env="docker" if DOCKER else None,
-                )
-            except Exception as e:
-                log.error(f"ColBERT: {e}")
-                app.state.sentence_transformer_rf = None
-                app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
-        else:
-            import sentence_transformers
-
-            try:
-                app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
-                    get_model_path(reranking_model, auto_update),
-                    device=DEVICE_TYPE,
-                    trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
-                )
-            except:
-                log.error("CrossEncoder error")
-                app.state.sentence_transformer_rf = None
-                app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
-    else:
-        app.state.sentence_transformer_rf = None
-
-
-update_embedding_model(
-    app.state.config.RAG_EMBEDDING_MODEL,
-    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
-)
-
-update_reranking_model(
-    app.state.config.RAG_RERANKING_MODEL,
-    RAG_RERANKING_MODEL_AUTO_UPDATE,
-)
-
-
-app.state.EMBEDDING_FUNCTION = get_embedding_function(
-    app.state.config.RAG_EMBEDDING_ENGINE,
-    app.state.config.RAG_EMBEDDING_MODEL,
-    app.state.sentence_transformer_ef,
-    (
-        app.state.config.OPENAI_API_BASE_URL
-        if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
-        else app.state.config.OLLAMA_BASE_URL
-    ),
-    (
-        app.state.config.OPENAI_API_KEY
-        if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
-        else app.state.config.OLLAMA_API_KEY
-    ),
-    app.state.config.RAG_EMBEDDING_BATCH_SIZE,
-)
-
-app.add_middleware(
-    CORSMiddleware,
-    allow_origins=CORS_ALLOW_ORIGIN,
-    allow_credentials=True,
-    allow_methods=["*"],
-    allow_headers=["*"],
-)
-
-
-class CollectionNameForm(BaseModel):
-    collection_name: Optional[str] = None
-
-
-class ProcessUrlForm(CollectionNameForm):
-    url: str
-
-
-class SearchForm(CollectionNameForm):
-    query: str
-
-
-@app.get("/")
-async def get_status():
-    return {
-        "status": True,
-        "chunk_size": app.state.config.CHUNK_SIZE,
-        "chunk_overlap": app.state.config.CHUNK_OVERLAP,
-        "template": app.state.config.RAG_TEMPLATE,
-        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
-        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
-        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
-        "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
-    }
-
-
-@app.get("/embedding")
-async def get_embedding_config(user=Depends(get_admin_user)):
-    return {
-        "status": True,
-        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
-        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
-        "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
-        "openai_config": {
-            "url": app.state.config.OPENAI_API_BASE_URL,
-            "key": app.state.config.OPENAI_API_KEY,
-        },
-        "ollama_config": {
-            "url": app.state.config.OLLAMA_BASE_URL,
-            "key": app.state.config.OLLAMA_API_KEY,
-        },
-    }
-
-
-@app.get("/reranking")
-async def get_reraanking_config(user=Depends(get_admin_user)):
-    return {
-        "status": True,
-        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
-    }
-
-
-class OpenAIConfigForm(BaseModel):
-    url: str
-    key: str
-
-
-class OllamaConfigForm(BaseModel):
-    url: str
-    key: str
-
-
-class EmbeddingModelUpdateForm(BaseModel):
-    openai_config: Optional[OpenAIConfigForm] = None
-    ollama_config: Optional[OllamaConfigForm] = None
-    embedding_engine: str
-    embedding_model: str
-    embedding_batch_size: Optional[int] = 1
-
-
-@app.post("/embedding/update")
-async def update_embedding_config(
-    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
-):
-    log.info(
-        f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
-    )
-    try:
-        app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
-        app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
-
-        if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
-            if form_data.openai_config is not None:
-                app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
-                app.state.config.OPENAI_API_KEY = form_data.openai_config.key
-
-            if form_data.ollama_config is not None:
-                app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url
-                app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key
-
-            app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
-
-        update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
-
-        app.state.EMBEDDING_FUNCTION = get_embedding_function(
-            app.state.config.RAG_EMBEDDING_ENGINE,
-            app.state.config.RAG_EMBEDDING_MODEL,
-            app.state.sentence_transformer_ef,
-            (
-                app.state.config.OPENAI_API_BASE_URL
-                if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
-                else app.state.config.OLLAMA_BASE_URL
-            ),
-            (
-                app.state.config.OPENAI_API_KEY
-                if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
-                else app.state.config.OLLAMA_API_KEY
-            ),
-            app.state.config.RAG_EMBEDDING_BATCH_SIZE,
-        )
-
-        return {
-            "status": True,
-            "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
-            "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
-            "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
-            "openai_config": {
-                "url": app.state.config.OPENAI_API_BASE_URL,
-                "key": app.state.config.OPENAI_API_KEY,
-            },
-            "ollama_config": {
-                "url": app.state.config.OLLAMA_BASE_URL,
-                "key": app.state.config.OLLAMA_API_KEY,
-            },
-        }
-    except Exception as e:
-        log.exception(f"Problem updating embedding model: {e}")
-        raise HTTPException(
-            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
-            detail=ERROR_MESSAGES.DEFAULT(e),
-        )
-
-
-class RerankingModelUpdateForm(BaseModel):
-    reranking_model: str
-
-
-@app.post("/reranking/update")
-async def update_reranking_config(
-    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
-):
-    log.info(
-        f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
-    )
-    try:
-        app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
-
-        update_reranking_model(app.state.config.RAG_RERANKING_MODEL, True)
-
-        return {
-            "status": True,
-            "reranking_model": app.state.config.RAG_RERANKING_MODEL,
-        }
-    except Exception as e:
-        log.exception(f"Problem updating reranking model: {e}")
-        raise HTTPException(
-            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
-            detail=ERROR_MESSAGES.DEFAULT(e),
-        )
-
-
-@app.get("/config")
-async def get_rag_config(user=Depends(get_admin_user)):
-    return {
-        "status": True,
-        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
-        "content_extraction": {
-            "engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
-            "tika_server_url": app.state.config.TIKA_SERVER_URL,
-        },
-        "chunk": {
-            "text_splitter": app.state.config.TEXT_SPLITTER,
-            "chunk_size": app.state.config.CHUNK_SIZE,
-            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
-        },
-        "file": {
-            "max_size": app.state.config.FILE_MAX_SIZE,
-            "max_count": app.state.config.FILE_MAX_COUNT,
-        },
-        "youtube": {
-            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
-            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
-            "proxy_url": app.state.config.YOUTUBE_LOADER_PROXY_URL,
-        },
-        "web": {
-            "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
-            "search": {
-                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
-                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
-                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
-                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
-                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
-                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
-                "mojeek_search_api_key": app.state.config.MOJEEK_SEARCH_API_KEY,
-                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
-                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
-                "serper_api_key": app.state.config.SERPER_API_KEY,
-                "serply_api_key": app.state.config.SERPLY_API_KEY,
-                "tavily_api_key": app.state.config.TAVILY_API_KEY,
-                "searchapi_api_key": app.state.config.SEARCHAPI_API_KEY,
-                "seaarchapi_engine": app.state.config.SEARCHAPI_ENGINE,
-                "jina_api_key": app.state.config.JINA_API_KEY,
-                "bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT,
-                "bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
-                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
-            },
-        },
-    }
-
-
-class FileConfig(BaseModel):
-    max_size: Optional[int] = None
-    max_count: Optional[int] = None
-
-
-class ContentExtractionConfig(BaseModel):
-    engine: str = ""
-    tika_server_url: Optional[str] = None
-
-
-class ChunkParamUpdateForm(BaseModel):
-    text_splitter: Optional[str] = None
-    chunk_size: int
-    chunk_overlap: int
-
-
-class YoutubeLoaderConfig(BaseModel):
-    language: list[str]
-    translation: Optional[str] = None
-    proxy_url: str = ""
-
-
-class WebSearchConfig(BaseModel):
-    enabled: bool
-    engine: Optional[str] = None
-    searxng_query_url: Optional[str] = None
-    google_pse_api_key: Optional[str] = None
-    google_pse_engine_id: Optional[str] = None
-    brave_search_api_key: Optional[str] = None
-    mojeek_search_api_key: Optional[str] = None
-    serpstack_api_key: Optional[str] = None
-    serpstack_https: Optional[bool] = None
-    serper_api_key: Optional[str] = None
-    serply_api_key: Optional[str] = None
-    tavily_api_key: Optional[str] = None
-    searchapi_api_key: Optional[str] = None
-    searchapi_engine: Optional[str] = None
-    jina_api_key: Optional[str] = None
-    bing_search_v7_endpoint: Optional[str] = None
-    bing_search_v7_subscription_key: Optional[str] = None
-    result_count: Optional[int] = None
-    concurrent_requests: Optional[int] = None
-
-
-class WebConfig(BaseModel):
-    search: WebSearchConfig
-    web_loader_ssl_verification: Optional[bool] = None
-
-
-class ConfigUpdateForm(BaseModel):
-    pdf_extract_images: Optional[bool] = None
-    file: Optional[FileConfig] = None
-    content_extraction: Optional[ContentExtractionConfig] = None
-    chunk: Optional[ChunkParamUpdateForm] = None
-    youtube: Optional[YoutubeLoaderConfig] = None
-    web: Optional[WebConfig] = None
-
-
-@app.post("/config/update")
-async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
-    app.state.config.PDF_EXTRACT_IMAGES = (
-        form_data.pdf_extract_images
-        if form_data.pdf_extract_images is not None
-        else app.state.config.PDF_EXTRACT_IMAGES
-    )
-
-    if form_data.file is not None:
-        app.state.config.FILE_MAX_SIZE = form_data.file.max_size
-        app.state.config.FILE_MAX_COUNT = form_data.file.max_count
-
-    if form_data.content_extraction is not None:
-        log.info(f"Updating text settings: {form_data.content_extraction}")
-        app.state.config.CONTENT_EXTRACTION_ENGINE = form_data.content_extraction.engine
-        app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url
-
-    if form_data.chunk is not None:
-        app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
-        app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
-        app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
-
-    if form_data.youtube is not None:
-        app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
-        app.state.config.YOUTUBE_LOADER_PROXY_URL = form_data.youtube.proxy_url
-        app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
-
-    if form_data.web is not None:
-        app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
-            # Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False
-            form_data.web.web_loader_ssl_verification
-        )
-
-        app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
-        app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
-        app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
-        app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
-        app.state.config.GOOGLE_PSE_ENGINE_ID = (
-            form_data.web.search.google_pse_engine_id
-        )
-        app.state.config.BRAVE_SEARCH_API_KEY = (
-            form_data.web.search.brave_search_api_key
-        )
-        app.state.config.MOJEEK_SEARCH_API_KEY = (
-            form_data.web.search.mojeek_search_api_key
-        )
-        app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
-        app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
-        app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
-        app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
-        app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
-        app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key
-        app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine
-
-        app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key
-        app.state.config.BING_SEARCH_V7_ENDPOINT = (
-            form_data.web.search.bing_search_v7_endpoint
-        )
-        app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = (
-            form_data.web.search.bing_search_v7_subscription_key
-        )
-
-        app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
-        app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
-            form_data.web.search.concurrent_requests
-        )
-
-    return {
-        "status": True,
-        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
-        "file": {
-            "max_size": app.state.config.FILE_MAX_SIZE,
-            "max_count": app.state.config.FILE_MAX_COUNT,
-        },
-        "content_extraction": {
-            "engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
-            "tika_server_url": app.state.config.TIKA_SERVER_URL,
-        },
-        "chunk": {
-            "text_splitter": app.state.config.TEXT_SPLITTER,
-            "chunk_size": app.state.config.CHUNK_SIZE,
-            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
-        },
-        "youtube": {
-            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
-            "proxy_url": app.state.config.YOUTUBE_LOADER_PROXY_URL,
-            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
-        },
-        "web": {
-            "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
-            "search": {
-                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
-                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
-                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
-                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
-                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
-                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
-                "mojeek_search_api_key": app.state.config.MOJEEK_SEARCH_API_KEY,
-                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
-                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
-                "serper_api_key": app.state.config.SERPER_API_KEY,
-                "serply_api_key": app.state.config.SERPLY_API_KEY,
-                "serachapi_api_key": app.state.config.SEARCHAPI_API_KEY,
-                "searchapi_engine": app.state.config.SEARCHAPI_ENGINE,
-                "tavily_api_key": app.state.config.TAVILY_API_KEY,
-                "jina_api_key": app.state.config.JINA_API_KEY,
-                "bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT,
-                "bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
-                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
-            },
-        },
-    }
-
-
-@app.get("/template")
-async def get_rag_template(user=Depends(get_verified_user)):
-    return {
-        "status": True,
-        "template": app.state.config.RAG_TEMPLATE,
-    }
-
-
-@app.get("/query/settings")
-async def get_query_settings(user=Depends(get_admin_user)):
-    return {
-        "status": True,
-        "template": app.state.config.RAG_TEMPLATE,
-        "k": app.state.config.TOP_K,
-        "r": app.state.config.RELEVANCE_THRESHOLD,
-        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
-    }
-
-
-class QuerySettingsForm(BaseModel):
-    k: Optional[int] = None
-    r: Optional[float] = None
-    template: Optional[str] = None
-    hybrid: Optional[bool] = None
-
-
-@app.post("/query/settings/update")
-async def update_query_settings(
-    form_data: QuerySettingsForm, user=Depends(get_admin_user)
-):
-    app.state.config.RAG_TEMPLATE = form_data.template
-    app.state.config.TOP_K = form_data.k if form_data.k else 4
-    app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
-
-    app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
-        form_data.hybrid if form_data.hybrid else False
-    )
-
-    return {
-        "status": True,
-        "template": app.state.config.RAG_TEMPLATE,
-        "k": app.state.config.TOP_K,
-        "r": app.state.config.RELEVANCE_THRESHOLD,
-        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
-    }
-
-
-####################################
-#
-# Document process and retrieval
-#
-####################################
-
-
-def _get_docs_info(docs: list[Document]) -> str:
-    docs_info = set()
-
-    # Trying to select relevant metadata identifying the document.
-    for doc in docs:
-        metadata = getattr(doc, "metadata", {})
-        doc_name = metadata.get("name", "")
-        if not doc_name:
-            doc_name = metadata.get("title", "")
-        if not doc_name:
-            doc_name = metadata.get("source", "")
-        if doc_name:
-            docs_info.add(doc_name)
-
-    return ", ".join(docs_info)
-
-
-def save_docs_to_vector_db(
-    docs,
-    collection_name,
-    metadata: Optional[dict] = None,
-    overwrite: bool = False,
-    split: bool = True,
-    add: bool = False,
-) -> bool:
-    log.info(
-        f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
-    )
-
-    # Check if entries with the same hash (metadata.hash) already exist
-    if metadata and "hash" in metadata:
-        result = VECTOR_DB_CLIENT.query(
-            collection_name=collection_name,
-            filter={"hash": metadata["hash"]},
-        )
-
-        if result is not None:
-            existing_doc_ids = result.ids[0]
-            if existing_doc_ids:
-                log.info(f"Document with hash {metadata['hash']} already exists")
-                raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
-
-    if split:
-        if app.state.config.TEXT_SPLITTER in ["", "character"]:
-            text_splitter = RecursiveCharacterTextSplitter(
-                chunk_size=app.state.config.CHUNK_SIZE,
-                chunk_overlap=app.state.config.CHUNK_OVERLAP,
-                add_start_index=True,
-            )
-        elif app.state.config.TEXT_SPLITTER == "token":
-            log.info(
-                f"Using token text splitter: {app.state.config.TIKTOKEN_ENCODING_NAME}"
-            )
-
-            tiktoken.get_encoding(str(app.state.config.TIKTOKEN_ENCODING_NAME))
-            text_splitter = TokenTextSplitter(
-                encoding_name=str(app.state.config.TIKTOKEN_ENCODING_NAME),
-                chunk_size=app.state.config.CHUNK_SIZE,
-                chunk_overlap=app.state.config.CHUNK_OVERLAP,
-                add_start_index=True,
-            )
-        else:
-            raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter"))
-
-        docs = text_splitter.split_documents(docs)
-
-    if len(docs) == 0:
-        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
-
-    texts = [doc.page_content for doc in docs]
-    metadatas = [
-        {
-            **doc.metadata,
-            **(metadata if metadata else {}),
-            "embedding_config": json.dumps(
-                {
-                    "engine": app.state.config.RAG_EMBEDDING_ENGINE,
-                    "model": app.state.config.RAG_EMBEDDING_MODEL,
-                }
-            ),
-        }
-        for doc in docs
-    ]
-
-    # ChromaDB does not like datetime formats
-    # for meta-data so convert them to string.
-    for metadata in metadatas:
-        for key, value in metadata.items():
-            if isinstance(value, datetime):
-                metadata[key] = str(value)
-
-    try:
-        if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
-            log.info(f"collection {collection_name} already exists")
-
-            if overwrite:
-                VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
-                log.info(f"deleting existing collection {collection_name}")
-            elif add is False:
-                log.info(
-                    f"collection {collection_name} already exists, overwrite is False and add is False"
-                )
-                return True
-
-        log.info(f"adding to collection {collection_name}")
-        embedding_function = get_embedding_function(
-            app.state.config.RAG_EMBEDDING_ENGINE,
-            app.state.config.RAG_EMBEDDING_MODEL,
-            app.state.sentence_transformer_ef,
-            (
-                app.state.config.OPENAI_API_BASE_URL
-                if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
-                else app.state.config.OLLAMA_BASE_URL
-            ),
-            (
-                app.state.config.OPENAI_API_KEY
-                if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
-                else app.state.config.OLLAMA_API_KEY
-            ),
-            app.state.config.RAG_EMBEDDING_BATCH_SIZE,
-        )
-
-        embeddings = embedding_function(
-            list(map(lambda x: x.replace("\n", " "), texts))
-        )
-
-        items = [
-            {
-                "id": str(uuid.uuid4()),
-                "text": text,
-                "vector": embeddings[idx],
-                "metadata": metadatas[idx],
-            }
-            for idx, text in enumerate(texts)
-        ]
-
-        VECTOR_DB_CLIENT.insert(
-            collection_name=collection_name,
-            items=items,
-        )
-
-        return True
-    except Exception as e:
-        log.exception(e)
-        raise e
-
-
-class ProcessFileForm(BaseModel):
-    file_id: str
-    content: Optional[str] = None
-    collection_name: Optional[str] = None
-
-
-@app.post("/process/file")
-def process_file(
-    form_data: ProcessFileForm,
-    user=Depends(get_verified_user),
-):
-    try:
-        file = Files.get_file_by_id(form_data.file_id)
-
-        collection_name = form_data.collection_name
-
-        if collection_name is None:
-            collection_name = f"file-{file.id}"
-
-        if form_data.content:
-            # Update the content in the file
-            # Usage: /files/{file_id}/data/content/update
-
-            VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
-
-            docs = [
-                Document(
-                    page_content=form_data.content.replace("<br/>", "\n"),
-                    metadata={
-                        **file.meta,
-                        "name": file.filename,
-                        "created_by": file.user_id,
-                        "file_id": file.id,
-                        "source": file.filename,
-                    },
-                )
-            ]
-
-            text_content = form_data.content
-        elif form_data.collection_name:
-            # Check if the file has already been processed and save the content
-            # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update
-
-            result = VECTOR_DB_CLIENT.query(
-                collection_name=f"file-{file.id}", filter={"file_id": file.id}
-            )
-
-            if result is not None and len(result.ids[0]) > 0:
-                docs = [
-                    Document(
-                        page_content=result.documents[0][idx],
-                        metadata=result.metadatas[0][idx],
-                    )
-                    for idx, id in enumerate(result.ids[0])
-                ]
-            else:
-                docs = [
-                    Document(
-                        page_content=file.data.get("content", ""),
-                        metadata={
-                            **file.meta,
-                            "name": file.filename,
-                            "created_by": file.user_id,
-                            "file_id": file.id,
-                            "source": file.filename,
-                        },
-                    )
-                ]
-
-            text_content = file.data.get("content", "")
-        else:
-            # Process the file and save the content
-            # Usage: /files/
-            file_path = file.path
-            if file_path:
-                file_path = Storage.get_file(file_path)
-                loader = Loader(
-                    engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
-                    TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
-                    PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
-                )
-                docs = loader.load(
-                    file.filename, file.meta.get("content_type"), file_path
-                )
-
-                docs = [
-                    Document(
-                        page_content=doc.page_content,
-                        metadata={
-                            **doc.metadata,
-                            "name": file.filename,
-                            "created_by": file.user_id,
-                            "file_id": file.id,
-                            "source": file.filename,
-                        },
-                    )
-                    for doc in docs
-                ]
-            else:
-                docs = [
-                    Document(
-                        page_content=file.data.get("content", ""),
-                        metadata={
-                            **file.meta,
-                            "name": file.filename,
-                            "created_by": file.user_id,
-                            "file_id": file.id,
-                            "source": file.filename,
-                        },
-                    )
-                ]
-            text_content = " ".join([doc.page_content for doc in docs])
-
-        log.debug(f"text_content: {text_content}")
-        Files.update_file_data_by_id(
-            file.id,
-            {"content": text_content},
-        )
-
-        hash = calculate_sha256_string(text_content)
-        Files.update_file_hash_by_id(file.id, hash)
-
-        try:
-            result = save_docs_to_vector_db(
-                docs=docs,
-                collection_name=collection_name,
-                metadata={
-                    "file_id": file.id,
-                    "name": file.filename,
-                    "hash": hash,
-                },
-                add=(True if form_data.collection_name else False),
-            )
-
-            if result:
-                Files.update_file_metadata_by_id(
-                    file.id,
-                    {
-                        "collection_name": collection_name,
-                    },
-                )
-
-                return {
-                    "status": True,
-                    "collection_name": collection_name,
-                    "filename": file.filename,
-                    "content": text_content,
-                }
-        except Exception as e:
-            raise e
-    except Exception as e:
-        log.exception(e)
-        if "No pandoc was found" in str(e):
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
-            )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=str(e),
-            )
-
-
-class ProcessTextForm(BaseModel):
-    name: str
-    content: str
-    collection_name: Optional[str] = None
-
-
-@app.post("/process/text")
-def process_text(
-    form_data: ProcessTextForm,
-    user=Depends(get_verified_user),
-):
-    collection_name = form_data.collection_name
-    if collection_name is None:
-        collection_name = calculate_sha256_string(form_data.content)
-
-    docs = [
-        Document(
-            page_content=form_data.content,
-            metadata={"name": form_data.name, "created_by": user.id},
-        )
-    ]
-    text_content = form_data.content
-    log.debug(f"text_content: {text_content}")
-
-    result = save_docs_to_vector_db(docs, collection_name)
-
-    if result:
-        return {
-            "status": True,
-            "collection_name": collection_name,
-            "content": text_content,
-        }
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
-            detail=ERROR_MESSAGES.DEFAULT(),
-        )
-
-
-@app.post("/process/youtube")
-def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_user)):
-    try:
-        collection_name = form_data.collection_name
-        if not collection_name:
-            collection_name = calculate_sha256_string(form_data.url)[:63]
-
-        loader = YoutubeLoader(
-            form_data.url,
-            language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
-            proxy_url=app.state.config.YOUTUBE_LOADER_PROXY_URL,
-        )
-
-        docs = loader.load()
-        content = " ".join([doc.page_content for doc in docs])
-        log.debug(f"text_content: {content}")
-        save_docs_to_vector_db(docs, collection_name, overwrite=True)
-
-        return {
-            "status": True,
-            "collection_name": collection_name,
-            "filename": form_data.url,
-            "file": {
-                "data": {
-                    "content": content,
-                },
-                "meta": {
-                    "name": form_data.url,
-                },
-            },
-        }
-    except Exception as e:
-        log.exception(e)
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.DEFAULT(e),
-        )
-
-
-@app.post("/process/web")
-def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)):
-    try:
-        collection_name = form_data.collection_name
-        if not collection_name:
-            collection_name = calculate_sha256_string(form_data.url)[:63]
-
-        loader = get_web_loader(
-            form_data.url,
-            verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
-            requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
-        )
-        docs = loader.load()
-        content = " ".join([doc.page_content for doc in docs])
-        log.debug(f"text_content: {content}")
-        save_docs_to_vector_db(docs, collection_name, overwrite=True)
-
-        return {
-            "status": True,
-            "collection_name": collection_name,
-            "filename": form_data.url,
-            "file": {
-                "data": {
-                    "content": content,
-                },
-                "meta": {
-                    "name": form_data.url,
-                },
-            },
-        }
-    except Exception as e:
-        log.exception(e)
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.DEFAULT(e),
-        )
-
-
-def search_web(engine: str, query: str) -> list[SearchResult]:
-    """Search the web using a search engine and return the results as a list of SearchResult objects.
-    Will look for a search engine API key in environment variables in the following order:
-    - SEARXNG_QUERY_URL
-    - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
-    - BRAVE_SEARCH_API_KEY
-    - MOJEEK_SEARCH_API_KEY
-    - SERPSTACK_API_KEY
-    - SERPER_API_KEY
-    - SERPLY_API_KEY
-    - TAVILY_API_KEY
-    - SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
-    Args:
-        query (str): The query to search for
-    """
-
-    # TODO: add playwright to search the web
-    if engine == "searxng":
-        if app.state.config.SEARXNG_QUERY_URL:
-            return search_searxng(
-                app.state.config.SEARXNG_QUERY_URL,
-                query,
-                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-            )
-        else:
-            raise Exception("No SEARXNG_QUERY_URL found in environment variables")
-    elif engine == "google_pse":
-        if (
-            app.state.config.GOOGLE_PSE_API_KEY
-            and app.state.config.GOOGLE_PSE_ENGINE_ID
-        ):
-            return search_google_pse(
-                app.state.config.GOOGLE_PSE_API_KEY,
-                app.state.config.GOOGLE_PSE_ENGINE_ID,
-                query,
-                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-            )
-        else:
-            raise Exception(
-                "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
-            )
-    elif engine == "brave":
-        if app.state.config.BRAVE_SEARCH_API_KEY:
-            return search_brave(
-                app.state.config.BRAVE_SEARCH_API_KEY,
-                query,
-                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-            )
-        else:
-            raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
-    elif engine == "mojeek":
-        if app.state.config.MOJEEK_SEARCH_API_KEY:
-            return search_mojeek(
-                app.state.config.MOJEEK_SEARCH_API_KEY,
-                query,
-                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-            )
-        else:
-            raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables")
-    elif engine == "serpstack":
-        if app.state.config.SERPSTACK_API_KEY:
-            return search_serpstack(
-                app.state.config.SERPSTACK_API_KEY,
-                query,
-                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-                https_enabled=app.state.config.SERPSTACK_HTTPS,
-            )
-        else:
-            raise Exception("No SERPSTACK_API_KEY found in environment variables")
-    elif engine == "serper":
-        if app.state.config.SERPER_API_KEY:
-            return search_serper(
-                app.state.config.SERPER_API_KEY,
-                query,
-                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-            )
-        else:
-            raise Exception("No SERPER_API_KEY found in environment variables")
-    elif engine == "serply":
-        if app.state.config.SERPLY_API_KEY:
-            return search_serply(
-                app.state.config.SERPLY_API_KEY,
-                query,
-                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-            )
-        else:
-            raise Exception("No SERPLY_API_KEY found in environment variables")
-    elif engine == "duckduckgo":
-        return search_duckduckgo(
-            query,
-            app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-            app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-        )
-    elif engine == "tavily":
-        if app.state.config.TAVILY_API_KEY:
-            return search_tavily(
-                app.state.config.TAVILY_API_KEY,
-                query,
-                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-            )
-        else:
-            raise Exception("No TAVILY_API_KEY found in environment variables")
-    elif engine == "searchapi":
-        if app.state.config.SEARCHAPI_API_KEY:
-            return search_searchapi(
-                app.state.config.SEARCHAPI_API_KEY,
-                app.state.config.SEARCHAPI_ENGINE,
-                query,
-                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-            )
-        else:
-            raise Exception("No SEARCHAPI_API_KEY found in environment variables")
-    elif engine == "jina":
-        return search_jina(
-            app.state.config.JINA_API_KEY,
-            query,
-            app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-        )
-    elif engine == "bing":
-        return search_bing(
-            app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
-            app.state.config.BING_SEARCH_V7_ENDPOINT,
-            str(DEFAULT_LOCALE),
-            query,
-            app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-            app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-        )
-    else:
-        raise Exception("No search engine API key found in environment variables")
-
-
-@app.post("/process/web/search")
-def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
-    try:
-        logging.info(
-            f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
-        )
-        web_results = search_web(
-            app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
-        )
-    except Exception as e:
-        log.exception(e)
-
-        print(e)
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
-        )
-
-    try:
-        collection_name = form_data.collection_name
-        if collection_name == "":
-            collection_name = calculate_sha256_string(form_data.query)[:63]
-
-        urls = [result.link for result in web_results]
-
-        loader = get_web_loader(
-            urls,
-            verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
-            requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
-        )
-        docs = loader.aload()
-
-        save_docs_to_vector_db(docs, collection_name, overwrite=True)
-
-        return {
-            "status": True,
-            "collection_name": collection_name,
-            "filenames": urls,
-        }
-    except Exception as e:
-        log.exception(e)
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.DEFAULT(e),
-        )
-
-
-class QueryDocForm(BaseModel):
-    collection_name: str
-    query: str
-    k: Optional[int] = None
-    r: Optional[float] = None
-    hybrid: Optional[bool] = None
-
-
-@app.post("/query/doc")
-def query_doc_handler(
-    form_data: QueryDocForm,
-    user=Depends(get_verified_user),
-):
-    try:
-        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
-            return query_doc_with_hybrid_search(
-                collection_name=form_data.collection_name,
-                query=form_data.query,
-                embedding_function=app.state.EMBEDDING_FUNCTION,
-                k=form_data.k if form_data.k else app.state.config.TOP_K,
-                reranking_function=app.state.sentence_transformer_rf,
-                r=(
-                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
-                ),
-            )
-        else:
-            return query_doc(
-                collection_name=form_data.collection_name,
-                query=form_data.query,
-                embedding_function=app.state.EMBEDDING_FUNCTION,
-                k=form_data.k if form_data.k else app.state.config.TOP_K,
-            )
-    except Exception as e:
-        log.exception(e)
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.DEFAULT(e),
-        )
-
-
-class QueryCollectionsForm(BaseModel):
-    collection_names: list[str]
-    query: str
-    k: Optional[int] = None
-    r: Optional[float] = None
-    hybrid: Optional[bool] = None
-
-
-@app.post("/query/collection")
-def query_collection_handler(
-    form_data: QueryCollectionsForm,
-    user=Depends(get_verified_user),
-):
-    try:
-        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
-            return query_collection_with_hybrid_search(
-                collection_names=form_data.collection_names,
-                queries=[form_data.query],
-                embedding_function=app.state.EMBEDDING_FUNCTION,
-                k=form_data.k if form_data.k else app.state.config.TOP_K,
-                reranking_function=app.state.sentence_transformer_rf,
-                r=(
-                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
-                ),
-            )
-        else:
-            return query_collection(
-                collection_names=form_data.collection_names,
-                queries=[form_data.query],
-                embedding_function=app.state.EMBEDDING_FUNCTION,
-                k=form_data.k if form_data.k else app.state.config.TOP_K,
-            )
-
-    except Exception as e:
-        log.exception(e)
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.DEFAULT(e),
-        )
-
-
-####################################
-#
-# Vector DB operations
-#
-####################################
-
-
-class DeleteForm(BaseModel):
-    collection_name: str
-    file_id: str
-
-
-@app.post("/delete")
-def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)):
-    try:
-        if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
-            file = Files.get_file_by_id(form_data.file_id)
-            hash = file.hash
-
-            VECTOR_DB_CLIENT.delete(
-                collection_name=form_data.collection_name,
-                metadata={"hash": hash},
-            )
-            return {"status": True}
-        else:
-            return {"status": False}
-    except Exception as e:
-        log.exception(e)
-        return {"status": False}
-
-
-@app.post("/reset/db")
-def reset_vector_db(user=Depends(get_admin_user)):
-    VECTOR_DB_CLIENT.reset()
-    Knowledges.delete_all_knowledge()
-
-
-@app.post("/reset/uploads")
-def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
-    folder = f"{UPLOAD_DIR}"
-    try:
-        # Check if the directory exists
-        if os.path.exists(folder):
-            # Iterate over all the files and directories in the specified directory
-            for filename in os.listdir(folder):
-                file_path = os.path.join(folder, filename)
-                try:
-                    if os.path.isfile(file_path) or os.path.islink(file_path):
-                        os.unlink(file_path)  # Remove the file or link
-                    elif os.path.isdir(file_path):
-                        shutil.rmtree(file_path)  # Remove the directory
-                except Exception as e:
-                    print(f"Failed to delete {file_path}. Reason: {e}")
-        else:
-            print(f"The directory {folder} does not exist")
-    except Exception as e:
-        print(f"Failed to process the directory {folder}. Reason: {e}")
-    return True
-
-
-if ENV == "dev":
-
-    @app.get("/ef")
-    async def get_embeddings():
-        return {"result": app.state.EMBEDDING_FUNCTION("hello world")}
-
-    @app.get("/ef/{text}")
-    async def get_embeddings_text(text: str):
-        return {"result": app.state.EMBEDDING_FUNCTION(text)}

+ 0 - 22
backend/open_webui/apps/retrieval/vector/connector.py

@@ -1,22 +0,0 @@
-from open_webui.config import VECTOR_DB
-
-if VECTOR_DB == "milvus":
-    from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
-
-    VECTOR_DB_CLIENT = MilvusClient()
-elif VECTOR_DB == "qdrant":
-    from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
-
-    VECTOR_DB_CLIENT = QdrantClient()
-elif VECTOR_DB == "opensearch":
-    from open_webui.apps.retrieval.vector.dbs.opensearch import OpenSearchClient
-
-    VECTOR_DB_CLIENT = OpenSearchClient()
-elif VECTOR_DB == "pgvector":
-    from open_webui.apps.retrieval.vector.dbs.pgvector import PgvectorClient
-
-    VECTOR_DB_CLIENT = PgvectorClient()
-else:
-    from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
-
-    VECTOR_DB_CLIENT = ChromaClient()

+ 0 - 506
backend/open_webui/apps/webui/main.py

@@ -1,506 +0,0 @@
-import inspect
-import json
-import logging
-import time
-from typing import AsyncGenerator, Generator, Iterator
-
-from open_webui.apps.socket.main import get_event_call, get_event_emitter
-from open_webui.apps.webui.models.functions import Functions
-from open_webui.apps.webui.models.models import Models
-from open_webui.apps.webui.routers import (
-    auths,
-    chats,
-    folders,
-    configs,
-    groups,
-    files,
-    functions,
-    memories,
-    models,
-    knowledge,
-    prompts,
-    evaluations,
-    tools,
-    users,
-    utils,
-)
-from open_webui.apps.webui.utils import load_function_module_by_id
-from open_webui.config import (
-    ADMIN_EMAIL,
-    CORS_ALLOW_ORIGIN,
-    DEFAULT_MODELS,
-    DEFAULT_PROMPT_SUGGESTIONS,
-    DEFAULT_USER_ROLE,
-    MODEL_ORDER_LIST,
-    ENABLE_COMMUNITY_SHARING,
-    ENABLE_LOGIN_FORM,
-    ENABLE_MESSAGE_RATING,
-    ENABLE_SIGNUP,
-    ENABLE_API_KEY,
-    ENABLE_EVALUATION_ARENA_MODELS,
-    EVALUATION_ARENA_MODELS,
-    DEFAULT_ARENA_MODEL,
-    JWT_EXPIRES_IN,
-    ENABLE_OAUTH_ROLE_MANAGEMENT,
-    OAUTH_ROLES_CLAIM,
-    OAUTH_EMAIL_CLAIM,
-    OAUTH_PICTURE_CLAIM,
-    OAUTH_USERNAME_CLAIM,
-    OAUTH_ALLOWED_ROLES,
-    OAUTH_ADMIN_ROLES,
-    SHOW_ADMIN_DETAILS,
-    USER_PERMISSIONS,
-    WEBHOOK_URL,
-    WEBUI_AUTH,
-    WEBUI_BANNERS,
-    ENABLE_LDAP,
-    LDAP_SERVER_LABEL,
-    LDAP_SERVER_HOST,
-    LDAP_SERVER_PORT,
-    LDAP_ATTRIBUTE_FOR_USERNAME,
-    LDAP_SEARCH_FILTERS,
-    LDAP_SEARCH_BASE,
-    LDAP_APP_DN,
-    LDAP_APP_PASSWORD,
-    LDAP_USE_TLS,
-    LDAP_CA_CERT_FILE,
-    LDAP_CIPHERS,
-    AppConfig,
-)
-from open_webui.env import (
-    ENV,
-    SRC_LOG_LEVELS,
-    WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
-    WEBUI_AUTH_TRUSTED_NAME_HEADER,
-)
-from fastapi import FastAPI
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import StreamingResponse
-from pydantic import BaseModel
-from open_webui.utils.misc import (
-    openai_chat_chunk_message_template,
-    openai_chat_completion_message_template,
-)
-from open_webui.utils.payload import (
-    apply_model_params_to_body_openai,
-    apply_model_system_prompt_to_body,
-)
-
-
-from open_webui.utils.tools import get_tools
-
-app = FastAPI(
-    docs_url="/docs" if ENV == "dev" else None,
-    openapi_url="/openapi.json" if ENV == "dev" else None,
-    redoc_url=None,
-)
-
-log = logging.getLogger(__name__)
-log.setLevel(SRC_LOG_LEVELS["MAIN"])
-
-app.state.config = AppConfig()
-
-app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
-app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
-app.state.config.ENABLE_API_KEY = ENABLE_API_KEY
-
-app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
-app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
-app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
-
-
-app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
-app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
-
-
-app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
-app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
-app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
-
-
-app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
-app.state.config.WEBHOOK_URL = WEBHOOK_URL
-app.state.config.BANNERS = WEBUI_BANNERS
-app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST
-
-app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
-app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING
-
-app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS
-app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS
-
-app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
-app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
-app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
-
-app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
-app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
-app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
-app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
-
-app.state.config.ENABLE_LDAP = ENABLE_LDAP
-app.state.config.LDAP_SERVER_LABEL = LDAP_SERVER_LABEL
-app.state.config.LDAP_SERVER_HOST = LDAP_SERVER_HOST
-app.state.config.LDAP_SERVER_PORT = LDAP_SERVER_PORT
-app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME
-app.state.config.LDAP_APP_DN = LDAP_APP_DN
-app.state.config.LDAP_APP_PASSWORD = LDAP_APP_PASSWORD
-app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE
-app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS
-app.state.config.LDAP_USE_TLS = LDAP_USE_TLS
-app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE
-app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
-
-app.state.TOOLS = {}
-app.state.FUNCTIONS = {}
-
-app.add_middleware(
-    CORSMiddleware,
-    allow_origins=CORS_ALLOW_ORIGIN,
-    allow_credentials=True,
-    allow_methods=["*"],
-    allow_headers=["*"],
-)
-
-
-app.include_router(configs.router, prefix="/configs", tags=["configs"])
-
-app.include_router(auths.router, prefix="/auths", tags=["auths"])
-app.include_router(users.router, prefix="/users", tags=["users"])
-
-app.include_router(chats.router, prefix="/chats", tags=["chats"])
-
-app.include_router(models.router, prefix="/models", tags=["models"])
-app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
-app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
-app.include_router(tools.router, prefix="/tools", tags=["tools"])
-
-app.include_router(memories.router, prefix="/memories", tags=["memories"])
-app.include_router(folders.router, prefix="/folders", tags=["folders"])
-
-app.include_router(groups.router, prefix="/groups", tags=["groups"])
-app.include_router(files.router, prefix="/files", tags=["files"])
-app.include_router(functions.router, prefix="/functions", tags=["functions"])
-app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
-
-
-app.include_router(utils.router, prefix="/utils", tags=["utils"])
-
-
-@app.get("/")
-async def get_status():
-    return {
-        "status": True,
-        "auth": WEBUI_AUTH,
-        "default_models": app.state.config.DEFAULT_MODELS,
-        "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
-    }
-
-
-async def get_all_models():
-    models = []
-    pipe_models = await get_pipe_models()
-    models = models + pipe_models
-
-    if app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
-        arena_models = []
-        if len(app.state.config.EVALUATION_ARENA_MODELS) > 0:
-            arena_models = [
-                {
-                    "id": model["id"],
-                    "name": model["name"],
-                    "info": {
-                        "meta": model["meta"],
-                    },
-                    "object": "model",
-                    "created": int(time.time()),
-                    "owned_by": "arena",
-                    "arena": True,
-                }
-                for model in app.state.config.EVALUATION_ARENA_MODELS
-            ]
-        else:
-            # Add default arena model
-            arena_models = [
-                {
-                    "id": DEFAULT_ARENA_MODEL["id"],
-                    "name": DEFAULT_ARENA_MODEL["name"],
-                    "info": {
-                        "meta": DEFAULT_ARENA_MODEL["meta"],
-                    },
-                    "object": "model",
-                    "created": int(time.time()),
-                    "owned_by": "arena",
-                    "arena": True,
-                }
-            ]
-        models = models + arena_models
-    return models
-
-
-def get_function_module(pipe_id: str):
-    # Check if function is already loaded
-    if pipe_id not in app.state.FUNCTIONS:
-        function_module, _, _ = load_function_module_by_id(pipe_id)
-        app.state.FUNCTIONS[pipe_id] = function_module
-    else:
-        function_module = app.state.FUNCTIONS[pipe_id]
-
-    if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
-        valves = Functions.get_function_valves_by_id(pipe_id)
-        function_module.valves = function_module.Valves(**(valves if valves else {}))
-    return function_module
-
-
-async def get_pipe_models():
-    pipes = Functions.get_functions_by_type("pipe", active_only=True)
-    pipe_models = []
-
-    for pipe in pipes:
-        function_module = get_function_module(pipe.id)
-
-        # Check if function is a manifold
-        if hasattr(function_module, "pipes"):
-            sub_pipes = []
-
-            # Check if pipes is a function or a list
-
-            try:
-                if callable(function_module.pipes):
-                    sub_pipes = function_module.pipes()
-                else:
-                    sub_pipes = function_module.pipes
-            except Exception as e:
-                log.exception(e)
-                sub_pipes = []
-
-            log.debug(
-                f"get_pipe_models: function '{pipe.id}' is a manifold of {sub_pipes}"
-            )
-
-            for p in sub_pipes:
-                sub_pipe_id = f'{pipe.id}.{p["id"]}'
-                sub_pipe_name = p["name"]
-
-                if hasattr(function_module, "name"):
-                    sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
-
-                pipe_flag = {"type": pipe.type}
-
-                pipe_models.append(
-                    {
-                        "id": sub_pipe_id,
-                        "name": sub_pipe_name,
-                        "object": "model",
-                        "created": pipe.created_at,
-                        "owned_by": "openai",
-                        "pipe": pipe_flag,
-                    }
-                )
-        else:
-            pipe_flag = {"type": "pipe"}
-
-            log.debug(
-                f"get_pipe_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
-            )
-
-            pipe_models.append(
-                {
-                    "id": pipe.id,
-                    "name": pipe.name,
-                    "object": "model",
-                    "created": pipe.created_at,
-                    "owned_by": "openai",
-                    "pipe": pipe_flag,
-                }
-            )
-
-    return pipe_models
-
-
-async def execute_pipe(pipe, params):
-    if inspect.iscoroutinefunction(pipe):
-        return await pipe(**params)
-    else:
-        return pipe(**params)
-
-
-async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
-    if isinstance(res, str):
-        return res
-    if isinstance(res, Generator):
-        return "".join(map(str, res))
-    if isinstance(res, AsyncGenerator):
-        return "".join([str(stream) async for stream in res])
-
-
-def process_line(form_data: dict, line):
-    if isinstance(line, BaseModel):
-        line = line.model_dump_json()
-        line = f"data: {line}"
-    if isinstance(line, dict):
-        line = f"data: {json.dumps(line)}"
-
-    try:
-        line = line.decode("utf-8")
-    except Exception:
-        pass
-
-    if line.startswith("data:"):
-        return f"{line}\n\n"
-    else:
-        line = openai_chat_chunk_message_template(form_data["model"], line)
-        return f"data: {json.dumps(line)}\n\n"
-
-
-def get_pipe_id(form_data: dict) -> str:
-    pipe_id = form_data["model"]
-    if "." in pipe_id:
-        pipe_id, _ = pipe_id.split(".", 1)
-
-    return pipe_id
-
-
-def get_function_params(function_module, form_data, user, extra_params=None):
-    if extra_params is None:
-        extra_params = {}
-
-    pipe_id = get_pipe_id(form_data)
-
-    # Get the signature of the function
-    sig = inspect.signature(function_module.pipe)
-    params = {"body": form_data} | {
-        k: v for k, v in extra_params.items() if k in sig.parameters
-    }
-
-    if "__user__" in params and hasattr(function_module, "UserValves"):
-        user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
-        try:
-            params["__user__"]["valves"] = function_module.UserValves(**user_valves)
-        except Exception as e:
-            log.exception(e)
-            params["__user__"]["valves"] = function_module.UserValves()
-
-    return params
-
-
-async def generate_function_chat_completion(form_data, user, models: dict = {}):
-    model_id = form_data.get("model")
-    model_info = Models.get_model_by_id(model_id)
-
-    metadata = form_data.pop("metadata", {})
-
-    files = metadata.get("files", [])
-    tool_ids = metadata.get("tool_ids", [])
-    # Check if tool_ids is None
-    if tool_ids is None:
-        tool_ids = []
-
-    __event_emitter__ = None
-    __event_call__ = None
-    __task__ = None
-    __task_body__ = None
-
-    if metadata:
-        if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
-            __event_emitter__ = get_event_emitter(metadata)
-            __event_call__ = get_event_call(metadata)
-        __task__ = metadata.get("task", None)
-        __task_body__ = metadata.get("task_body", None)
-
-    extra_params = {
-        "__event_emitter__": __event_emitter__,
-        "__event_call__": __event_call__,
-        "__task__": __task__,
-        "__task_body__": __task_body__,
-        "__files__": files,
-        "__user__": {
-            "id": user.id,
-            "email": user.email,
-            "name": user.name,
-            "role": user.role,
-        },
-        "__metadata__": metadata,
-    }
-    extra_params["__tools__"] = get_tools(
-        app,
-        tool_ids,
-        user,
-        {
-            **extra_params,
-            "__model__": models.get(form_data["model"], None),
-            "__messages__": form_data["messages"],
-            "__files__": files,
-        },
-    )
-
-    if model_info:
-        if model_info.base_model_id:
-            form_data["model"] = model_info.base_model_id
-
-        params = model_info.params.model_dump()
-        form_data = apply_model_params_to_body_openai(params, form_data)
-        form_data = apply_model_system_prompt_to_body(params, form_data, user)
-
-    pipe_id = get_pipe_id(form_data)
-    function_module = get_function_module(pipe_id)
-
-    pipe = function_module.pipe
-    params = get_function_params(function_module, form_data, user, extra_params)
-
-    if form_data.get("stream", False):
-
-        async def stream_content():
-            try:
-                res = await execute_pipe(pipe, params)
-
-                # Directly return if the response is a StreamingResponse
-                if isinstance(res, StreamingResponse):
-                    async for data in res.body_iterator:
-                        yield data
-                    return
-                if isinstance(res, dict):
-                    yield f"data: {json.dumps(res)}\n\n"
-                    return
-
-            except Exception as e:
-                log.error(f"Error: {e}")
-                yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
-                return
-
-            if isinstance(res, str):
-                message = openai_chat_chunk_message_template(form_data["model"], res)
-                yield f"data: {json.dumps(message)}\n\n"
-
-            if isinstance(res, Iterator):
-                for line in res:
-                    yield process_line(form_data, line)
-
-            if isinstance(res, AsyncGenerator):
-                async for line in res:
-                    yield process_line(form_data, line)
-
-            if isinstance(res, str) or isinstance(res, Generator):
-                finish_message = openai_chat_chunk_message_template(
-                    form_data["model"], ""
-                )
-                finish_message["choices"][0]["finish_reason"] = "stop"
-                yield f"data: {json.dumps(finish_message)}\n\n"
-                yield "data: [DONE]"
-
-        return StreamingResponse(stream_content(), media_type="text/event-stream")
-    else:
-        try:
-            res = await execute_pipe(pipe, params)
-
-        except Exception as e:
-            log.error(f"Error: {e}")
-            return {"error": {"detail": str(e)}}
-
-        if isinstance(res, StreamingResponse) or isinstance(res, dict):
-            return res
-        if isinstance(res, BaseModel):
-            return res.model_dump()
-
-        message = await get_message_content(res)
-        return openai_chat_completion_message_template(form_data["model"], message)

+ 121 - 4
backend/open_webui/config.py

@@ -10,7 +10,7 @@ from urllib.parse import urlparse
 import chromadb
 import chromadb
 import requests
 import requests
 import yaml
 import yaml
-from open_webui.apps.webui.internal.db import Base, get_db
+from open_webui.internal.db import Base, get_db
 from open_webui.env import (
 from open_webui.env import (
     OPEN_WEBUI_DIR,
     OPEN_WEBUI_DIR,
     DATA_DIR,
     DATA_DIR,
@@ -21,6 +21,7 @@ from open_webui.env import (
     WEBUI_NAME,
     WEBUI_NAME,
     log,
     log,
     DATABASE_URL,
     DATABASE_URL,
+    OFFLINE_MODE,
 )
 )
 from pydantic import BaseModel
 from pydantic import BaseModel
 from sqlalchemy import JSON, Column, DateTime, Integer, func
 from sqlalchemy import JSON, Column, DateTime, Integer, func
@@ -306,6 +307,18 @@ GOOGLE_CLIENT_SECRET = PersistentConfig(
     os.environ.get("GOOGLE_CLIENT_SECRET", ""),
     os.environ.get("GOOGLE_CLIENT_SECRET", ""),
 )
 )
 
 
+GOOGLE_DRIVE_CLIENT_ID = PersistentConfig(
+    "GOOGLE_DRIVE_CLIENT_ID",
+    "google_drive.client_id",
+    os.environ.get("GOOGLE_DRIVE_CLIENT_ID", ""),
+)
+
+GOOGLE_DRIVE_API_KEY = PersistentConfig(
+    "GOOGLE_DRIVE_API_KEY",
+    "google_drive.api_key",
+    os.environ.get("GOOGLE_DRIVE_API_KEY", ""),
+)
+
 GOOGLE_OAUTH_SCOPE = PersistentConfig(
 GOOGLE_OAUTH_SCOPE = PersistentConfig(
     "GOOGLE_OAUTH_SCOPE",
     "GOOGLE_OAUTH_SCOPE",
     "oauth.google.scope",
     "oauth.google.scope",
@@ -402,12 +415,24 @@ OAUTH_EMAIL_CLAIM = PersistentConfig(
     os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
     os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
 )
 )
 
 
+OAUTH_GROUPS_CLAIM = PersistentConfig(
+    "OAUTH_GROUPS_CLAIM",
+    "oauth.oidc.group_claim",
+    os.environ.get("OAUTH_GROUP_CLAIM", "groups"),
+)
+
 ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig(
 ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig(
     "ENABLE_OAUTH_ROLE_MANAGEMENT",
     "ENABLE_OAUTH_ROLE_MANAGEMENT",
     "oauth.enable_role_mapping",
     "oauth.enable_role_mapping",
     os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true",
     os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true",
 )
 )
 
 
+ENABLE_OAUTH_GROUP_MANAGEMENT = PersistentConfig(
+    "ENABLE_OAUTH_GROUP_MANAGEMENT",
+    "oauth.enable_group_mapping",
+    os.environ.get("ENABLE_OAUTH_GROUP_MANAGEMENT", "False").lower() == "true",
+)
+
 OAUTH_ROLES_CLAIM = PersistentConfig(
 OAUTH_ROLES_CLAIM = PersistentConfig(
     "OAUTH_ROLES_CLAIM",
     "OAUTH_ROLES_CLAIM",
     "oauth.roles_claim",
     "oauth.roles_claim",
@@ -429,6 +454,15 @@ OAUTH_ADMIN_ROLES = PersistentConfig(
     [role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
     [role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
 )
 )
 
 
+OAUTH_ALLOWED_DOMAINS = PersistentConfig(
+    "OAUTH_ALLOWED_DOMAINS",
+    "oauth.allowed_domains",
+    [
+        domain.strip()
+        for domain in os.environ.get("OAUTH_ALLOWED_DOMAINS", "*").split(",")
+    ],
+)
+
 
 
 def load_oauth_providers():
 def load_oauth_providers():
     OAUTH_PROVIDERS.clear()
     OAUTH_PROVIDERS.clear()
@@ -686,6 +720,12 @@ OPENAI_API_BASE_URL = "https://api.openai.com/v1"
 # WEBUI
 # WEBUI
 ####################################
 ####################################
 
 
+
+WEBUI_URL = PersistentConfig(
+    "WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", "http://localhost:3000")
+)
+
+
 ENABLE_SIGNUP = PersistentConfig(
 ENABLE_SIGNUP = PersistentConfig(
     "ENABLE_SIGNUP",
     "ENABLE_SIGNUP",
     "ui.enable_signup",
     "ui.enable_signup",
@@ -813,6 +853,12 @@ USER_PERMISSIONS = PersistentConfig(
     },
     },
 )
 )
 
 
+ENABLE_CHANNELS = PersistentConfig(
+    "ENABLE_CHANNELS",
+    "channels.enable",
+    os.environ.get("ENABLE_CHANNELS", "False").lower() == "true",
+)
+
 
 
 ENABLE_EVALUATION_ARENA_MODELS = PersistentConfig(
 ENABLE_EVALUATION_ARENA_MODELS = PersistentConfig(
     "ENABLE_EVALUATION_ARENA_MODELS",
     "ENABLE_EVALUATION_ARENA_MODELS",
@@ -948,12 +994,45 @@ TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
     os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""),
     os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""),
 )
 )
 
 
+DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
+
+Examples of titles:
+📉 Stock Market Trends
+🍪 Perfect Chocolate Chip Recipe
+Evolution of Music Streaming
+Remote Work Productivity Tips
+Artificial Intelligence in Healthcare
+🎮 Video Game Development Insights
+
+<chat_history>
+{{MESSAGES:END:2}}
+</chat_history>"""
+
+
 TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
 TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
     "TAGS_GENERATION_PROMPT_TEMPLATE",
     "TAGS_GENERATION_PROMPT_TEMPLATE",
     "task.tags.prompt_template",
     "task.tags.prompt_template",
     os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""),
     os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""),
 )
 )
 
 
+DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE = """### Task:
+Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags.
+
+### Guidelines:
+- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education)
+- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation
+- If content is too short (less than 3 messages) or too diverse, use only ["General"]
+- Use the chat's primary language; default to English if multilingual
+- Prioritize accuracy over specificity
+
+### Output:
+JSON format: { "tags": ["tag1", "tag2", "tag3"] }
+
+### Chat History:
+<chat_history>
+{{MESSAGES:END:6}}
+</chat_history>"""
+
 ENABLE_TAGS_GENERATION = PersistentConfig(
 ENABLE_TAGS_GENERATION = PersistentConfig(
     "ENABLE_TAGS_GENERATION",
     "ENABLE_TAGS_GENERATION",
     "task.tags.enable",
     "task.tags.enable",
@@ -1072,6 +1151,19 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
 )
 )
 
 
 
 
+DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text."""
+
+
+DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE = """Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
+
+Message: ```{{prompt}}```"""
+
+DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}"
+
+Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
+
+Responses from models: {{responses}}"""
+
 ####################################
 ####################################
 # Vector Database
 # Vector Database
 ####################################
 ####################################
@@ -1123,6 +1215,15 @@ if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"):
 # Information Retrieval (RAG)
 # Information Retrieval (RAG)
 ####################################
 ####################################
 
 
+
+# If configured, Google Drive will be available as an upload option.
+ENABLE_GOOGLE_DRIVE_INTEGRATION = PersistentConfig(
+    "ENABLE_GOOGLE_DRIVE_INTEGRATION",
+    "google_drive.enable",
+    os.getenv("ENABLE_GOOGLE_DRIVE_INTEGRATION", "False").lower() == "true",
+)
+
+
 # RAG Content Extraction
 # RAG Content Extraction
 CONTENT_EXTRACTION_ENGINE = PersistentConfig(
 CONTENT_EXTRACTION_ENGINE = PersistentConfig(
     "CONTENT_EXTRACTION_ENGINE",
     "CONTENT_EXTRACTION_ENGINE",
@@ -1197,7 +1298,8 @@ RAG_EMBEDDING_MODEL = PersistentConfig(
 log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}")
 log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}")
 
 
 RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
 RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
-    os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true"
+    not OFFLINE_MODE
+    and os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true"
 )
 )
 
 
 RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
 RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
@@ -1222,7 +1324,8 @@ if RAG_RERANKING_MODEL.value != "":
     log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
     log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
 
 
 RAG_RERANKING_MODEL_AUTO_UPDATE = (
 RAG_RERANKING_MODEL_AUTO_UPDATE = (
-    os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true"
+    not OFFLINE_MODE
+    and os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true"
 )
 )
 
 
 RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
 RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
@@ -1356,6 +1459,7 @@ RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
     ],
     ],
 )
 )
 
 
+
 SEARXNG_QUERY_URL = PersistentConfig(
 SEARXNG_QUERY_URL = PersistentConfig(
     "SEARXNG_QUERY_URL",
     "SEARXNG_QUERY_URL",
     "rag.web.search.searxng_query_url",
     "rag.web.search.searxng_query_url",
@@ -1380,6 +1484,12 @@ BRAVE_SEARCH_API_KEY = PersistentConfig(
     os.getenv("BRAVE_SEARCH_API_KEY", ""),
     os.getenv("BRAVE_SEARCH_API_KEY", ""),
 )
 )
 
 
+KAGI_SEARCH_API_KEY = PersistentConfig(
+    "KAGI_SEARCH_API_KEY",
+    "rag.web.search.kagi_search_api_key",
+    os.getenv("KAGI_SEARCH_API_KEY", ""),
+)
+
 MOJEEK_SEARCH_API_KEY = PersistentConfig(
 MOJEEK_SEARCH_API_KEY = PersistentConfig(
     "MOJEEK_SEARCH_API_KEY",
     "MOJEEK_SEARCH_API_KEY",
     "rag.web.search.mojeek_search_api_key",
     "rag.web.search.mojeek_search_api_key",
@@ -1525,6 +1635,12 @@ COMFYUI_BASE_URL = PersistentConfig(
     os.getenv("COMFYUI_BASE_URL", ""),
     os.getenv("COMFYUI_BASE_URL", ""),
 )
 )
 
 
+COMFYUI_API_KEY = PersistentConfig(
+    "COMFYUI_API_KEY",
+    "image_generation.comfyui.api_key",
+    os.getenv("COMFYUI_API_KEY", ""),
+)
+
 COMFYUI_DEFAULT_WORKFLOW = """
 COMFYUI_DEFAULT_WORKFLOW = """
 {
 {
   "3": {
   "3": {
@@ -1686,7 +1802,8 @@ WHISPER_MODEL = PersistentConfig(
 
 
 WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
 WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
 WHISPER_MODEL_AUTO_UPDATE = (
 WHISPER_MODEL_AUTO_UPDATE = (
-    os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
+    not OFFLINE_MODE
+    and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
 )
 )
 
 
 
 

+ 1 - 3
backend/open_webui/env.py

@@ -103,8 +103,6 @@ WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
 if WEBUI_NAME != "Open WebUI":
 if WEBUI_NAME != "Open WebUI":
     WEBUI_NAME += " (Open WebUI)"
     WEBUI_NAME += " (Open WebUI)"
 
 
-WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
-
 WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
 WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
 
 
 
 
@@ -376,7 +374,7 @@ else:
         AIOHTTP_CLIENT_TIMEOUT = 300
         AIOHTTP_CLIENT_TIMEOUT = 300
 
 
 AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
 AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
-    "AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "5"
+    "AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""
 )
 )
 
 
 if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
 if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":

+ 316 - 0
backend/open_webui/functions.py

@@ -0,0 +1,316 @@
+import logging
+import sys
+import inspect
+import json
+
+from pydantic import BaseModel
+from typing import AsyncGenerator, Generator, Iterator
+from fastapi import (
+    Depends,
+    FastAPI,
+    File,
+    Form,
+    HTTPException,
+    Request,
+    UploadFile,
+    status,
+)
+from starlette.responses import Response, StreamingResponse
+
+
+from open_webui.socket.main import (
+    get_event_call,
+    get_event_emitter,
+)
+
+
+from open_webui.models.functions import Functions
+from open_webui.models.models import Models
+
+from open_webui.utils.plugin import load_function_module_by_id
+from open_webui.utils.tools import get_tools
+from open_webui.utils.access_control import has_access
+
+from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
+
+from open_webui.utils.misc import (
+    add_or_update_system_message,
+    get_last_user_message,
+    prepend_to_first_user_message_content,
+    openai_chat_chunk_message_template,
+    openai_chat_completion_message_template,
+)
+from open_webui.utils.payload import (
+    apply_model_params_to_body_openai,
+    apply_model_system_prompt_to_body,
+)
+
+
+logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MAIN"])
+
+
+def get_function_module_by_id(request: Request, pipe_id: str):
+    # Check if function is already loaded
+    if pipe_id not in request.app.state.FUNCTIONS:
+        function_module, _, _ = load_function_module_by_id(pipe_id)
+        request.app.state.FUNCTIONS[pipe_id] = function_module
+    else:
+        function_module = request.app.state.FUNCTIONS[pipe_id]
+
+    if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
+        valves = Functions.get_function_valves_by_id(pipe_id)
+        function_module.valves = function_module.Valves(**(valves if valves else {}))
+    return function_module
+
+
+async def get_function_models(request):
+    pipes = Functions.get_functions_by_type("pipe", active_only=True)
+    pipe_models = []
+
+    for pipe in pipes:
+        function_module = get_function_module_by_id(request, pipe.id)
+
+        # Check if function is a manifold
+        if hasattr(function_module, "pipes"):
+            sub_pipes = []
+
+            # Check if pipes is a function or a list
+
+            try:
+                if callable(function_module.pipes):
+                    sub_pipes = function_module.pipes()
+                else:
+                    sub_pipes = function_module.pipes
+            except Exception as e:
+                log.exception(e)
+                sub_pipes = []
+
+            log.debug(
+                f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
+            )
+
+            for p in sub_pipes:
+                sub_pipe_id = f'{pipe.id}.{p["id"]}'
+                sub_pipe_name = p["name"]
+
+                if hasattr(function_module, "name"):
+                    sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
+
+                pipe_flag = {"type": pipe.type}
+
+                pipe_models.append(
+                    {
+                        "id": sub_pipe_id,
+                        "name": sub_pipe_name,
+                        "object": "model",
+                        "created": pipe.created_at,
+                        "owned_by": "openai",
+                        "pipe": pipe_flag,
+                    }
+                )
+        else:
+            pipe_flag = {"type": "pipe"}
+
+            log.debug(
+                f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
+            )
+
+            pipe_models.append(
+                {
+                    "id": pipe.id,
+                    "name": pipe.name,
+                    "object": "model",
+                    "created": pipe.created_at,
+                    "owned_by": "openai",
+                    "pipe": pipe_flag,
+                }
+            )
+
+    return pipe_models
+
+
+async def generate_function_chat_completion(
+    request, form_data, user, models: dict = {}
+):
+    async def execute_pipe(pipe, params):
+        if inspect.iscoroutinefunction(pipe):
+            return await pipe(**params)
+        else:
+            return pipe(**params)
+
+    async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
+        if isinstance(res, str):
+            return res
+        if isinstance(res, Generator):
+            return "".join(map(str, res))
+        if isinstance(res, AsyncGenerator):
+            return "".join([str(stream) async for stream in res])
+
+    def process_line(form_data: dict, line):
+        if isinstance(line, BaseModel):
+            line = line.model_dump_json()
+            line = f"data: {line}"
+        if isinstance(line, dict):
+            line = f"data: {json.dumps(line)}"
+
+        try:
+            line = line.decode("utf-8")
+        except Exception:
+            pass
+
+        if line.startswith("data:"):
+            return f"{line}\n\n"
+        else:
+            line = openai_chat_chunk_message_template(form_data["model"], line)
+            return f"data: {json.dumps(line)}\n\n"
+
+    def get_pipe_id(form_data: dict) -> str:
+        pipe_id = form_data["model"]
+        if "." in pipe_id:
+            pipe_id, _ = pipe_id.split(".", 1)
+        return pipe_id
+
+    def get_function_params(function_module, form_data, user, extra_params=None):
+        if extra_params is None:
+            extra_params = {}
+
+        pipe_id = get_pipe_id(form_data)
+
+        # Get the signature of the function
+        sig = inspect.signature(function_module.pipe)
+        params = {"body": form_data} | {
+            k: v for k, v in extra_params.items() if k in sig.parameters
+        }
+
+        if "__user__" in params and hasattr(function_module, "UserValves"):
+            user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
+            try:
+                params["__user__"]["valves"] = function_module.UserValves(**user_valves)
+            except Exception as e:
+                log.exception(e)
+                params["__user__"]["valves"] = function_module.UserValves()
+
+        return params
+
+    model_id = form_data.get("model")
+    model_info = Models.get_model_by_id(model_id)
+
+    metadata = form_data.pop("metadata", {})
+
+    files = metadata.get("files", [])
+    tool_ids = metadata.get("tool_ids", [])
+    # Check if tool_ids is None
+    if tool_ids is None:
+        tool_ids = []
+
+    __event_emitter__ = None
+    __event_call__ = None
+    __task__ = None
+    __task_body__ = None
+
+    if metadata:
+        if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
+            __event_emitter__ = get_event_emitter(metadata)
+            __event_call__ = get_event_call(metadata)
+        __task__ = metadata.get("task", None)
+        __task_body__ = metadata.get("task_body", None)
+
+    extra_params = {
+        "__event_emitter__": __event_emitter__,
+        "__event_call__": __event_call__,
+        "__task__": __task__,
+        "__task_body__": __task_body__,
+        "__files__": files,
+        "__user__": {
+            "id": user.id,
+            "email": user.email,
+            "name": user.name,
+            "role": user.role,
+        },
+        "__metadata__": metadata,
+        "__request__": request,
+    }
+    extra_params["__tools__"] = get_tools(
+        request,
+        tool_ids,
+        user,
+        {
+            **extra_params,
+            "__model__": models.get(form_data["model"], None),
+            "__messages__": form_data["messages"],
+            "__files__": files,
+        },
+    )
+
+    if model_info:
+        if model_info.base_model_id:
+            form_data["model"] = model_info.base_model_id
+
+        params = model_info.params.model_dump()
+        form_data = apply_model_params_to_body_openai(params, form_data)
+        form_data = apply_model_system_prompt_to_body(params, form_data, user)
+
+    pipe_id = get_pipe_id(form_data)
+    function_module = get_function_module_by_id(request, pipe_id)
+
+    pipe = function_module.pipe
+    params = get_function_params(function_module, form_data, user, extra_params)
+
+    if form_data.get("stream", False):
+
+        async def stream_content():
+            try:
+                res = await execute_pipe(pipe, params)
+
+                # Directly return if the response is a StreamingResponse
+                if isinstance(res, StreamingResponse):
+                    async for data in res.body_iterator:
+                        yield data
+                    return
+                if isinstance(res, dict):
+                    yield f"data: {json.dumps(res)}\n\n"
+                    return
+
+            except Exception as e:
+                log.error(f"Error: {e}")
+                yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
+                return
+
+            if isinstance(res, str):
+                message = openai_chat_chunk_message_template(form_data["model"], res)
+                yield f"data: {json.dumps(message)}\n\n"
+
+            if isinstance(res, Iterator):
+                for line in res:
+                    yield process_line(form_data, line)
+
+            if isinstance(res, AsyncGenerator):
+                async for line in res:
+                    yield process_line(form_data, line)
+
+            if isinstance(res, str) or isinstance(res, Generator):
+                finish_message = openai_chat_chunk_message_template(
+                    form_data["model"], ""
+                )
+                finish_message["choices"][0]["finish_reason"] = "stop"
+                yield f"data: {json.dumps(finish_message)}\n\n"
+                yield "data: [DONE]"
+
+        return StreamingResponse(stream_content(), media_type="text/event-stream")
+    else:
+        try:
+            res = await execute_pipe(pipe, params)
+
+        except Exception as e:
+            log.error(f"Error: {e}")
+            return {"error": {"detail": str(e)}}
+
+        if isinstance(res, StreamingResponse) or isinstance(res, dict):
+            return res
+        if isinstance(res, BaseModel):
+            return res.model_dump()
+
+        message = await get_message_content(res)
+        return openai_chat_completion_message_template(form_data["model"], message)

+ 2 - 2
backend/open_webui/apps/webui/internal/db.py → backend/open_webui/internal/db.py

@@ -3,7 +3,7 @@ import logging
 from contextlib import contextmanager
 from contextlib import contextmanager
 from typing import Any, Optional
 from typing import Any, Optional
 
 
-from open_webui.apps.webui.internal.wrappers import register_connection
+from open_webui.internal.wrappers import register_connection
 from open_webui.env import (
 from open_webui.env import (
     OPEN_WEBUI_DIR,
     OPEN_WEBUI_DIR,
     DATABASE_URL,
     DATABASE_URL,
@@ -54,7 +54,7 @@ def handle_peewee_migration(DATABASE_URL):
     try:
     try:
         # Replace the postgresql:// with postgres:// to handle the peewee migration
         # Replace the postgresql:// with postgres:// to handle the peewee migration
         db = register_connection(DATABASE_URL.replace("postgresql://", "postgres://"))
         db = register_connection(DATABASE_URL.replace("postgresql://", "postgres://"))
-        migrate_dir = OPEN_WEBUI_DIR / "apps" / "webui" / "internal" / "migrations"
+        migrate_dir = OPEN_WEBUI_DIR / "internal" / "migrations"
         router = Router(db, logger=log, migrate_dir=migrate_dir)
         router = Router(db, logger=log, migrate_dir=migrate_dir)
         router.run()
         router.run()
         db.close()
         db.close()

+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/001_initial_schema.py → backend/open_webui/internal/migrations/001_initial_schema.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/002_add_local_sharing.py → backend/open_webui/internal/migrations/002_add_local_sharing.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/003_add_auth_api_key.py → backend/open_webui/internal/migrations/003_add_auth_api_key.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/004_add_archived.py → backend/open_webui/internal/migrations/004_add_archived.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/005_add_updated_at.py → backend/open_webui/internal/migrations/005_add_updated_at.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py → backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/007_add_user_last_active_at.py → backend/open_webui/internal/migrations/007_add_user_last_active_at.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/008_add_memory.py → backend/open_webui/internal/migrations/008_add_memory.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/009_add_models.py → backend/open_webui/internal/migrations/009_add_models.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py → backend/open_webui/internal/migrations/010_migrate_modelfiles_to_models.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/011_add_user_settings.py → backend/open_webui/internal/migrations/011_add_user_settings.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/012_add_tools.py → backend/open_webui/internal/migrations/012_add_tools.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/013_add_user_info.py → backend/open_webui/internal/migrations/013_add_user_info.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/014_add_files.py → backend/open_webui/internal/migrations/014_add_files.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/015_add_functions.py → backend/open_webui/internal/migrations/015_add_functions.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/016_add_valves_and_is_active.py → backend/open_webui/internal/migrations/016_add_valves_and_is_active.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/017_add_user_oauth_sub.py → backend/open_webui/internal/migrations/017_add_user_oauth_sub.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/018_add_function_is_global.py → backend/open_webui/internal/migrations/018_add_function_is_global.py


+ 0 - 0
backend/open_webui/apps/webui/internal/wrappers.py → backend/open_webui/internal/wrappers.py


Fichier diff supprimé car celui-ci est trop grand
+ 655 - 2288
backend/open_webui/main.py


+ 1 - 1
backend/open_webui/migrations/env.py

@@ -1,7 +1,7 @@
 from logging.config import fileConfig
 from logging.config import fileConfig
 
 
 from alembic import context
 from alembic import context
-from open_webui.apps.webui.models.auths import Auth
+from open_webui.models.auths import Auth
 from open_webui.env import DATABASE_URL
 from open_webui.env import DATABASE_URL
 from sqlalchemy import engine_from_config, pool
 from sqlalchemy import engine_from_config, pool
 
 

+ 1 - 1
backend/open_webui/migrations/script.py.mako

@@ -9,7 +9,7 @@ from typing import Sequence, Union
 
 
 from alembic import op
 from alembic import op
 import sqlalchemy as sa
 import sqlalchemy as sa
-import open_webui.apps.webui.internal.db
+import open_webui.internal.db
 ${imports if imports else ""}
 ${imports if imports else ""}
 
 
 # revision identifiers, used by Alembic.
 # revision identifiers, used by Alembic.

+ 48 - 0
backend/open_webui/migrations/versions/57c599a3cb57_add_channel_table.py

@@ -0,0 +1,48 @@
+"""Add channel table
+
+Revision ID: 57c599a3cb57
+Revises: 922e7a387820
+Create Date: 2024-12-22 03:00:00.000000
+
+"""
+
+from alembic import op
+import sqlalchemy as sa
+
+revision = "57c599a3cb57"
+down_revision = "922e7a387820"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    op.create_table(
+        "channel",
+        sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
+        sa.Column("user_id", sa.Text()),
+        sa.Column("name", sa.Text()),
+        sa.Column("description", sa.Text(), nullable=True),
+        sa.Column("data", sa.JSON(), nullable=True),
+        sa.Column("meta", sa.JSON(), nullable=True),
+        sa.Column("access_control", sa.JSON(), nullable=True),
+        sa.Column("created_at", sa.BigInteger(), nullable=True),
+        sa.Column("updated_at", sa.BigInteger(), nullable=True),
+    )
+
+    op.create_table(
+        "message",
+        sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
+        sa.Column("user_id", sa.Text()),
+        sa.Column("channel_id", sa.Text(), nullable=True),
+        sa.Column("content", sa.Text()),
+        sa.Column("data", sa.JSON(), nullable=True),
+        sa.Column("meta", sa.JSON(), nullable=True),
+        sa.Column("created_at", sa.BigInteger(), nullable=True),
+        sa.Column("updated_at", sa.BigInteger(), nullable=True),
+    )
+
+
+def downgrade():
+    op.drop_table("channel")
+
+    op.drop_table("message")

+ 26 - 0
backend/open_webui/migrations/versions/7826ab40b532_update_file_table.py

@@ -0,0 +1,26 @@
+"""Update file table
+
+Revision ID: 7826ab40b532
+Revises: 57c599a3cb57
+Create Date: 2024-12-23 03:00:00.000000
+
+"""
+
+from alembic import op
+import sqlalchemy as sa
+
+revision = "7826ab40b532"
+down_revision = "57c599a3cb57"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    op.add_column(
+        "file",
+        sa.Column("access_control", sa.JSON(), nullable=True),
+    )
+
+
+def downgrade():
+    op.drop_column("file", "access_control")

+ 2 - 2
backend/open_webui/migrations/versions/7e5b5dc7342b_init.py

@@ -11,8 +11,8 @@ from typing import Sequence, Union
 import sqlalchemy as sa
 import sqlalchemy as sa
 from alembic import op
 from alembic import op
 
 
-import open_webui.apps.webui.internal.db
-from open_webui.apps.webui.internal.db import JSONField
+import open_webui.internal.db
+from open_webui.internal.db import JSONField
 from open_webui.migrations.util import get_existing_tables
 from open_webui.migrations.util import get_existing_tables
 
 
 # revision identifiers, used by Alembic.
 # revision identifiers, used by Alembic.

+ 3 - 3
backend/open_webui/apps/webui/models/auths.py → backend/open_webui/models/auths.py

@@ -2,12 +2,12 @@ import logging
 import uuid
 import uuid
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, get_db
-from open_webui.apps.webui.models.users import UserModel, Users
+from open_webui.internal.db import Base, get_db
+from open_webui.models.users import UserModel, Users
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from pydantic import BaseModel
 from pydantic import BaseModel
 from sqlalchemy import Boolean, Column, String, Text
 from sqlalchemy import Boolean, Column, String, Text
-from open_webui.utils.utils import verify_password
+from open_webui.utils.auth import verify_password
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
 log.setLevel(SRC_LOG_LEVELS["MODELS"])

+ 132 - 0
backend/open_webui/models/channels.py

@@ -0,0 +1,132 @@
+import json
+import time
+import uuid
+from typing import Optional
+
+from open_webui.internal.db import Base, get_db
+from open_webui.utils.access_control import has_access
+
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
+from sqlalchemy import or_, func, select, and_, text
+from sqlalchemy.sql import exists
+
+####################
+# Channel DB Schema
+####################
+
+
+class Channel(Base):
+    __tablename__ = "channel"
+
+    id = Column(Text, primary_key=True)
+    user_id = Column(Text)
+
+    name = Column(Text)
+    description = Column(Text, nullable=True)
+
+    data = Column(JSON, nullable=True)
+    meta = Column(JSON, nullable=True)
+    access_control = Column(JSON, nullable=True)
+
+    created_at = Column(BigInteger)
+    updated_at = Column(BigInteger)
+
+
+class ChannelModel(BaseModel):
+    model_config = ConfigDict(from_attributes=True)
+
+    id: str
+    user_id: str
+    description: Optional[str] = None
+
+    name: str
+    data: Optional[dict] = None
+    meta: Optional[dict] = None
+    access_control: Optional[dict] = None
+
+    created_at: int  # timestamp in epoch
+    updated_at: int  # timestamp in epoch
+
+
+####################
+# Forms
+####################
+
+
+class ChannelForm(BaseModel):
+    name: str
+    description: Optional[str] = None
+    data: Optional[dict] = None
+    meta: Optional[dict] = None
+    access_control: Optional[dict] = None
+
+
+class ChannelTable:
+    def insert_new_channel(
+        self, form_data: ChannelForm, user_id: str
+    ) -> Optional[ChannelModel]:
+        with get_db() as db:
+            channel = ChannelModel(
+                **{
+                    **form_data.model_dump(),
+                    "name": form_data.name.lower(),
+                    "id": str(uuid.uuid4()),
+                    "user_id": user_id,
+                    "created_at": int(time.time_ns()),
+                    "updated_at": int(time.time_ns()),
+                }
+            )
+
+            new_channel = Channel(**channel.model_dump())
+
+            db.add(new_channel)
+            db.commit()
+            return channel
+
+    def get_channels(self) -> list[ChannelModel]:
+        with get_db() as db:
+            channels = db.query(Channel).all()
+            return [ChannelModel.model_validate(channel) for channel in channels]
+
+    def get_channels_by_user_id(
+        self, user_id: str, permission: str = "read"
+    ) -> list[ChannelModel]:
+        channels = self.get_channels()
+        return [
+            channel
+            for channel in channels
+            if channel.user_id == user_id
+            or has_access(user_id, permission, channel.access_control)
+        ]
+
+    def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
+        with get_db() as db:
+            channel = db.query(Channel).filter(Channel.id == id).first()
+            return ChannelModel.model_validate(channel) if channel else None
+
+    def update_channel_by_id(
+        self, id: str, form_data: ChannelForm
+    ) -> Optional[ChannelModel]:
+        with get_db() as db:
+            channel = db.query(Channel).filter(Channel.id == id).first()
+            if not channel:
+                return None
+
+            channel.name = form_data.name
+            channel.data = form_data.data
+            channel.meta = form_data.meta
+            channel.access_control = form_data.access_control
+            channel.updated_at = int(time.time_ns())
+
+            db.commit()
+            return ChannelModel.model_validate(channel) if channel else None
+
+    def delete_channel_by_id(self, id: str):
+        with get_db() as db:
+            db.query(Channel).filter(Channel.id == id).delete()
+            db.commit()
+            return True
+
+
+Channels = ChannelTable()

+ 87 - 2
backend/open_webui/apps/webui/models/chats.py → backend/open_webui/models/chats.py

@@ -3,8 +3,8 @@ import time
 import uuid
 import uuid
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, get_db
-from open_webui.apps.webui.models.tags import TagModel, Tag, Tags
+from open_webui.internal.db import Base, get_db
+from open_webui.models.tags import TagModel, Tag, Tags
 
 
 
 
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
@@ -168,6 +168,91 @@ class ChatTable:
         except Exception:
         except Exception:
             return None
             return None
 
 
+    def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]:
+        chat = self.get_chat_by_id(id)
+        if chat is None:
+            return None
+
+        chat = chat.chat
+        chat["title"] = title
+
+        return self.update_chat_by_id(id, chat)
+
+    def update_chat_tags_by_id(
+        self, id: str, tags: list[str], user
+    ) -> Optional[ChatModel]:
+        chat = self.get_chat_by_id(id)
+        if chat is None:
+            return None
+
+        self.delete_all_tags_by_id_and_user_id(id, user.id)
+
+        for tag in chat.meta.get("tags", []):
+            if self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
+                Tags.delete_tag_by_name_and_user_id(tag, user.id)
+
+        for tag_name in tags:
+            if tag_name.lower() == "none":
+                continue
+
+            self.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, tag_name)
+        return self.get_chat_by_id(id)
+
+    def get_chat_title_by_id(self, id: str) -> Optional[str]:
+        chat = self.get_chat_by_id(id)
+        if chat is None:
+            return None
+
+        return chat.chat.get("title", "New Chat")
+
+    def get_messages_by_chat_id(self, id: str) -> Optional[dict]:
+        chat = self.get_chat_by_id(id)
+        if chat is None:
+            return None
+
+        return chat.chat.get("history", {}).get("messages", {}) or {}
+
+    def upsert_message_to_chat_by_id_and_message_id(
+        self, id: str, message_id: str, message: dict
+    ) -> Optional[ChatModel]:
+        chat = self.get_chat_by_id(id)
+        if chat is None:
+            return None
+
+        chat = chat.chat
+        history = chat.get("history", {})
+
+        if message_id in history.get("messages", {}):
+            history["messages"][message_id] = {
+                **history["messages"][message_id],
+                **message,
+            }
+        else:
+            history["messages"][message_id] = message
+
+        history["currentId"] = message_id
+
+        chat["history"] = history
+        return self.update_chat_by_id(id, chat)
+
+    def add_message_status_to_chat_by_id_and_message_id(
+        self, id: str, message_id: str, status: dict
+    ) -> Optional[ChatModel]:
+        chat = self.get_chat_by_id(id)
+        if chat is None:
+            return None
+
+        chat = chat.chat
+        history = chat.get("history", {})
+
+        if message_id in history.get("messages", {}):
+            status_history = history["messages"][message_id].get("statusHistory", [])
+            status_history.append(status)
+            history["messages"][message_id]["statusHistory"] = status_history
+
+        chat["history"] = history
+        return self.update_chat_by_id(id, chat)
+
     def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
     def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
         with get_db() as db:
         with get_db() as db:
             # Get the existing chat to share
             # Get the existing chat to share

+ 2 - 2
backend/open_webui/apps/webui/models/feedbacks.py → backend/open_webui/models/feedbacks.py

@@ -3,8 +3,8 @@ import time
 import uuid
 import uuid
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, get_db
-from open_webui.apps.webui.models.chats import Chats
+from open_webui.internal.db import Base, get_db
+from open_webui.models.chats import Chats
 
 
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict

+ 6 - 1
backend/open_webui/apps/webui/models/files.py → backend/open_webui/models/files.py

@@ -2,7 +2,7 @@ import logging
 import time
 import time
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, JSONField, get_db
+from open_webui.internal.db import Base, JSONField, get_db
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
 from sqlalchemy import BigInteger, Column, String, Text, JSON
 from sqlalchemy import BigInteger, Column, String, Text, JSON
@@ -27,6 +27,8 @@ class File(Base):
     data = Column(JSON, nullable=True)
     data = Column(JSON, nullable=True)
     meta = Column(JSON, nullable=True)
     meta = Column(JSON, nullable=True)
 
 
+    access_control = Column(JSON, nullable=True)
+
     created_at = Column(BigInteger)
     created_at = Column(BigInteger)
     updated_at = Column(BigInteger)
     updated_at = Column(BigInteger)
 
 
@@ -44,6 +46,8 @@ class FileModel(BaseModel):
     data: Optional[dict] = None
     data: Optional[dict] = None
     meta: Optional[dict] = None
     meta: Optional[dict] = None
 
 
+    access_control: Optional[dict] = None
+
     created_at: Optional[int]  # timestamp in epoch
     created_at: Optional[int]  # timestamp in epoch
     updated_at: Optional[int]  # timestamp in epoch
     updated_at: Optional[int]  # timestamp in epoch
 
 
@@ -90,6 +94,7 @@ class FileForm(BaseModel):
     path: str
     path: str
     data: dict = {}
     data: dict = {}
     meta: dict = {}
     meta: dict = {}
+    access_control: Optional[dict] = None
 
 
 
 
 class FilesTable:
 class FilesTable:

+ 2 - 2
backend/open_webui/apps/webui/models/folders.py → backend/open_webui/models/folders.py

@@ -3,8 +3,8 @@ import time
 import uuid
 import uuid
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, get_db
-from open_webui.apps.webui.models.chats import Chats
+from open_webui.internal.db import Base, get_db
+from open_webui.models.chats import Chats
 
 
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict

+ 2 - 2
backend/open_webui/apps/webui/models/functions.py → backend/open_webui/models/functions.py

@@ -2,8 +2,8 @@ import logging
 import time
 import time
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, JSONField, get_db
-from open_webui.apps.webui.models.users import Users
+from open_webui.internal.db import Base, JSONField, get_db
+from open_webui.models.users import Users
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
 from sqlalchemy import BigInteger, Boolean, Column, String, Text
 from sqlalchemy import BigInteger, Boolean, Column, String, Text

+ 9 - 2
backend/open_webui/apps/webui/models/groups.py → backend/open_webui/models/groups.py

@@ -4,10 +4,10 @@ import time
 from typing import Optional
 from typing import Optional
 import uuid
 import uuid
 
 
-from open_webui.apps.webui.internal.db import Base, get_db
+from open_webui.internal.db import Base, get_db
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
-from open_webui.apps.webui.models.files import FileMetadataResponse
+from open_webui.models.files import FileMetadataResponse
 
 
 
 
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
@@ -146,6 +146,13 @@ class GroupTable:
         except Exception:
         except Exception:
             return None
             return None
 
 
+    def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
+        group = self.get_group_by_id(id)
+        if group:
+            return group.user_ids
+        else:
+            return None
+
     def update_group_by_id(
     def update_group_by_id(
         self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
         self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
     ) -> Optional[GroupModel]:
     ) -> Optional[GroupModel]:

+ 3 - 3
backend/open_webui/apps/webui/models/knowledge.py → backend/open_webui/models/knowledge.py

@@ -4,11 +4,11 @@ import time
 from typing import Optional
 from typing import Optional
 import uuid
 import uuid
 
 
-from open_webui.apps.webui.internal.db import Base, get_db
+from open_webui.internal.db import Base, get_db
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
-from open_webui.apps.webui.models.files import FileMetadataResponse
-from open_webui.apps.webui.models.users import Users, UserResponse
+from open_webui.models.files import FileMetadataResponse
+from open_webui.models.users import Users, UserResponse
 
 
 
 
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict

+ 1 - 1
backend/open_webui/apps/webui/models/memories.py → backend/open_webui/models/memories.py

@@ -2,7 +2,7 @@ import time
 import uuid
 import uuid
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, get_db
+from open_webui.internal.db import Base, get_db
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
 from sqlalchemy import BigInteger, Column, String, Text
 from sqlalchemy import BigInteger, Column, String, Text
 
 

+ 141 - 0
backend/open_webui/models/messages.py

@@ -0,0 +1,141 @@
+import json
+import time
+import uuid
+from typing import Optional
+
+from open_webui.internal.db import Base, get_db
+from open_webui.models.tags import TagModel, Tag, Tags
+
+
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
+from sqlalchemy import or_, func, select, and_, text
+from sqlalchemy.sql import exists
+
+####################
+# Message DB Schema
+####################
+
+
+class Message(Base):
+    __tablename__ = "message"
+    id = Column(Text, primary_key=True)
+
+    user_id = Column(Text)
+    channel_id = Column(Text, nullable=True)
+
+    content = Column(Text)
+    data = Column(JSON, nullable=True)
+    meta = Column(JSON, nullable=True)
+
+    created_at = Column(BigInteger)  # time_ns
+    updated_at = Column(BigInteger)  # time_ns
+
+
+class MessageModel(BaseModel):
+    model_config = ConfigDict(from_attributes=True)
+
+    id: str
+    user_id: str
+    channel_id: Optional[str] = None
+
+    content: str
+    data: Optional[dict] = None
+    meta: Optional[dict] = None
+
+    created_at: int  # timestamp in epoch
+    updated_at: int  # timestamp in epoch
+
+
+####################
+# Forms
+####################
+
+
+class MessageForm(BaseModel):
+    content: str
+    data: Optional[dict] = None
+    meta: Optional[dict] = None
+
+
+class MessageTable:
+    def insert_new_message(
+        self, form_data: MessageForm, channel_id: str, user_id: str
+    ) -> Optional[MessageModel]:
+        with get_db() as db:
+            id = str(uuid.uuid4())
+
+            ts = int(time.time_ns())
+            message = MessageModel(
+                **{
+                    "id": id,
+                    "user_id": user_id,
+                    "channel_id": channel_id,
+                    "content": form_data.content,
+                    "data": form_data.data,
+                    "meta": form_data.meta,
+                    "created_at": ts,
+                    "updated_at": ts,
+                }
+            )
+
+            result = Message(**message.model_dump())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
+            return MessageModel.model_validate(result) if result else None
+
+    def get_message_by_id(self, id: str) -> Optional[MessageModel]:
+        with get_db() as db:
+            message = db.get(Message, id)
+            return MessageModel.model_validate(message) if message else None
+
+    def get_messages_by_channel_id(
+        self, channel_id: str, skip: int = 0, limit: int = 50
+    ) -> list[MessageModel]:
+        with get_db() as db:
+            all_messages = (
+                db.query(Message)
+                .filter_by(channel_id=channel_id)
+                .order_by(Message.created_at.desc())
+                .offset(skip)
+                .limit(limit)
+                .all()
+            )
+            return [MessageModel.model_validate(message) for message in all_messages]
+
+    def get_messages_by_user_id(
+        self, user_id: str, skip: int = 0, limit: int = 50
+    ) -> list[MessageModel]:
+        with get_db() as db:
+            all_messages = (
+                db.query(Message)
+                .filter_by(user_id=user_id)
+                .order_by(Message.created_at.desc())
+                .offset(skip)
+                .limit(limit)
+                .all()
+            )
+            return [MessageModel.model_validate(message) for message in all_messages]
+
+    def update_message_by_id(
+        self, id: str, form_data: MessageForm
+    ) -> Optional[MessageModel]:
+        with get_db() as db:
+            message = db.get(Message, id)
+            message.content = form_data.content
+            message.data = form_data.data
+            message.meta = form_data.meta
+            message.updated_at = int(time.time_ns())
+            db.commit()
+            db.refresh(message)
+            return MessageModel.model_validate(message) if message else None
+
+    def delete_message_by_id(self, id: str) -> bool:
+        with get_db() as db:
+            db.query(Message).filter_by(id=id).delete()
+            db.commit()
+            return True
+
+
+Messages = MessageTable()

+ 2 - 2
backend/open_webui/apps/webui/models/models.py → backend/open_webui/models/models.py

@@ -2,10 +2,10 @@ import logging
 import time
 import time
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, JSONField, get_db
+from open_webui.internal.db import Base, JSONField, get_db
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
-from open_webui.apps.webui.models.users import Users, UserResponse
+from open_webui.models.users import Users, UserResponse
 
 
 
 
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict

+ 2 - 2
backend/open_webui/apps/webui/models/prompts.py → backend/open_webui/models/prompts.py

@@ -1,8 +1,8 @@
 import time
 import time
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, get_db
-from open_webui.apps.webui.models.users import Users, UserResponse
+from open_webui.internal.db import Base, get_db
+from open_webui.models.users import Users, UserResponse
 
 
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
 from sqlalchemy import BigInteger, Column, String, Text, JSON
 from sqlalchemy import BigInteger, Column, String, Text, JSON

+ 1 - 1
backend/open_webui/apps/webui/models/tags.py → backend/open_webui/models/tags.py

@@ -3,7 +3,7 @@ import time
 import uuid
 import uuid
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, get_db
+from open_webui.internal.db import Base, get_db
 
 
 
 
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS

+ 2 - 2
backend/open_webui/apps/webui/models/tools.py → backend/open_webui/models/tools.py

@@ -2,8 +2,8 @@ import logging
 import time
 import time
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, JSONField, get_db
-from open_webui.apps.webui.models.users import Users, UserResponse
+from open_webui.internal.db import Base, JSONField, get_db
+from open_webui.models.users import Users, UserResponse
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
 from sqlalchemy import BigInteger, Column, String, Text, JSON
 from sqlalchemy import BigInteger, Column, String, Text, JSON

+ 43 - 8
backend/open_webui/apps/webui/models/users.py → backend/open_webui/models/users.py

@@ -1,8 +1,8 @@
 import time
 import time
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, JSONField, get_db
-from open_webui.apps.webui.models.chats import Chats
+from open_webui.internal.db import Base, JSONField, get_db
+from open_webui.models.chats import Chats
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
 from sqlalchemy import BigInteger, Column, String, Text
 from sqlalchemy import BigInteger, Column, String, Text
 
 
@@ -70,6 +70,13 @@ class UserResponse(BaseModel):
     profile_image_url: str
     profile_image_url: str
 
 
 
 
+class UserNameResponse(BaseModel):
+    id: str
+    name: str
+    role: str
+    profile_image_url: str
+
+
 class UserRoleUpdateForm(BaseModel):
 class UserRoleUpdateForm(BaseModel):
     id: str
     id: str
     role: str
     role: str
@@ -147,13 +154,25 @@ class UsersTable:
         except Exception:
         except Exception:
             return None
             return None
 
 
-    def get_users(self, skip: int = 0, limit: int = 50) -> list[UserModel]:
+    def get_users(
+        self, skip: Optional[int] = None, limit: Optional[int] = None
+    ) -> list[UserModel]:
         with get_db() as db:
         with get_db() as db:
-            users = (
-                db.query(User)
-                # .offset(skip).limit(limit)
-                .all()
-            )
+
+            query = db.query(User).order_by(User.created_at.desc())
+
+            if skip:
+                query = query.offset(skip)
+            if limit:
+                query = query.limit(limit)
+
+            users = query.all()
+
+            return [UserModel.model_validate(user) for user in users]
+
+    def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
+        with get_db() as db:
+            users = db.query(User).filter(User.id.in_(user_ids)).all()
             return [UserModel.model_validate(user) for user in users]
             return [UserModel.model_validate(user) for user in users]
 
 
     def get_num_users(self) -> Optional[int]:
     def get_num_users(self) -> Optional[int]:
@@ -168,6 +187,22 @@ class UsersTable:
         except Exception:
         except Exception:
             return None
             return None
 
 
+    def get_user_webhook_url_by_id(self, id: str) -> Optional[str]:
+        try:
+            with get_db() as db:
+                user = db.query(User).filter_by(id=id).first()
+
+                if user.settings is None:
+                    return None
+                else:
+                    return (
+                        user.settings.get("ui", {})
+                        .get("notifications", {})
+                        .get("webhook_url", None)
+                    )
+        except Exception:
+            return None
+
     def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
     def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
         try:
         try:
             with get_db() as db:
             with get_db() as db:

+ 4 - 2
backend/open_webui/apps/retrieval/loaders/main.py → backend/open_webui/retrieval/loaders/main.py

@@ -1,6 +1,7 @@
 import requests
 import requests
 import logging
 import logging
 import ftfy
 import ftfy
+import sys
 
 
 from langchain_community.document_loaders import (
 from langchain_community.document_loaders import (
     BSHTMLLoader,
     BSHTMLLoader,
@@ -18,8 +19,9 @@ from langchain_community.document_loaders import (
     YoutubeLoader,
     YoutubeLoader,
 )
 )
 from langchain_core.documents import Document
 from langchain_core.documents import Document
-from open_webui.env import SRC_LOG_LEVELS
+from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
 
 
+logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
@@ -106,7 +108,7 @@ class TikaLoader:
             if "Content-Type" in raw_metadata:
             if "Content-Type" in raw_metadata:
                 headers["Content-Type"] = raw_metadata["Content-Type"]
                 headers["Content-Type"] = raw_metadata["Content-Type"]
 
 
-            log.info("Tika extracted text: %s", text)
+            log.debug("Tika extracted text: %s", text)
 
 
             return [Document(page_content=text, metadata=headers)]
             return [Document(page_content=text, metadata=headers)]
         else:
         else:

+ 0 - 0
backend/open_webui/apps/retrieval/loaders/youtube.py → backend/open_webui/retrieval/loaders/youtube.py


+ 0 - 0
backend/open_webui/apps/retrieval/models/colbert.py → backend/open_webui/retrieval/models/colbert.py


+ 4 - 2
backend/open_webui/apps/retrieval/utils.py → backend/open_webui/retrieval/utils.py

@@ -11,7 +11,7 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev
 from langchain_community.retrievers import BM25Retriever
 from langchain_community.retrievers import BM25Retriever
 from langchain_core.documents import Document
 from langchain_core.documents import Document
 
 
-from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
+from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
 from open_webui.utils.misc import get_last_user_message
 from open_webui.utils.misc import get_last_user_message
 
 
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
@@ -70,7 +70,9 @@ def query_doc(
             limit=k,
             limit=k,
         )
         )
 
 
-        log.info(f"query_doc:result {result.ids} {result.metadatas}")
+        if result:
+            log.info(f"query_doc:result {result.ids} {result.metadatas}")
+
         return result
         return result
     except Exception as e:
     except Exception as e:
         print(e)
         print(e)

+ 22 - 0
backend/open_webui/retrieval/vector/connector.py

@@ -0,0 +1,22 @@
+from open_webui.config import VECTOR_DB
+
+if VECTOR_DB == "milvus":
+    from open_webui.retrieval.vector.dbs.milvus import MilvusClient
+
+    VECTOR_DB_CLIENT = MilvusClient()
+elif VECTOR_DB == "qdrant":
+    from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
+
+    VECTOR_DB_CLIENT = QdrantClient()
+elif VECTOR_DB == "opensearch":
+    from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
+
+    VECTOR_DB_CLIENT = OpenSearchClient()
+elif VECTOR_DB == "pgvector":
+    from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
+
+    VECTOR_DB_CLIENT = PgvectorClient()
+else:
+    from open_webui.retrieval.vector.dbs.chroma import ChromaClient
+
+    VECTOR_DB_CLIENT = ChromaClient()

+ 1 - 1
backend/open_webui/apps/retrieval/vector/dbs/chroma.py → backend/open_webui/retrieval/vector/dbs/chroma.py

@@ -4,7 +4,7 @@ from chromadb.utils.batch_utils import create_batches
 
 
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
 from open_webui.config import (
 from open_webui.config import (
     CHROMA_DATA_PATH,
     CHROMA_DATA_PATH,
     CHROMA_HTTP_HOST,
     CHROMA_HTTP_HOST,

+ 1 - 1
backend/open_webui/apps/retrieval/vector/dbs/milvus.py → backend/open_webui/retrieval/vector/dbs/milvus.py

@@ -4,7 +4,7 @@ import json
 
 
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
 from open_webui.config import (
 from open_webui.config import (
     MILVUS_URI,
     MILVUS_URI,
 )
 )

+ 1 - 1
backend/open_webui/apps/retrieval/vector/dbs/opensearch.py → backend/open_webui/retrieval/vector/dbs/opensearch.py

@@ -1,7 +1,7 @@
 from opensearchpy import OpenSearch
 from opensearchpy import OpenSearch
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
 from open_webui.config import (
 from open_webui.config import (
     OPENSEARCH_URI,
     OPENSEARCH_URI,
     OPENSEARCH_SSL,
     OPENSEARCH_SSL,

+ 2 - 2
backend/open_webui/apps/retrieval/vector/dbs/pgvector.py → backend/open_webui/retrieval/vector/dbs/pgvector.py

@@ -18,7 +18,7 @@ from sqlalchemy.dialects.postgresql import JSONB, array
 from pgvector.sqlalchemy import Vector
 from pgvector.sqlalchemy import Vector
 from sqlalchemy.ext.mutable import MutableDict
 from sqlalchemy.ext.mutable import MutableDict
 
 
-from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
 from open_webui.config import PGVECTOR_DB_URL
 from open_webui.config import PGVECTOR_DB_URL
 
 
 VECTOR_LENGTH = 1536
 VECTOR_LENGTH = 1536
@@ -40,7 +40,7 @@ class PgvectorClient:
 
 
         # if no pgvector uri, use the existing database connection
         # if no pgvector uri, use the existing database connection
         if not PGVECTOR_DB_URL:
         if not PGVECTOR_DB_URL:
-            from open_webui.apps.webui.internal.db import Session
+            from open_webui.internal.db import Session
 
 
             self.session = Session
             self.session = Session
         else:
         else:

+ 1 - 1
backend/open_webui/apps/retrieval/vector/dbs/qdrant.py → backend/open_webui/retrieval/vector/dbs/qdrant.py

@@ -4,7 +4,7 @@ from qdrant_client import QdrantClient as Qclient
 from qdrant_client.http.models import PointStruct
 from qdrant_client.http.models import PointStruct
 from qdrant_client.models import models
 from qdrant_client.models import models
 
 
-from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
 from open_webui.config import QDRANT_URI, QDRANT_API_KEY
 from open_webui.config import QDRANT_URI, QDRANT_API_KEY
 
 
 NO_LIMIT = 999999999
 NO_LIMIT = 999999999

+ 0 - 0
backend/open_webui/apps/retrieval/vector/main.py → backend/open_webui/retrieval/vector/main.py


+ 1 - 1
backend/open_webui/apps/retrieval/web/bing.py → backend/open_webui/retrieval/web/bing.py

@@ -3,7 +3,7 @@ import os
 from pprint import pprint
 from pprint import pprint
 from typing import Optional
 from typing import Optional
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 import argparse
 import argparse
 
 

+ 1 - 1
backend/open_webui/apps/retrieval/web/brave.py → backend/open_webui/retrieval/web/brave.py

@@ -2,7 +2,7 @@ import logging
 from typing import Optional
 from typing import Optional
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/web/duckduckgo.py → backend/open_webui/retrieval/web/duckduckgo.py

@@ -1,7 +1,7 @@
 import logging
 import logging
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
 from duckduckgo_search import DDGS
 from duckduckgo_search import DDGS
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 

+ 1 - 1
backend/open_webui/apps/retrieval/web/google_pse.py → backend/open_webui/retrieval/web/google_pse.py

@@ -2,7 +2,7 @@ import logging
 from typing import Optional
 from typing import Optional
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/web/jina_search.py → backend/open_webui/retrieval/web/jina_search.py

@@ -1,7 +1,7 @@
 import logging
 import logging
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult
+from open_webui.retrieval.web.main import SearchResult
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from yarl import URL
 from yarl import URL
 
 

+ 48 - 0
backend/open_webui/retrieval/web/kagi.py

@@ -0,0 +1,48 @@
+import logging
+from typing import Optional
+
+import requests
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.env import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+
+def search_kagi(
+    api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
+) -> list[SearchResult]:
+    """Search using Kagi's Search API and return the results as a list of SearchResult objects.
+
+    The Search API will inherit the settings in your account, including results personalization and snippet length.
+
+    Args:
+        api_key (str): A Kagi Search API key
+        query (str): The query to search for
+        count (int): The number of results to return
+    """
+    url = "https://kagi.com/api/v0/search"
+    headers = {
+        "Authorization": f"Bot {api_key}",
+    }
+    params = {"q": query, "limit": count}
+
+    response = requests.get(url, headers=headers, params=params)
+    response.raise_for_status()
+    json_response = response.json()
+    search_results = json_response.get("data", [])
+
+    results = [
+        SearchResult(
+            link=result["url"], title=result["title"], snippet=result.get("snippet")
+        )
+        for result in search_results
+        if result["t"] == 0
+    ]
+
+    print(results)
+
+    if filter_list:
+        results = get_filtered_results(results, filter_list)
+
+    return results

+ 0 - 0
backend/open_webui/apps/retrieval/web/main.py → backend/open_webui/retrieval/web/main.py


+ 1 - 1
backend/open_webui/apps/retrieval/web/mojeek.py → backend/open_webui/retrieval/web/mojeek.py

@@ -2,7 +2,7 @@ import logging
 from typing import Optional
 from typing import Optional
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/web/searchapi.py → backend/open_webui/retrieval/web/searchapi.py

@@ -3,7 +3,7 @@ from typing import Optional
 from urllib.parse import urlencode
 from urllib.parse import urlencode
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/web/searxng.py → backend/open_webui/retrieval/web/searxng.py

@@ -2,7 +2,7 @@ import logging
 from typing import Optional
 from typing import Optional
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/web/serper.py → backend/open_webui/retrieval/web/serper.py

@@ -3,7 +3,7 @@ import logging
 from typing import Optional
 from typing import Optional
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/web/serply.py → backend/open_webui/retrieval/web/serply.py

@@ -3,7 +3,7 @@ from typing import Optional
 from urllib.parse import urlencode
 from urllib.parse import urlencode
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/web/serpstack.py → backend/open_webui/retrieval/web/serpstack.py

@@ -2,7 +2,7 @@ import logging
 from typing import Optional
 from typing import Optional
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/web/tavily.py → backend/open_webui/retrieval/web/tavily.py

@@ -1,7 +1,7 @@
 import logging
 import logging
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult
+from open_webui.retrieval.web.main import SearchResult
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)

+ 0 - 0
backend/open_webui/apps/retrieval/web/testdata/bing.json → backend/open_webui/retrieval/web/testdata/bing.json


+ 0 - 0
backend/open_webui/apps/retrieval/web/testdata/brave.json → backend/open_webui/retrieval/web/testdata/brave.json


+ 0 - 0
backend/open_webui/apps/retrieval/web/testdata/google_pse.json → backend/open_webui/retrieval/web/testdata/google_pse.json


+ 0 - 0
backend/open_webui/apps/retrieval/web/testdata/searchapi.json → backend/open_webui/retrieval/web/testdata/searchapi.json


+ 0 - 0
backend/open_webui/apps/retrieval/web/testdata/searxng.json → backend/open_webui/retrieval/web/testdata/searxng.json


+ 0 - 0
backend/open_webui/apps/retrieval/web/testdata/serper.json → backend/open_webui/retrieval/web/testdata/serper.json


+ 0 - 0
backend/open_webui/apps/retrieval/web/testdata/serply.json → backend/open_webui/retrieval/web/testdata/serply.json


+ 0 - 0
backend/open_webui/apps/retrieval/web/testdata/serpstack.json → backend/open_webui/retrieval/web/testdata/serpstack.json


+ 3 - 3
backend/open_webui/apps/retrieval/web/utils.py → backend/open_webui/retrieval/web/utils.py

@@ -82,15 +82,15 @@ class SafeWebBaseLoader(WebBaseLoader):
 
 
 
 
 def get_web_loader(
 def get_web_loader(
-    url: Union[str, Sequence[str]],
+    urls: Union[str, Sequence[str]],
     verify_ssl: bool = True,
     verify_ssl: bool = True,
     requests_per_second: int = 2,
     requests_per_second: int = 2,
 ):
 ):
     # Check if the URL is valid
     # Check if the URL is valid
-    if not validate_url(url):
+    if not validate_url(urls):
         raise ValueError(ERROR_MESSAGES.INVALID_URL)
         raise ValueError(ERROR_MESSAGES.INVALID_URL)
     return SafeWebBaseLoader(
     return SafeWebBaseLoader(
-        url,
+        urls,
         verify_ssl=verify_ssl,
         verify_ssl=verify_ssl,
         requests_per_second=requests_per_second,
         requests_per_second=requests_per_second,
         continue_on_failure=True,
         continue_on_failure=True,

+ 713 - 0
backend/open_webui/routers/audio.py

@@ -0,0 +1,713 @@
+import hashlib
+import json
+import logging
+import os
+import uuid
+from functools import lru_cache
+from pathlib import Path
+from pydub import AudioSegment
+from pydub.silence import split_on_silence
+
+import aiohttp
+import aiofiles
+import requests
+
+from fastapi import (
+    Depends,
+    FastAPI,
+    File,
+    HTTPException,
+    Request,
+    UploadFile,
+    status,
+    APIRouter,
+)
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse
+from pydantic import BaseModel
+
+
+from open_webui.utils.auth import get_admin_user, get_verified_user
+from open_webui.config import (
+    WHISPER_MODEL_AUTO_UPDATE,
+    WHISPER_MODEL_DIR,
+    CACHE_DIR,
+)
+
+from open_webui.constants import ERROR_MESSAGES
+from open_webui.env import (
+    ENV,
+    SRC_LOG_LEVELS,
+    DEVICE_TYPE,
+    ENABLE_FORWARD_USER_INFO_HEADERS,
+)
+
+
+router = APIRouter()
+
+# Constants
+MAX_FILE_SIZE_MB = 25
+MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024  # Convert MB to bytes
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["AUDIO"])
+
+SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
+SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
+
+
+##########################################
+#
+# Utility functions
+#
+##########################################
+
+from pydub import AudioSegment
+from pydub.utils import mediainfo
+
+
+def is_mp4_audio(file_path):
+    """Check if the given file is an MP4 audio file."""
+    if not os.path.isfile(file_path):
+        print(f"File not found: {file_path}")
+        return False
+
+    info = mediainfo(file_path)
+    if (
+        info.get("codec_name") == "aac"
+        and info.get("codec_type") == "audio"
+        and info.get("codec_tag_string") == "mp4a"
+    ):
+        return True
+    return False
+
+
+def convert_mp4_to_wav(file_path, output_path):
+    """Convert MP4 audio file to WAV format."""
+    audio = AudioSegment.from_file(file_path, format="mp4")
+    audio.export(output_path, format="wav")
+    print(f"Converted {file_path} to {output_path}")
+
+
+def set_faster_whisper_model(model: str, auto_update: bool = False):
+    whisper_model = None
+    if model:
+        from faster_whisper import WhisperModel
+
+        faster_whisper_kwargs = {
+            "model_size_or_path": model,
+            "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu",
+            "compute_type": "int8",
+            "download_root": WHISPER_MODEL_DIR,
+            "local_files_only": not auto_update,
+        }
+
+        try:
+            whisper_model = WhisperModel(**faster_whisper_kwargs)
+        except Exception:
+            log.warning(
+                "WhisperModel initialization failed, attempting download with local_files_only=False"
+            )
+            faster_whisper_kwargs["local_files_only"] = False
+            whisper_model = WhisperModel(**faster_whisper_kwargs)
+    return whisper_model
+
+
+##########################################
+#
+# Audio API
+#
+##########################################
+
+
+class TTSConfigForm(BaseModel):
+    OPENAI_API_BASE_URL: str
+    OPENAI_API_KEY: str
+    API_KEY: str
+    ENGINE: str
+    MODEL: str
+    VOICE: str
+    SPLIT_ON: str
+    AZURE_SPEECH_REGION: str
+    AZURE_SPEECH_OUTPUT_FORMAT: str
+
+
+class STTConfigForm(BaseModel):
+    OPENAI_API_BASE_URL: str
+    OPENAI_API_KEY: str
+    ENGINE: str
+    MODEL: str
+    WHISPER_MODEL: str
+
+
+class AudioConfigUpdateForm(BaseModel):
+    tts: TTSConfigForm
+    stt: STTConfigForm
+
+
+@router.get("/config")
+async def get_audio_config(request: Request, user=Depends(get_admin_user)):
+    return {
+        "tts": {
+            "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
+            "API_KEY": request.app.state.config.TTS_API_KEY,
+            "ENGINE": request.app.state.config.TTS_ENGINE,
+            "MODEL": request.app.state.config.TTS_MODEL,
+            "VOICE": request.app.state.config.TTS_VOICE,
+            "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
+            "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
+            "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
+        },
+        "stt": {
+            "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
+            "ENGINE": request.app.state.config.STT_ENGINE,
+            "MODEL": request.app.state.config.STT_MODEL,
+            "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
+        },
+    }
+
+
+@router.post("/config/update")
+async def update_audio_config(
+    request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
+):
+    request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
+    request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
+    request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
+    request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
+    request.app.state.config.TTS_MODEL = form_data.tts.MODEL
+    request.app.state.config.TTS_VOICE = form_data.tts.VOICE
+    request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
+    request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
+    request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
+        form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
+    )
+
+    request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
+    request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
+    request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
+    request.app.state.config.STT_MODEL = form_data.stt.MODEL
+    request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
+
+    if request.app.state.config.STT_ENGINE == "":
+        request.app.state.faster_whisper_model = set_faster_whisper_model(
+            form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE
+        )
+
+    return {
+        "tts": {
+            "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
+            "API_KEY": request.app.state.config.TTS_API_KEY,
+            "ENGINE": request.app.state.config.TTS_ENGINE,
+            "MODEL": request.app.state.config.TTS_MODEL,
+            "VOICE": request.app.state.config.TTS_VOICE,
+            "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
+            "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
+            "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
+        },
+        "stt": {
+            "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
+            "ENGINE": request.app.state.config.STT_ENGINE,
+            "MODEL": request.app.state.config.STT_MODEL,
+            "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
+        },
+    }
+
+
+def load_speech_pipeline(request):
+    from transformers import pipeline
+    from datasets import load_dataset
+
+    if request.app.state.speech_synthesiser is None:
+        request.app.state.speech_synthesiser = pipeline(
+            "text-to-speech", "microsoft/speecht5_tts"
+        )
+
+    if request.app.state.speech_speaker_embeddings_dataset is None:
+        request.app.state.speech_speaker_embeddings_dataset = load_dataset(
+            "Matthijs/cmu-arctic-xvectors", split="validation"
+        )
+
+
+@router.post("/speech")
+async def speech(request: Request, user=Depends(get_verified_user)):
+    body = await request.body()
+    name = hashlib.sha256(
+        body
+        + str(request.app.state.config.TTS_ENGINE).encode("utf-8")
+        + str(request.app.state.config.TTS_MODEL).encode("utf-8")
+    ).hexdigest()
+
+    file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
+    file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
+
+    # Check if the file already exists in the cache
+    if file_path.is_file():
+        return FileResponse(file_path)
+
+    payload = None
+    try:
+        payload = json.loads(body.decode("utf-8"))
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(status_code=400, detail="Invalid JSON payload")
+
+    if request.app.state.config.TTS_ENGINE == "openai":
+        payload["model"] = request.app.state.config.TTS_MODEL
+
+        try:
+            # print(payload)
+            async with aiohttp.ClientSession() as session:
+                async with session.post(
+                    url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
+                    json=payload,
+                    headers={
+                        "Content-Type": "application/json",
+                        "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
+                        **(
+                            {
+                                "X-OpenWebUI-User-Name": user.name,
+                                "X-OpenWebUI-User-Id": user.id,
+                                "X-OpenWebUI-User-Email": user.email,
+                                "X-OpenWebUI-User-Role": user.role,
+                            }
+                            if ENABLE_FORWARD_USER_INFO_HEADERS
+                            else {}
+                        ),
+                    },
+                ) as r:
+                    r.raise_for_status()
+
+                    async with aiofiles.open(file_path, "wb") as f:
+                        await f.write(await r.read())
+
+                    async with aiofiles.open(file_body_path, "w") as f:
+                        await f.write(json.dumps(payload))
+
+            return FileResponse(file_path)
+
+        except Exception as e:
+            log.exception(e)
+            detail = None
+
+            try:
+                if r.status != 200:
+                    res = await r.json()
+
+                    if "error" in res:
+                        detail = f"External: {res['error'].get('message', '')}"
+            except Exception:
+                detail = f"External: {e}"
+
+            raise HTTPException(
+                status_code=getattr(r, "status", 500),
+                detail=detail if detail else "Open WebUI: Server Connection Error",
+            )
+
+    elif request.app.state.config.TTS_ENGINE == "elevenlabs":
+        voice_id = payload.get("voice", "")
+
+        if voice_id not in get_available_voices():
+            raise HTTPException(
+                status_code=400,
+                detail="Invalid voice id",
+            )
+
+        try:
+            async with aiohttp.ClientSession() as session:
+                async with session.post(
+                    f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
+                    json={
+                        "text": payload["input"],
+                        "model_id": request.app.state.config.TTS_MODEL,
+                        "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
+                    },
+                    headers={
+                        "Accept": "audio/mpeg",
+                        "Content-Type": "application/json",
+                        "xi-api-key": request.app.state.config.TTS_API_KEY,
+                    },
+                ) as r:
+                    r.raise_for_status()
+
+                    async with aiofiles.open(file_path, "wb") as f:
+                        await f.write(await r.read())
+
+                    async with aiofiles.open(file_body_path, "w") as f:
+                        await f.write(json.dumps(payload))
+
+            return FileResponse(file_path)
+
+        except Exception as e:
+            log.exception(e)
+            detail = None
+
+            try:
+                if r.status != 200:
+                    res = await r.json()
+                    if "error" in res:
+                        detail = f"External: {res['error'].get('message', '')}"
+            except Exception:
+                detail = f"External: {e}"
+
+            raise HTTPException(
+                status_code=getattr(r, "status", 500),
+                detail=detail if detail else "Open WebUI: Server Connection Error",
+            )
+
+    elif request.app.state.config.TTS_ENGINE == "azure":
+        try:
+            payload = json.loads(body.decode("utf-8"))
+        except Exception as e:
+            log.exception(e)
+            raise HTTPException(status_code=400, detail="Invalid JSON payload")
+
+        region = request.app.state.config.TTS_AZURE_SPEECH_REGION
+        language = request.app.state.config.TTS_VOICE
+        locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
+        output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
+
+        try:
+            data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
+                <voice name="{language}">{payload["input"]}</voice>
+            </speak>"""
+            async with aiohttp.ClientSession() as session:
+                async with session.post(
+                    f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
+                    headers={
+                        "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
+                        "Content-Type": "application/ssml+xml",
+                        "X-Microsoft-OutputFormat": output_format,
+                    },
+                    data=data,
+                ) as r:
+                    r.raise_for_status()
+
+                    async with aiofiles.open(file_path, "wb") as f:
+                        await f.write(await r.read())
+
+                    async with aiofiles.open(file_body_path, "w") as f:
+                        await f.write(json.dumps(payload))
+
+                    return FileResponse(file_path)
+
+        except Exception as e:
+            log.exception(e)
+            detail = None
+
+            try:
+                if r.status != 200:
+                    res = await r.json()
+                    if "error" in res:
+                        detail = f"External: {res['error'].get('message', '')}"
+            except Exception:
+                detail = f"External: {e}"
+
+            raise HTTPException(
+                status_code=getattr(r, "status", 500),
+                detail=detail if detail else "Open WebUI: Server Connection Error",
+            )
+
+    elif request.app.state.config.TTS_ENGINE == "transformers":
+        payload = None
+        try:
+            payload = json.loads(body.decode("utf-8"))
+        except Exception as e:
+            log.exception(e)
+            raise HTTPException(status_code=400, detail="Invalid JSON payload")
+
+        import torch
+        import soundfile as sf
+
+        load_speech_pipeline(request)
+
+        embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
+
+        speaker_index = 6799
+        try:
+            speaker_index = embeddings_dataset["filename"].index(
+                request.app.state.config.TTS_MODEL
+            )
+        except Exception:
+            pass
+
+        speaker_embedding = torch.tensor(
+            embeddings_dataset[speaker_index]["xvector"]
+        ).unsqueeze(0)
+
+        speech = request.app.state.speech_synthesiser(
+            payload["input"],
+            forward_params={"speaker_embeddings": speaker_embedding},
+        )
+
+        sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
+
+        async with aiofiles.open(file_body_path, "w") as f:
+            await f.write(json.dumps(payload))
+
+        return FileResponse(file_path)
+
+
+def transcribe(request: Request, file_path):
+    print("transcribe", file_path)
+    filename = os.path.basename(file_path)
+    file_dir = os.path.dirname(file_path)
+    id = filename.split(".")[0]
+
+    if request.app.state.config.STT_ENGINE == "":
+        if request.app.state.faster_whisper_model is None:
+            request.app.state.faster_whisper_model = set_faster_whisper_model(
+                request.app.state.config.WHISPER_MODEL
+            )
+
+        model = request.app.state.faster_whisper_model
+        segments, info = model.transcribe(file_path, beam_size=5)
+        log.info(
+            "Detected language '%s' with probability %f"
+            % (info.language, info.language_probability)
+        )
+
+        transcript = "".join([segment.text for segment in list(segments)])
+        data = {"text": transcript.strip()}
+
+        # save the transcript to a json file
+        transcript_file = f"{file_dir}/{id}.json"
+        with open(transcript_file, "w") as f:
+            json.dump(data, f)
+
+        log.debug(data)
+        return data
+    elif request.app.state.config.STT_ENGINE == "openai":
+        if is_mp4_audio(file_path):
+            os.rename(file_path, file_path.replace(".wav", ".mp4"))
+            # Convert MP4 audio file to WAV format
+            convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
+
+        r = None
+        try:
+            r = requests.post(
+                url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
+                headers={
+                    "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
+                },
+                files={"file": (filename, open(file_path, "rb"))},
+                data={"model": request.app.state.config.STT_MODEL},
+            )
+
+            r.raise_for_status()
+            data = r.json()
+
+            # save the transcript to a json file
+            transcript_file = f"{file_dir}/{id}.json"
+            with open(transcript_file, "w") as f:
+                json.dump(data, f)
+
+            return data
+        except Exception as e:
+            log.exception(e)
+
+            detail = None
+            if r is not None:
+                try:
+                    res = r.json()
+                    if "error" in res:
+                        detail = f"External: {res['error'].get('message', '')}"
+                except Exception:
+                    detail = f"External: {e}"
+
+            raise Exception(detail if detail else "Open WebUI: Server Connection Error")
+
+
+def compress_audio(file_path):
+    if os.path.getsize(file_path) > MAX_FILE_SIZE:
+        file_dir = os.path.dirname(file_path)
+        audio = AudioSegment.from_file(file_path)
+        audio = audio.set_frame_rate(16000).set_channels(1)  # Compress audio
+        compressed_path = f"{file_dir}/{id}_compressed.opus"
+        audio.export(compressed_path, format="opus", bitrate="32k")
+        log.debug(f"Compressed audio to {compressed_path}")
+
+        if (
+            os.path.getsize(compressed_path) > MAX_FILE_SIZE
+        ):  # Still larger than MAX_FILE_SIZE after compression
+            raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
+        return compressed_path
+    else:
+        return file_path
+
+
+@router.post("/transcriptions")
+def transcription(
+    request: Request,
+    file: UploadFile = File(...),
+    user=Depends(get_verified_user),
+):
+    log.info(f"file.content_type: {file.content_type}")
+
+    if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
+        )
+
+    try:
+        ext = file.filename.split(".")[-1]
+        id = uuid.uuid4()
+
+        filename = f"{id}.{ext}"
+        contents = file.file.read()
+
+        file_dir = f"{CACHE_DIR}/audio/transcriptions"
+        os.makedirs(file_dir, exist_ok=True)
+        file_path = f"{file_dir}/{filename}"
+
+        with open(file_path, "wb") as f:
+            f.write(contents)
+
+        try:
+            try:
+                file_path = compress_audio(file_path)
+            except Exception as e:
+                log.exception(e)
+
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT(e),
+                )
+
+            data = transcribe(request, file_path)
+            file_path = file_path.split("/")[-1]
+            return {**data, "filename": file_path}
+        except Exception as e:
+            log.exception(e)
+
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT(e),
+            )
+
+    except Exception as e:
+        log.exception(e)
+
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
+
+def get_available_models(request: Request) -> list[dict]:
+    available_models = []
+    if request.app.state.config.TTS_ENGINE == "openai":
+        available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
+    elif request.app.state.config.TTS_ENGINE == "elevenlabs":
+        try:
+            response = requests.get(
+                "https://api.elevenlabs.io/v1/models",
+                headers={
+                    "xi-api-key": request.app.state.config.TTS_API_KEY,
+                    "Content-Type": "application/json",
+                },
+                timeout=5,
+            )
+            response.raise_for_status()
+            models = response.json()
+
+            available_models = [
+                {"name": model["name"], "id": model["model_id"]} for model in models
+            ]
+        except requests.RequestException as e:
+            log.error(f"Error fetching voices: {str(e)}")
+    return available_models
+
+
+@router.get("/models")
+async def get_models(request: Request, user=Depends(get_verified_user)):
+    return {"models": get_available_models(request)}
+
+
+def get_available_voices(request) -> dict:
+    """Returns {voice_id: voice_name} dict"""
+    available_voices = {}
+    if request.app.state.config.TTS_ENGINE == "openai":
+        available_voices = {
+            "alloy": "alloy",
+            "echo": "echo",
+            "fable": "fable",
+            "onyx": "onyx",
+            "nova": "nova",
+            "shimmer": "shimmer",
+        }
+    elif request.app.state.config.TTS_ENGINE == "elevenlabs":
+        try:
+            available_voices = get_elevenlabs_voices(
+                api_key=request.app.state.config.TTS_API_KEY
+            )
+        except Exception:
+            # Avoided @lru_cache with exception
+            pass
+    elif request.app.state.config.TTS_ENGINE == "azure":
+        try:
+            region = request.app.state.config.TTS_AZURE_SPEECH_REGION
+            url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
+            headers = {
+                "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
+            }
+
+            response = requests.get(url, headers=headers)
+            response.raise_for_status()
+            voices = response.json()
+
+            for voice in voices:
+                available_voices[voice["ShortName"]] = (
+                    f"{voice['DisplayName']} ({voice['ShortName']})"
+                )
+        except requests.RequestException as e:
+            log.error(f"Error fetching voices: {str(e)}")
+
+    return available_voices
+
+
+@lru_cache
+def get_elevenlabs_voices(api_key: str) -> dict:
+    """
+    Note, set the following in your .env file to use Elevenlabs:
+    AUDIO_TTS_ENGINE=elevenlabs
+    AUDIO_TTS_API_KEY=sk_...  # Your Elevenlabs API key
+    AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL  # From https://api.elevenlabs.io/v1/voices
+    AUDIO_TTS_MODEL=eleven_multilingual_v2
+    """
+
+    try:
+        # TODO: Add retries
+        response = requests.get(
+            "https://api.elevenlabs.io/v1/voices",
+            headers={
+                "xi-api-key": api_key,
+                "Content-Type": "application/json",
+            },
+        )
+        response.raise_for_status()
+        voices_data = response.json()
+
+        voices = {}
+        for voice in voices_data.get("voices", []):
+            voices[voice["voice_id"]] = voice["name"]
+    except requests.RequestException as e:
+        # Avoid @lru_cache with exception
+        log.error(f"Error fetching voices: {str(e)}")
+        raise RuntimeError(f"Error fetching voices: {str(e)}")
+
+    return voices
+
+
+@router.get("/voices")
+async def get_voices(request: Request, user=Depends(get_verified_user)):
+    return {
+        "voices": [
+            {"id": k, "name": v} for k, v in get_available_voices(request).items()
+        ]
+    }

+ 41 - 6
backend/open_webui/apps/webui/routers/auths.py → backend/open_webui/routers/auths.py

@@ -3,8 +3,9 @@ import uuid
 import time
 import time
 import datetime
 import datetime
 import logging
 import logging
+from aiohttp import ClientSession
 
 
-from open_webui.apps.webui.models.auths import (
+from open_webui.models.auths import (
     AddUserForm,
     AddUserForm,
     ApiKey,
     ApiKey,
     Auths,
     Auths,
@@ -17,7 +18,7 @@ from open_webui.apps.webui.models.auths import (
     UpdateProfileForm,
     UpdateProfileForm,
     UserResponse,
     UserResponse,
 )
 )
-from open_webui.apps.webui.models.users import Users
+from open_webui.models.users import Users
 
 
 from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
 from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
 from open_webui.env import (
 from open_webui.env import (
@@ -29,10 +30,14 @@ from open_webui.env import (
     SRC_LOG_LEVELS,
     SRC_LOG_LEVELS,
 )
 )
 from fastapi import APIRouter, Depends, HTTPException, Request, status
 from fastapi import APIRouter, Depends, HTTPException, Request, status
-from fastapi.responses import Response
+from fastapi.responses import RedirectResponse, Response
+from open_webui.config import (
+    OPENID_PROVIDER_URL,
+    ENABLE_OAUTH_SIGNUP,
+)
 from pydantic import BaseModel
 from pydantic import BaseModel
 from open_webui.utils.misc import parse_duration, validate_email_format
 from open_webui.utils.misc import parse_duration, validate_email_format
-from open_webui.utils.utils import (
+from open_webui.utils.auth import (
     create_api_key,
     create_api_key,
     create_token,
     create_token,
     get_admin_user,
     get_admin_user,
@@ -498,8 +503,31 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
 
 
 
 
 @router.get("/signout")
 @router.get("/signout")
-async def signout(response: Response):
+async def signout(request: Request, response: Response):
     response.delete_cookie("token")
     response.delete_cookie("token")
+
+    if ENABLE_OAUTH_SIGNUP.value:
+        oauth_id_token = request.cookies.get("oauth_id_token")
+        if oauth_id_token:
+            try:
+                async with ClientSession() as session:
+                    async with session.get(OPENID_PROVIDER_URL.value) as resp:
+                        if resp.status == 200:
+                            openid_data = await resp.json()
+                            logout_url = openid_data.get("end_session_endpoint")
+                            if logout_url:
+                                response.delete_cookie("oauth_id_token")
+                                return RedirectResponse(
+                                    url=f"{logout_url}?id_token_hint={oauth_id_token}"
+                                )
+                        else:
+                            raise HTTPException(
+                                status_code=resp.status,
+                                detail="Failed to fetch OpenID configuration",
+                            )
+            except Exception as e:
+                raise HTTPException(status_code=500, detail=str(e))
+
     return {"status": True}
     return {"status": True}
 
 
 
 
@@ -519,7 +547,6 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
 
 
     try:
     try:
-        print(form_data)
         hashed = get_password_hash(form_data.password)
         hashed = get_password_hash(form_data.password)
         user = Auths.insert_new_auth(
         user = Auths.insert_new_auth(
             form_data.email.lower(),
             form_data.email.lower(),
@@ -586,8 +613,10 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
 async def get_admin_config(request: Request, user=Depends(get_admin_user)):
 async def get_admin_config(request: Request, user=Depends(get_admin_user)):
     return {
     return {
         "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
         "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
+        "WEBUI_URL": request.app.state.config.WEBUI_URL,
         "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
         "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
         "ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
         "ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
+        "ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
         "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
         "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
         "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
         "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
         "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
         "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
@@ -597,8 +626,10 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
 
 
 class AdminConfig(BaseModel):
 class AdminConfig(BaseModel):
     SHOW_ADMIN_DETAILS: bool
     SHOW_ADMIN_DETAILS: bool
+    WEBUI_URL: str
     ENABLE_SIGNUP: bool
     ENABLE_SIGNUP: bool
     ENABLE_API_KEY: bool
     ENABLE_API_KEY: bool
+    ENABLE_CHANNELS: bool
     DEFAULT_USER_ROLE: str
     DEFAULT_USER_ROLE: str
     JWT_EXPIRES_IN: str
     JWT_EXPIRES_IN: str
     ENABLE_COMMUNITY_SHARING: bool
     ENABLE_COMMUNITY_SHARING: bool
@@ -610,8 +641,10 @@ async def update_admin_config(
     request: Request, form_data: AdminConfig, user=Depends(get_admin_user)
     request: Request, form_data: AdminConfig, user=Depends(get_admin_user)
 ):
 ):
     request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
     request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
+    request.app.state.config.WEBUI_URL = form_data.WEBUI_URL
     request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
     request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
     request.app.state.config.ENABLE_API_KEY = form_data.ENABLE_API_KEY
     request.app.state.config.ENABLE_API_KEY = form_data.ENABLE_API_KEY
+    request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS
 
 
     if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
     if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
         request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
         request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
@@ -629,8 +662,10 @@ async def update_admin_config(
 
 
     return {
     return {
         "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
         "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
+        "WEBUI_URL": request.app.state.config.WEBUI_URL,
         "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
         "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
         "ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
         "ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
+        "ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
         "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
         "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
         "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
         "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
         "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
         "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,

+ 394 - 0
backend/open_webui/routers/channels.py

@@ -0,0 +1,394 @@
+import json
+import logging
+from typing import Optional
+
+
+from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks
+from pydantic import BaseModel
+
+
+from open_webui.socket.main import sio, get_user_ids_from_room
+from open_webui.models.users import Users, UserNameResponse
+
+from open_webui.models.channels import Channels, ChannelModel, ChannelForm
+from open_webui.models.messages import Messages, MessageModel, MessageForm
+
+
+from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
+from open_webui.constants import ERROR_MESSAGES
+from open_webui.env import SRC_LOG_LEVELS
+
+
+from open_webui.utils.auth import get_admin_user, get_verified_user
+from open_webui.utils.access_control import has_access, get_users_with_access
+from open_webui.utils.webhook import post_webhook
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MODELS"])
+
+router = APIRouter()
+
+############################
+# GetChatList
+############################
+
+
+@router.get("/", response_model=list[ChannelModel])
+async def get_channels(user=Depends(get_verified_user)):
+    if user.role == "admin":
+        return Channels.get_channels()
+    else:
+        return Channels.get_channels_by_user_id(user.id)
+
+
+############################
+# CreateNewChannel
+############################
+
+
+@router.post("/create", response_model=Optional[ChannelModel])
+async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user)):
+    try:
+        channel = Channels.insert_new_channel(form_data, user.id)
+        return ChannelModel(**channel.model_dump())
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+
+############################
+# GetChannelById
+############################
+
+
+@router.get("/{id}", response_model=Optional[ChannelModel])
+async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
+    channel = Channels.get_channel_by_id(id)
+    if not channel:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
+        )
+
+    if user.role != "admin" and not has_access(
+        user.id, type="read", access_control=channel.access_control
+    ):
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+    return ChannelModel(**channel.model_dump())
+
+
+############################
+# UpdateChannelById
+############################
+
+
+@router.post("/{id}/update", response_model=Optional[ChannelModel])
+async def update_channel_by_id(
+    id: str, form_data: ChannelForm, user=Depends(get_admin_user)
+):
+    channel = Channels.get_channel_by_id(id)
+    if not channel:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
+        )
+
+    try:
+        channel = Channels.update_channel_by_id(id, form_data)
+        return ChannelModel(**channel.model_dump())
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+
+############################
+# DeleteChannelById
+############################
+
+
+@router.delete("/{id}/delete", response_model=bool)
+async def delete_channel_by_id(id: str, user=Depends(get_admin_user)):
+    channel = Channels.get_channel_by_id(id)
+    if not channel:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
+        )
+
+    try:
+        Channels.delete_channel_by_id(id)
+        return True
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+
+############################
+# GetChannelMessages
+############################
+
+
+class MessageUserModel(MessageModel):
+    user: UserNameResponse
+
+
+@router.get("/{id}/messages", response_model=list[MessageUserModel])
+async def get_channel_messages(
+    id: str, skip: int = 0, limit: int = 50, user=Depends(get_verified_user)
+):
+    channel = Channels.get_channel_by_id(id)
+    if not channel:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
+        )
+
+    if user.role != "admin" and not has_access(
+        user.id, type="read", access_control=channel.access_control
+    ):
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+    message_list = Messages.get_messages_by_channel_id(id, skip, limit)
+    users = {}
+
+    messages = []
+    for message in message_list:
+        if message.user_id not in users:
+            user = Users.get_user_by_id(message.user_id)
+            users[message.user_id] = user
+
+        messages.append(
+            MessageUserModel(
+                **{
+                    **message.model_dump(),
+                    "user": UserNameResponse(**users[message.user_id].model_dump()),
+                }
+            )
+        )
+
+    return messages
+
+
+############################
+# PostNewMessage
+############################
+
+
+async def send_notification(webui_url, channel, message, active_user_ids):
+    users = get_users_with_access("read", channel.access_control)
+
+    for user in users:
+        if user.id in active_user_ids:
+            continue
+        else:
+            if user.settings:
+                webhook_url = user.settings.ui.get("notifications", {}).get(
+                    "webhook_url", None
+                )
+
+                if webhook_url:
+                    post_webhook(
+                        webhook_url,
+                        f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
+                        {
+                            "action": "channel",
+                            "message": message.content,
+                            "title": channel.name,
+                            "url": f"{webui_url}/channels/{channel.id}",
+                        },
+                    )
+
+
+@router.post("/{id}/messages/post", response_model=Optional[MessageModel])
+async def post_new_message(
+    request: Request,
+    id: str,
+    form_data: MessageForm,
+    background_tasks: BackgroundTasks,
+    user=Depends(get_verified_user),
+):
+    channel = Channels.get_channel_by_id(id)
+    if not channel:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
+        )
+
+    if user.role != "admin" and not has_access(
+        user.id, type="read", access_control=channel.access_control
+    ):
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+    try:
+        message = Messages.insert_new_message(form_data, channel.id, user.id)
+
+        if message:
+            event_data = {
+                "channel_id": channel.id,
+                "message_id": message.id,
+                "data": {
+                    "type": "message",
+                    "data": {
+                        **message.model_dump(),
+                        "user": UserNameResponse(**user.model_dump()).model_dump(),
+                    },
+                },
+                "user": UserNameResponse(**user.model_dump()).model_dump(),
+                "channel": channel.model_dump(),
+            }
+
+            await sio.emit(
+                "channel-events",
+                event_data,
+                to=f"channel:{channel.id}",
+            )
+
+            active_user_ids = get_user_ids_from_room(f"channel:{channel.id}")
+
+            background_tasks.add_task(
+                send_notification,
+                request.app.state.config.WEBUI_URL,
+                channel,
+                message,
+                active_user_ids,
+            )
+
+        return MessageModel(**message.model_dump())
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+
+############################
+# UpdateMessageById
+############################
+
+
+@router.post(
+    "/{id}/messages/{message_id}/update", response_model=Optional[MessageModel]
+)
+async def update_message_by_id(
+    id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
+):
+    channel = Channels.get_channel_by_id(id)
+    if not channel:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
+        )
+
+    if user.role != "admin" and not has_access(
+        user.id, type="read", access_control=channel.access_control
+    ):
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+    message = Messages.get_message_by_id(message_id)
+    if not message:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
+        )
+
+    if message.channel_id != id:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+    try:
+        message = Messages.update_message_by_id(message_id, form_data)
+        if message:
+            await sio.emit(
+                "channel-events",
+                {
+                    "channel_id": channel.id,
+                    "message_id": message.id,
+                    "data": {
+                        "type": "message:update",
+                        "data": {
+                            **message.model_dump(),
+                            "user": UserNameResponse(**user.model_dump()).model_dump(),
+                        },
+                    },
+                    "user": UserNameResponse(**user.model_dump()).model_dump(),
+                    "channel": channel.model_dump(),
+                },
+                to=f"channel:{channel.id}",
+            )
+
+        return MessageModel(**message.model_dump())
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+
+############################
+# DeleteMessageById
+############################
+
+
+@router.delete("/{id}/messages/{message_id}/delete", response_model=bool)
+async def delete_message_by_id(
+    id: str, message_id: str, user=Depends(get_verified_user)
+):
+    channel = Channels.get_channel_by_id(id)
+    if not channel:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
+        )
+
+    if user.role != "admin" and not has_access(
+        user.id, type="read", access_control=channel.access_control
+    ):
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+    message = Messages.get_message_by_id(message_id)
+    if not message:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
+        )
+
+    if message.channel_id != id:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+    try:
+        Messages.delete_message_by_id(message_id)
+        await sio.emit(
+            "channel-events",
+            {
+                "channel_id": channel.id,
+                "message_id": message.id,
+                "data": {
+                    "type": "message:delete",
+                    "data": {
+                        **message.model_dump(),
+                        "user": UserNameResponse(**user.model_dump()).model_dump(),
+                    },
+                },
+                "user": UserNameResponse(**user.model_dump()).model_dump(),
+                "channel": channel.model_dump(),
+            },
+            to=f"channel:{channel.id}",
+        )
+
+        return True
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
+        )

+ 4 - 5
backend/open_webui/apps/webui/routers/chats.py → backend/open_webui/routers/chats.py

@@ -2,15 +2,15 @@ import json
 import logging
 import logging
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.models.chats import (
+from open_webui.models.chats import (
     ChatForm,
     ChatForm,
     ChatImportForm,
     ChatImportForm,
     ChatResponse,
     ChatResponse,
     Chats,
     Chats,
     ChatTitleIdResponse,
     ChatTitleIdResponse,
 )
 )
-from open_webui.apps.webui.models.tags import TagModel, Tags
-from open_webui.apps.webui.models.folders import Folders
+from open_webui.models.tags import TagModel, Tags
+from open_webui.models.folders import Folders
 
 
 from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
 from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
@@ -19,7 +19,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
 
 
-from open_webui.utils.utils import get_admin_user, get_verified_user
+from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.utils.access_control import has_permission
 from open_webui.utils.access_control import has_permission
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
@@ -607,7 +607,6 @@ async def add_tag_by_id_and_tag_name(
                 detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"),
                 detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"),
             )
             )
 
 
-        print(tags, tag_id)
         if tag_id not in tags:
         if tag_id not in tags:
             Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
             Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
                 id, user.id, form_data.name
                 id, user.id, form_data.name

+ 1 - 1
backend/open_webui/apps/webui/routers/configs.py → backend/open_webui/routers/configs.py

@@ -3,7 +3,7 @@ from pydantic import BaseModel
 
 
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.utils.utils import get_admin_user, get_verified_user
+from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.config import get_config, save_config
 from open_webui.config import get_config, save_config
 from open_webui.config import BannerModel
 from open_webui.config import BannerModel
 
 

+ 3 - 3
backend/open_webui/apps/webui/routers/evaluations.py → backend/open_webui/routers/evaluations.py

@@ -2,8 +2,8 @@ from typing import Optional
 from fastapi import APIRouter, Depends, HTTPException, status, Request
 from fastapi import APIRouter, Depends, HTTPException, status, Request
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
-from open_webui.apps.webui.models.users import Users, UserModel
-from open_webui.apps.webui.models.feedbacks import (
+from open_webui.models.users import Users, UserModel
+from open_webui.models.feedbacks import (
     FeedbackModel,
     FeedbackModel,
     FeedbackResponse,
     FeedbackResponse,
     FeedbackForm,
     FeedbackForm,
@@ -11,7 +11,7 @@ from open_webui.apps.webui.models.feedbacks import (
 )
 )
 
 
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
-from open_webui.utils.utils import get_admin_user, get_verified_user
+from open_webui.utils.auth import get_admin_user, get_verified_user
 
 
 router = APIRouter()
 router = APIRouter()
 
 

+ 37 - 17
backend/open_webui/apps/webui/routers/files.py → backend/open_webui/routers/files.py

@@ -5,27 +5,28 @@ from pathlib import Path
 from typing import Optional
 from typing import Optional
 from pydantic import BaseModel
 from pydantic import BaseModel
 import mimetypes
 import mimetypes
+from urllib.parse import quote
 
 
 from open_webui.storage.provider import Storage
 from open_webui.storage.provider import Storage
 
 
-from open_webui.apps.webui.models.files import (
+from open_webui.models.files import (
     FileForm,
     FileForm,
     FileModel,
     FileModel,
     FileModelResponse,
     FileModelResponse,
     Files,
     Files,
 )
 )
-from open_webui.apps.retrieval.main import process_file, ProcessFileForm
+from open_webui.routers.retrieval import process_file, ProcessFileForm
 
 
 from open_webui.config import UPLOAD_DIR
 from open_webui.config import UPLOAD_DIR
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
 
 
 
 
-from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
+from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
 from fastapi.responses import FileResponse, StreamingResponse
 from fastapi.responses import FileResponse, StreamingResponse
 
 
 
 
-from open_webui.utils.utils import get_admin_user, get_verified_user
+from open_webui.utils.auth import get_admin_user, get_verified_user
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -39,7 +40,9 @@ router = APIRouter()
 
 
 
 
 @router.post("/", response_model=FileModelResponse)
 @router.post("/", response_model=FileModelResponse)
-def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
+def upload_file(
+    request: Request, file: UploadFile = File(...), user=Depends(get_verified_user)
+):
     log.info(f"file.content_type: {file.content_type}")
     log.info(f"file.content_type: {file.content_type}")
     try:
     try:
         unsanitized_filename = file.filename
         unsanitized_filename = file.filename
@@ -68,7 +71,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
         )
         )
 
 
         try:
         try:
-            process_file(ProcessFileForm(file_id=id))
+            process_file(request, ProcessFileForm(file_id=id))
             file_item = Files.get_file_by_id(id=id)
             file_item = Files.get_file_by_id(id=id)
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)
@@ -183,13 +186,15 @@ class ContentForm(BaseModel):
 
 
 @router.post("/{id}/data/content/update")
 @router.post("/{id}/data/content/update")
 async def update_file_data_content_by_id(
 async def update_file_data_content_by_id(
-    id: str, form_data: ContentForm, user=Depends(get_verified_user)
+    request: Request, id: str, form_data: ContentForm, user=Depends(get_verified_user)
 ):
 ):
     file = Files.get_file_by_id(id)
     file = Files.get_file_by_id(id)
 
 
     if file and (file.user_id == user.id or user.role == "admin"):
     if file and (file.user_id == user.id or user.role == "admin"):
         try:
         try:
-            process_file(ProcessFileForm(file_id=id, content=form_data.content))
+            process_file(
+                request, ProcessFileForm(file_id=id, content=form_data.content)
+            )
             file = Files.get_file_by_id(id=id)
             file = Files.get_file_by_id(id=id)
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)
@@ -218,11 +223,22 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
 
 
             # Check if the file already exists in the cache
             # Check if the file already exists in the cache
             if file_path.is_file():
             if file_path.is_file():
-                print(f"file_path: {file_path}")
-                headers = {
-                    "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
-                }
+                # Handle Unicode filenames
+                filename = file.meta.get("name", file.filename)
+                encoded_filename = quote(filename)  # RFC5987 encoding
+
+                headers = {}
+                if file.meta.get("content_type") not in [
+                    "application/pdf",
+                    "text/plain",
+                ]:
+                    headers = {
+                        **headers,
+                        "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
+                    }
+
                 return FileResponse(file_path, headers=headers)
                 return FileResponse(file_path, headers=headers)
+
             else:
             else:
                 raise HTTPException(
                 raise HTTPException(
                     status_code=status.HTTP_404_NOT_FOUND,
                     status_code=status.HTTP_404_NOT_FOUND,
@@ -279,16 +295,20 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
 
 
     if file and (file.user_id == user.id or user.role == "admin"):
     if file and (file.user_id == user.id or user.role == "admin"):
         file_path = file.path
         file_path = file.path
+
+        # Handle Unicode filenames
+        filename = file.meta.get("name", file.filename)
+        encoded_filename = quote(filename)  # RFC5987 encoding
+        headers = {
+            "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
+        }
+
         if file_path:
         if file_path:
             file_path = Storage.get_file(file_path)
             file_path = Storage.get_file(file_path)
             file_path = Path(file_path)
             file_path = Path(file_path)
 
 
             # Check if the file already exists in the cache
             # Check if the file already exists in the cache
             if file_path.is_file():
             if file_path.is_file():
-                print(f"file_path: {file_path}")
-                headers = {
-                    "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
-                }
                 return FileResponse(file_path, headers=headers)
                 return FileResponse(file_path, headers=headers)
             else:
             else:
                 raise HTTPException(
                 raise HTTPException(
@@ -307,7 +327,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
             return StreamingResponse(
             return StreamingResponse(
                 generator(),
                 generator(),
                 media_type="text/plain",
                 media_type="text/plain",
-                headers={"Content-Disposition": f"attachment; filename={file_name}"},
+                headers=headers,
             )
             )
     else:
     else:
         raise HTTPException(
         raise HTTPException(

+ 3 - 3
backend/open_webui/apps/webui/routers/folders.py → backend/open_webui/routers/folders.py

@@ -8,12 +8,12 @@ from pydantic import BaseModel
 import mimetypes
 import mimetypes
 
 
 
 
-from open_webui.apps.webui.models.folders import (
+from open_webui.models.folders import (
     FolderForm,
     FolderForm,
     FolderModel,
     FolderModel,
     Folders,
     Folders,
 )
 )
-from open_webui.apps.webui.models.chats import Chats
+from open_webui.models.chats import Chats
 
 
 from open_webui.config import UPLOAD_DIR
 from open_webui.config import UPLOAD_DIR
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
@@ -24,7 +24,7 @@ from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
 from fastapi.responses import FileResponse, StreamingResponse
 from fastapi.responses import FileResponse, StreamingResponse
 
 
 
 
-from open_webui.utils.utils import get_admin_user, get_verified_user
+from open_webui.utils.auth import get_admin_user, get_verified_user
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
 log.setLevel(SRC_LOG_LEVELS["MODELS"])

+ 3 - 3
backend/open_webui/apps/webui/routers/functions.py → backend/open_webui/routers/functions.py

@@ -2,17 +2,17 @@ import os
 from pathlib import Path
 from pathlib import Path
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.models.functions import (
+from open_webui.models.functions import (
     FunctionForm,
     FunctionForm,
     FunctionModel,
     FunctionModel,
     FunctionResponse,
     FunctionResponse,
     Functions,
     Functions,
 )
 )
-from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports
+from open_webui.utils.plugin import load_function_module_by_id, replace_imports
 from open_webui.config import CACHE_DIR
 from open_webui.config import CACHE_DIR
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
 from fastapi import APIRouter, Depends, HTTPException, Request, status
 from fastapi import APIRouter, Depends, HTTPException, Request, status
-from open_webui.utils.utils import get_admin_user, get_verified_user
+from open_webui.utils.auth import get_admin_user, get_verified_user
 
 
 router = APIRouter()
 router = APIRouter()
 
 

+ 2 - 2
backend/open_webui/apps/webui/routers/groups.py → backend/open_webui/routers/groups.py

@@ -2,7 +2,7 @@ import os
 from pathlib import Path
 from pathlib import Path
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.models.groups import (
+from open_webui.models.groups import (
     Groups,
     Groups,
     GroupForm,
     GroupForm,
     GroupUpdateForm,
     GroupUpdateForm,
@@ -12,7 +12,7 @@ from open_webui.apps.webui.models.groups import (
 from open_webui.config import CACHE_DIR
 from open_webui.config import CACHE_DIR
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
 from fastapi import APIRouter, Depends, HTTPException, Request, status
 from fastapi import APIRouter, Depends, HTTPException, Request, status
-from open_webui.utils.utils import get_admin_user, get_verified_user
+from open_webui.utils.auth import get_admin_user, get_verified_user
 
 
 router = APIRouter()
 router = APIRouter()
 
 

+ 170 - 175
backend/open_webui/apps/images/main.py → backend/open_webui/routers/images.py

@@ -9,38 +9,24 @@ from pathlib import Path
 from typing import Optional
 from typing import Optional
 
 
 import requests
 import requests
-from open_webui.apps.images.utils.comfyui import (
+
+
+from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
+from fastapi.middleware.cors import CORSMiddleware
+from pydantic import BaseModel
+
+
+from open_webui.config import CACHE_DIR
+from open_webui.constants import ERROR_MESSAGES
+from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
+
+from open_webui.utils.auth import get_admin_user, get_verified_user
+from open_webui.utils.images.comfyui import (
     ComfyUIGenerateImageForm,
     ComfyUIGenerateImageForm,
     ComfyUIWorkflow,
     ComfyUIWorkflow,
     comfyui_generate_image,
     comfyui_generate_image,
 )
 )
-from open_webui.config import (
-    AUTOMATIC1111_API_AUTH,
-    AUTOMATIC1111_BASE_URL,
-    AUTOMATIC1111_CFG_SCALE,
-    AUTOMATIC1111_SAMPLER,
-    AUTOMATIC1111_SCHEDULER,
-    CACHE_DIR,
-    COMFYUI_BASE_URL,
-    COMFYUI_WORKFLOW,
-    COMFYUI_WORKFLOW_NODES,
-    CORS_ALLOW_ORIGIN,
-    ENABLE_IMAGE_GENERATION,
-    IMAGE_GENERATION_ENGINE,
-    IMAGE_GENERATION_MODEL,
-    IMAGE_SIZE,
-    IMAGE_STEPS,
-    IMAGES_OPENAI_API_BASE_URL,
-    IMAGES_OPENAI_API_KEY,
-    AppConfig,
-)
-from open_webui.constants import ERROR_MESSAGES
-from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
 
 
-from fastapi import Depends, FastAPI, HTTPException, Request
-from fastapi.middleware.cors import CORSMiddleware
-from pydantic import BaseModel
-from open_webui.utils.utils import get_admin_user, get_verified_user
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["IMAGES"])
 log.setLevel(SRC_LOG_LEVELS["IMAGES"])
@@ -48,63 +34,31 @@ log.setLevel(SRC_LOG_LEVELS["IMAGES"])
 IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
 IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
 IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
 IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
 
 
-app = FastAPI(
-    docs_url="/docs" if ENV == "dev" else None,
-    openapi_url="/openapi.json" if ENV == "dev" else None,
-    redoc_url=None,
-)
-
-app.add_middleware(
-    CORSMiddleware,
-    allow_origins=CORS_ALLOW_ORIGIN,
-    allow_credentials=True,
-    allow_methods=["*"],
-    allow_headers=["*"],
-)
-
-app.state.config = AppConfig()
-
-app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
-app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
 
 
-app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
-app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
+router = APIRouter()
 
 
-app.state.config.MODEL = IMAGE_GENERATION_MODEL
 
 
-app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
-app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
-app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE
-app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER
-app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER
-app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
-app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
-app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
-
-app.state.config.IMAGE_SIZE = IMAGE_SIZE
-app.state.config.IMAGE_STEPS = IMAGE_STEPS
-
-
-@app.get("/config")
+@router.get("/config")
 async def get_config(request: Request, user=Depends(get_admin_user)):
 async def get_config(request: Request, user=Depends(get_admin_user)):
     return {
     return {
-        "enabled": app.state.config.ENABLED,
-        "engine": app.state.config.ENGINE,
+        "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
+        "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
         "openai": {
         "openai": {
-            "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
+            "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
         },
         },
         "automatic1111": {
         "automatic1111": {
-            "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
-            "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
-            "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
-            "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
-            "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
+            "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
+            "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
+            "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
+            "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
+            "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
         },
         },
         "comfyui": {
         "comfyui": {
-            "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
-            "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
-            "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
+            "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
+            "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
+            "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
+            "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
         },
         },
     }
     }
 
 
@@ -117,13 +71,14 @@ class OpenAIConfigForm(BaseModel):
 class Automatic1111ConfigForm(BaseModel):
 class Automatic1111ConfigForm(BaseModel):
     AUTOMATIC1111_BASE_URL: str
     AUTOMATIC1111_BASE_URL: str
     AUTOMATIC1111_API_AUTH: str
     AUTOMATIC1111_API_AUTH: str
-    AUTOMATIC1111_CFG_SCALE: Optional[str]
+    AUTOMATIC1111_CFG_SCALE: Optional[str | float | int]
     AUTOMATIC1111_SAMPLER: Optional[str]
     AUTOMATIC1111_SAMPLER: Optional[str]
     AUTOMATIC1111_SCHEDULER: Optional[str]
     AUTOMATIC1111_SCHEDULER: Optional[str]
 
 
 
 
 class ComfyUIConfigForm(BaseModel):
 class ComfyUIConfigForm(BaseModel):
     COMFYUI_BASE_URL: str
     COMFYUI_BASE_URL: str
+    COMFYUI_API_KEY: str
     COMFYUI_WORKFLOW: str
     COMFYUI_WORKFLOW: str
     COMFYUI_WORKFLOW_NODES: list[dict]
     COMFYUI_WORKFLOW_NODES: list[dict]
 
 
@@ -136,133 +91,157 @@ class ConfigForm(BaseModel):
     comfyui: ComfyUIConfigForm
     comfyui: ComfyUIConfigForm
 
 
 
 
-@app.post("/config/update")
-async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
-    app.state.config.ENGINE = form_data.engine
-    app.state.config.ENABLED = form_data.enabled
+@router.post("/config/update")
+async def update_config(
+    request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
+):
+    request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine
+    request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
 
 
-    app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
-    app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
+    request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
+        form_data.openai.OPENAI_API_BASE_URL
+    )
+    request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
 
 
-    app.state.config.AUTOMATIC1111_BASE_URL = (
+    request.app.state.config.AUTOMATIC1111_BASE_URL = (
         form_data.automatic1111.AUTOMATIC1111_BASE_URL
         form_data.automatic1111.AUTOMATIC1111_BASE_URL
     )
     )
-    app.state.config.AUTOMATIC1111_API_AUTH = (
+    request.app.state.config.AUTOMATIC1111_API_AUTH = (
         form_data.automatic1111.AUTOMATIC1111_API_AUTH
         form_data.automatic1111.AUTOMATIC1111_API_AUTH
     )
     )
 
 
-    app.state.config.AUTOMATIC1111_CFG_SCALE = (
+    request.app.state.config.AUTOMATIC1111_CFG_SCALE = (
         float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
         float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
         if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
         if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
         else None
         else None
     )
     )
-    app.state.config.AUTOMATIC1111_SAMPLER = (
+    request.app.state.config.AUTOMATIC1111_SAMPLER = (
         form_data.automatic1111.AUTOMATIC1111_SAMPLER
         form_data.automatic1111.AUTOMATIC1111_SAMPLER
         if form_data.automatic1111.AUTOMATIC1111_SAMPLER
         if form_data.automatic1111.AUTOMATIC1111_SAMPLER
         else None
         else None
     )
     )
-    app.state.config.AUTOMATIC1111_SCHEDULER = (
+    request.app.state.config.AUTOMATIC1111_SCHEDULER = (
         form_data.automatic1111.AUTOMATIC1111_SCHEDULER
         form_data.automatic1111.AUTOMATIC1111_SCHEDULER
         if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
         if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
         else None
         else None
     )
     )
 
 
-    app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/")
-    app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
-    app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES
+    request.app.state.config.COMFYUI_BASE_URL = (
+        form_data.comfyui.COMFYUI_BASE_URL.strip("/")
+    )
+    request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
+    request.app.state.config.COMFYUI_WORKFLOW_NODES = (
+        form_data.comfyui.COMFYUI_WORKFLOW_NODES
+    )
 
 
     return {
     return {
-        "enabled": app.state.config.ENABLED,
-        "engine": app.state.config.ENGINE,
+        "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
+        "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
         "openai": {
         "openai": {
-            "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
+            "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
         },
         },
         "automatic1111": {
         "automatic1111": {
-            "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
-            "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
-            "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
-            "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
-            "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
+            "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
+            "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
+            "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
+            "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
+            "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
         },
         },
         "comfyui": {
         "comfyui": {
-            "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
-            "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
-            "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
+            "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
+            "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
+            "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
+            "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
         },
         },
     }
     }
 
 
 
 
-def get_automatic1111_api_auth():
-    if app.state.config.AUTOMATIC1111_API_AUTH is None:
+def get_automatic1111_api_auth(request: Request):
+    if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
         return ""
         return ""
     else:
     else:
-        auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
+        auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
+            "utf-8"
+        )
         auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
         auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
         auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
         auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
         return f"Basic {auth1111_base64_encoded_string}"
         return f"Basic {auth1111_base64_encoded_string}"
 
 
 
 
-@app.get("/config/url/verify")
-async def verify_url(user=Depends(get_admin_user)):
-    if app.state.config.ENGINE == "automatic1111":
+@router.get("/config/url/verify")
+async def verify_url(request: Request, user=Depends(get_admin_user)):
+    if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111":
         try:
         try:
             r = requests.get(
             r = requests.get(
-                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
-                headers={"authorization": get_automatic1111_api_auth()},
+                url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
+                headers={"authorization": get_automatic1111_api_auth(request)},
             )
             )
             r.raise_for_status()
             r.raise_for_status()
             return True
             return True
         except Exception:
         except Exception:
-            app.state.config.ENABLED = False
+            request.app.state.config.ENABLE_IMAGE_GENERATION = False
             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
-    elif app.state.config.ENGINE == "comfyui":
+    elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
         try:
         try:
-            r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
+            r = requests.get(
+                url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
+            )
             r.raise_for_status()
             r.raise_for_status()
             return True
             return True
         except Exception:
         except Exception:
-            app.state.config.ENABLED = False
+            request.app.state.config.ENABLE_IMAGE_GENERATION = False
             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
     else:
     else:
         return True
         return True
 
 
 
 
-def set_image_model(model: str):
+def set_image_model(request: Request, model: str):
     log.info(f"Setting image model to {model}")
     log.info(f"Setting image model to {model}")
-    app.state.config.MODEL = model
-    if app.state.config.ENGINE in ["", "automatic1111"]:
+    request.app.state.config.IMAGE_GENERATION_MODEL = model
+    if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
         api_auth = get_automatic1111_api_auth()
         api_auth = get_automatic1111_api_auth()
         r = requests.get(
         r = requests.get(
-            url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
+            url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
             headers={"authorization": api_auth},
             headers={"authorization": api_auth},
         )
         )
         options = r.json()
         options = r.json()
         if model != options["sd_model_checkpoint"]:
         if model != options["sd_model_checkpoint"]:
             options["sd_model_checkpoint"] = model
             options["sd_model_checkpoint"] = model
             r = requests.post(
             r = requests.post(
-                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
+                url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
                 json=options,
                 json=options,
                 headers={"authorization": api_auth},
                 headers={"authorization": api_auth},
             )
             )
-    return app.state.config.MODEL
+    return request.app.state.config.IMAGE_GENERATION_MODEL
 
 
 
 
-def get_image_model():
-    if app.state.config.ENGINE == "openai":
-        return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
-    elif app.state.config.ENGINE == "comfyui":
-        return app.state.config.MODEL if app.state.config.MODEL else ""
-    elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "":
+def get_image_model(request):
+    if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
+        return (
+            request.app.state.config.IMAGE_GENERATION_MODEL
+            if request.app.state.config.IMAGE_GENERATION_MODEL
+            else "dall-e-2"
+        )
+    elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
+        return (
+            request.app.state.config.IMAGE_GENERATION_MODEL
+            if request.app.state.config.IMAGE_GENERATION_MODEL
+            else ""
+        )
+    elif (
+        request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
+        or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
+    ):
         try:
         try:
             r = requests.get(
             r = requests.get(
-                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
+                url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
                 headers={"authorization": get_automatic1111_api_auth()},
                 headers={"authorization": get_automatic1111_api_auth()},
             )
             )
             options = r.json()
             options = r.json()
             return options["sd_model_checkpoint"]
             return options["sd_model_checkpoint"]
         except Exception as e:
         except Exception as e:
-            app.state.config.ENABLED = False
+            request.app.state.config.ENABLE_IMAGE_GENERATION = False
             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
 
 
 
 
@@ -272,23 +251,25 @@ class ImageConfigForm(BaseModel):
     IMAGE_STEPS: int
     IMAGE_STEPS: int
 
 
 
 
-@app.get("/image/config")
-async def get_image_config(user=Depends(get_admin_user)):
+@router.get("/image/config")
+async def get_image_config(request: Request, user=Depends(get_admin_user)):
     return {
     return {
-        "MODEL": app.state.config.MODEL,
-        "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
-        "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
+        "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
+        "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
+        "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
     }
     }
 
 
 
 
-@app.post("/image/config/update")
-async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)):
+@router.post("/image/config/update")
+async def update_image_config(
+    request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
+):
 
 
-    set_image_model(form_data.MODEL)
+    set_image_model(request, form_data.MODEL)
 
 
     pattern = r"^\d+x\d+$"
     pattern = r"^\d+x\d+$"
     if re.match(pattern, form_data.IMAGE_SIZE):
     if re.match(pattern, form_data.IMAGE_SIZE):
-        app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
+        request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
     else:
     else:
         raise HTTPException(
         raise HTTPException(
             status_code=400,
             status_code=400,
@@ -296,7 +277,7 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin
         )
         )
 
 
     if form_data.IMAGE_STEPS >= 0:
     if form_data.IMAGE_STEPS >= 0:
-        app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
+        request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
     else:
     else:
         raise HTTPException(
         raise HTTPException(
             status_code=400,
             status_code=400,
@@ -304,29 +285,35 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin
         )
         )
 
 
     return {
     return {
-        "MODEL": app.state.config.MODEL,
-        "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
-        "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
+        "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
+        "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
+        "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
     }
     }
 
 
 
 
-@app.get("/models")
-def get_models(user=Depends(get_verified_user)):
+@router.get("/models")
+def get_models(request: Request, user=Depends(get_verified_user)):
     try:
     try:
-        if app.state.config.ENGINE == "openai":
+        if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
             return [
             return [
                 {"id": "dall-e-2", "name": "DALL·E 2"},
                 {"id": "dall-e-2", "name": "DALL·E 2"},
                 {"id": "dall-e-3", "name": "DALL·E 3"},
                 {"id": "dall-e-3", "name": "DALL·E 3"},
             ]
             ]
-        elif app.state.config.ENGINE == "comfyui":
+        elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
             # TODO - get models from comfyui
             # TODO - get models from comfyui
-            r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
+            headers = {
+                "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
+            }
+            r = requests.get(
+                url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
+                headers=headers,
+            )
             info = r.json()
             info = r.json()
 
 
-            workflow = json.loads(app.state.config.COMFYUI_WORKFLOW)
+            workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
             model_node_id = None
             model_node_id = None
 
 
-            for node in app.state.config.COMFYUI_WORKFLOW_NODES:
+            for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
                 if node["type"] == "model":
                 if node["type"] == "model":
                     if node["node_ids"]:
                     if node["node_ids"]:
                         model_node_id = node["node_ids"][0]
                         model_node_id = node["node_ids"][0]
@@ -362,10 +349,11 @@ def get_models(user=Depends(get_verified_user)):
                     )
                     )
                 )
                 )
         elif (
         elif (
-            app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
+            request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
+            or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
         ):
         ):
             r = requests.get(
             r = requests.get(
-                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
+                url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
                 headers={"authorization": get_automatic1111_api_auth()},
                 headers={"authorization": get_automatic1111_api_auth()},
             )
             )
             models = r.json()
             models = r.json()
@@ -376,7 +364,7 @@ def get_models(user=Depends(get_verified_user)):
                 )
                 )
             )
             )
     except Exception as e:
     except Exception as e:
-        app.state.config.ENABLED = False
+        request.app.state.config.ENABLE_IMAGE_GENERATION = False
         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
 
 
 
 
@@ -448,18 +436,21 @@ def save_url_image(url):
         return None
         return None
 
 
 
 
-@app.post("/generations")
+@router.post("/generations")
 async def image_generations(
 async def image_generations(
+    request: Request,
     form_data: GenerateImageForm,
     form_data: GenerateImageForm,
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
 ):
 ):
-    width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
+    width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
 
 
     r = None
     r = None
     try:
     try:
-        if app.state.config.ENGINE == "openai":
+        if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
             headers = {}
             headers = {}
-            headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
+            headers["Authorization"] = (
+                f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}"
+            )
             headers["Content-Type"] = "application/json"
             headers["Content-Type"] = "application/json"
 
 
             if ENABLE_FORWARD_USER_INFO_HEADERS:
             if ENABLE_FORWARD_USER_INFO_HEADERS:
@@ -470,14 +461,16 @@ async def image_generations(
 
 
             data = {
             data = {
                 "model": (
                 "model": (
-                    app.state.config.MODEL
-                    if app.state.config.MODEL != ""
+                    request.app.state.config.IMAGE_GENERATION_MODEL
+                    if request.app.state.config.IMAGE_GENERATION_MODEL != ""
                     else "dall-e-2"
                     else "dall-e-2"
                 ),
                 ),
                 "prompt": form_data.prompt,
                 "prompt": form_data.prompt,
                 "n": form_data.n,
                 "n": form_data.n,
                 "size": (
                 "size": (
-                    form_data.size if form_data.size else app.state.config.IMAGE_SIZE
+                    form_data.size
+                    if form_data.size
+                    else request.app.state.config.IMAGE_SIZE
                 ),
                 ),
                 "response_format": "b64_json",
                 "response_format": "b64_json",
             }
             }
@@ -485,7 +478,7 @@ async def image_generations(
             # Use asyncio.to_thread for the requests.post call
             # Use asyncio.to_thread for the requests.post call
             r = await asyncio.to_thread(
             r = await asyncio.to_thread(
                 requests.post,
                 requests.post,
-                url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
+                url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations",
                 json=data,
                 json=data,
                 headers=headers,
                 headers=headers,
             )
             )
@@ -505,7 +498,7 @@ async def image_generations(
 
 
             return images
             return images
 
 
-        elif app.state.config.ENGINE == "comfyui":
+        elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
             data = {
             data = {
                 "prompt": form_data.prompt,
                 "prompt": form_data.prompt,
                 "width": width,
                 "width": width,
@@ -513,8 +506,8 @@ async def image_generations(
                 "n": form_data.n,
                 "n": form_data.n,
             }
             }
 
 
-            if app.state.config.IMAGE_STEPS is not None:
-                data["steps"] = app.state.config.IMAGE_STEPS
+            if request.app.state.config.IMAGE_STEPS is not None:
+                data["steps"] = request.app.state.config.IMAGE_STEPS
 
 
             if form_data.negative_prompt is not None:
             if form_data.negative_prompt is not None:
                 data["negative_prompt"] = form_data.negative_prompt
                 data["negative_prompt"] = form_data.negative_prompt
@@ -523,18 +516,19 @@ async def image_generations(
                 **{
                 **{
                     "workflow": ComfyUIWorkflow(
                     "workflow": ComfyUIWorkflow(
                         **{
                         **{
-                            "workflow": app.state.config.COMFYUI_WORKFLOW,
-                            "nodes": app.state.config.COMFYUI_WORKFLOW_NODES,
+                            "workflow": request.app.state.config.COMFYUI_WORKFLOW,
+                            "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
                         }
                         }
                     ),
                     ),
                     **data,
                     **data,
                 }
                 }
             )
             )
             res = await comfyui_generate_image(
             res = await comfyui_generate_image(
-                app.state.config.MODEL,
+                request.app.state.config.IMAGE_GENERATION_MODEL,
                 form_data,
                 form_data,
                 user.id,
                 user.id,
-                app.state.config.COMFYUI_BASE_URL,
+                request.app.state.config.COMFYUI_BASE_URL,
+                request.app.state.config.COMFYUI_API_KEY,
             )
             )
             log.debug(f"res: {res}")
             log.debug(f"res: {res}")
 
 
@@ -551,7 +545,8 @@ async def image_generations(
             log.debug(f"images: {images}")
             log.debug(f"images: {images}")
             return images
             return images
         elif (
         elif (
-            app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
+            request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
+            or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
         ):
         ):
             if form_data.model:
             if form_data.model:
                 set_image_model(form_data.model)
                 set_image_model(form_data.model)
@@ -563,25 +558,25 @@ async def image_generations(
                 "height": height,
                 "height": height,
             }
             }
 
 
-            if app.state.config.IMAGE_STEPS is not None:
-                data["steps"] = app.state.config.IMAGE_STEPS
+            if request.app.state.config.IMAGE_STEPS is not None:
+                data["steps"] = request.app.state.config.IMAGE_STEPS
 
 
             if form_data.negative_prompt is not None:
             if form_data.negative_prompt is not None:
                 data["negative_prompt"] = form_data.negative_prompt
                 data["negative_prompt"] = form_data.negative_prompt
 
 
-            if app.state.config.AUTOMATIC1111_CFG_SCALE:
-                data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE
+            if request.app.state.config.AUTOMATIC1111_CFG_SCALE:
+                data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE
 
 
-            if app.state.config.AUTOMATIC1111_SAMPLER:
-                data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER
+            if request.app.state.config.AUTOMATIC1111_SAMPLER:
+                data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER
 
 
-            if app.state.config.AUTOMATIC1111_SCHEDULER:
-                data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER
+            if request.app.state.config.AUTOMATIC1111_SCHEDULER:
+                data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER
 
 
             # Use asyncio.to_thread for the requests.post call
             # Use asyncio.to_thread for the requests.post call
             r = await asyncio.to_thread(
             r = await asyncio.to_thread(
                 requests.post,
                 requests.post,
-                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
+                url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
                 json=data,
                 json=data,
                 headers={"authorization": get_automatic1111_api_auth()},
                 headers={"authorization": get_automatic1111_api_auth()},
             )
             )

+ 99 - 9
backend/open_webui/apps/webui/routers/knowledge.py → backend/open_webui/routers/knowledge.py

@@ -1,22 +1,26 @@
-import json
-from typing import Optional, Union
+from typing import List, Optional
 from pydantic import BaseModel
 from pydantic import BaseModel
 from fastapi import APIRouter, Depends, HTTPException, status, Request
 from fastapi import APIRouter, Depends, HTTPException, status, Request
 import logging
 import logging
 
 
-from open_webui.apps.webui.models.knowledge import (
+from open_webui.models.knowledge import (
     Knowledges,
     Knowledges,
     KnowledgeForm,
     KnowledgeForm,
     KnowledgeResponse,
     KnowledgeResponse,
     KnowledgeUserResponse,
     KnowledgeUserResponse,
 )
 )
-from open_webui.apps.webui.models.files import Files, FileModel
-from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
-from open_webui.apps.retrieval.main import process_file, ProcessFileForm
+from open_webui.models.files import Files, FileModel
+from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
+from open_webui.routers.retrieval import (
+    process_file,
+    ProcessFileForm,
+    process_files_batch,
+    BatchProcessFilesForm,
+)
 
 
 
 
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
-from open_webui.utils.utils import get_admin_user, get_verified_user
+from open_webui.utils.auth import get_verified_user
 from open_webui.utils.access_control import has_access, has_permission
 from open_webui.utils.access_control import has_access, has_permission
 
 
 
 
@@ -242,6 +246,7 @@ class KnowledgeFileIdForm(BaseModel):
 
 
 @router.post("/{id}/file/add", response_model=Optional[KnowledgeFilesResponse])
 @router.post("/{id}/file/add", response_model=Optional[KnowledgeFilesResponse])
 def add_file_to_knowledge_by_id(
 def add_file_to_knowledge_by_id(
+    request: Request,
     id: str,
     id: str,
     form_data: KnowledgeFileIdForm,
     form_data: KnowledgeFileIdForm,
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
@@ -274,7 +279,9 @@ def add_file_to_knowledge_by_id(
 
 
     # Add content to the vector database
     # Add content to the vector database
     try:
     try:
-        process_file(ProcessFileForm(file_id=form_data.file_id, collection_name=id))
+        process_file(
+            request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
+        )
     except Exception as e:
     except Exception as e:
         log.debug(e)
         log.debug(e)
         raise HTTPException(
         raise HTTPException(
@@ -318,6 +325,7 @@ def add_file_to_knowledge_by_id(
 
 
 @router.post("/{id}/file/update", response_model=Optional[KnowledgeFilesResponse])
 @router.post("/{id}/file/update", response_model=Optional[KnowledgeFilesResponse])
 def update_file_from_knowledge_by_id(
 def update_file_from_knowledge_by_id(
+    request: Request,
     id: str,
     id: str,
     form_data: KnowledgeFileIdForm,
     form_data: KnowledgeFileIdForm,
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
@@ -349,7 +357,9 @@ def update_file_from_knowledge_by_id(
 
 
     # Add content to the vector database
     # Add content to the vector database
     try:
     try:
-        process_file(ProcessFileForm(file_id=form_data.file_id, collection_name=id))
+        process_file(
+            request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
+        )
     except Exception as e:
     except Exception as e:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
             status_code=status.HTTP_400_BAD_REQUEST,
@@ -508,3 +518,83 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
     knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
     knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
 
 
     return knowledge
     return knowledge
+
+
+############################
+# AddFilesToKnowledge
+############################
+
+
+@router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse])
+def add_files_to_knowledge_batch(
+    id: str,
+    form_data: list[KnowledgeFileIdForm],
+    user=Depends(get_verified_user),
+):
+    """
+    Add multiple files to a knowledge base
+    """
+    knowledge = Knowledges.get_knowledge_by_id(id=id)
+    if not knowledge:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+    if knowledge.user_id != user.id and user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
+    # Get files content
+    print(f"files/batch/add - {len(form_data)} files")
+    files: List[FileModel] = []
+    for form in form_data:
+        file = Files.get_file_by_id(form.file_id)
+        if not file:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=f"File {form.file_id} not found",
+            )
+        files.append(file)
+
+    # Process files
+    try:
+        result = process_files_batch(
+            BatchProcessFilesForm(files=files, collection_name=id)
+        )
+    except Exception as e:
+        log.error(
+            f"add_files_to_knowledge_batch: Exception occurred: {e}", exc_info=True
+        )
+        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
+
+    # Add successful files to knowledge base
+    data = knowledge.data or {}
+    existing_file_ids = data.get("file_ids", [])
+
+    # Only add files that were successfully processed
+    successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
+    for file_id in successful_file_ids:
+        if file_id not in existing_file_ids:
+            existing_file_ids.append(file_id)
+
+    data["file_ids"] = existing_file_ids
+    knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
+
+    # If there were any errors, include them in the response
+    if result.errors:
+        error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
+        return KnowledgeFilesResponse(
+            **knowledge.model_dump(),
+            files=Files.get_files_by_ids(existing_file_ids),
+            warnings={
+                "message": "Some files failed to process",
+                "errors": error_details,
+            },
+        )
+
+    return KnowledgeFilesResponse(
+        **knowledge.model_dump(), files=Files.get_files_by_ids(existing_file_ids)
+    )

+ 3 - 3
backend/open_webui/apps/webui/routers/memories.py → backend/open_webui/routers/memories.py

@@ -3,9 +3,9 @@ from pydantic import BaseModel
 import logging
 import logging
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.models.memories import Memories, MemoryModel
-from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
-from open_webui.utils.utils import get_verified_user
+from open_webui.models.memories import Memories, MemoryModel
+from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
+from open_webui.utils.auth import get_verified_user
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
 
 

+ 2 - 2
backend/open_webui/apps/webui/routers/models.py → backend/open_webui/routers/models.py

@@ -1,6 +1,6 @@
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.models.models import (
+from open_webui.models.models import (
     ModelForm,
     ModelForm,
     ModelModel,
     ModelModel,
     ModelResponse,
     ModelResponse,
@@ -11,7 +11,7 @@ from open_webui.constants import ERROR_MESSAGES
 from fastapi import APIRouter, Depends, HTTPException, Request, status
 from fastapi import APIRouter, Depends, HTTPException, Request, status
 
 
 
 
-from open_webui.utils.utils import get_admin_user, get_verified_user
+from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.utils.access_control import has_access, has_permission
 from open_webui.utils.access_control import has_access, has_permission
 
 
 
 

Fichier diff supprimé car celui-ci est trop grand
+ 353 - 344
backend/open_webui/routers/ollama.py


+ 329 - 272
backend/open_webui/apps/openai/main.py → backend/open_webui/routers/openai.py

@@ -10,15 +10,15 @@ from aiocache import cached
 import requests
 import requests
 
 
 
 
-from open_webui.apps.webui.models.models import Models
+from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse, StreamingResponse
+from pydantic import BaseModel
+from starlette.background import BackgroundTask
+
+from open_webui.models.models import Models
 from open_webui.config import (
 from open_webui.config import (
     CACHE_DIR,
     CACHE_DIR,
-    CORS_ALLOW_ORIGIN,
-    ENABLE_OPENAI_API,
-    OPENAI_API_BASE_URLS,
-    OPENAI_API_KEYS,
-    OPENAI_API_CONFIGS,
-    AppConfig,
 )
 )
 from open_webui.env import (
 from open_webui.env import (
     AIOHTTP_CLIENT_TIMEOUT,
     AIOHTTP_CLIENT_TIMEOUT,
@@ -29,18 +29,14 @@ from open_webui.env import (
 
 
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.env import ENV, SRC_LOG_LEVELS
 from open_webui.env import ENV, SRC_LOG_LEVELS
-from fastapi import Depends, FastAPI, HTTPException, Request
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import FileResponse, StreamingResponse
-from pydantic import BaseModel
-from starlette.background import BackgroundTask
+
 
 
 from open_webui.utils.payload import (
 from open_webui.utils.payload import (
     apply_model_params_to_body_openai,
     apply_model_params_to_body_openai,
     apply_model_system_prompt_to_body,
     apply_model_system_prompt_to_body,
 )
 )
 
 
-from open_webui.utils.utils import get_admin_user, get_verified_user
+from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.utils.access_control import has_access
 from open_webui.utils.access_control import has_access
 
 
 
 
@@ -48,36 +44,69 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["OPENAI"])
 log.setLevel(SRC_LOG_LEVELS["OPENAI"])
 
 
 
 
-app = FastAPI(
-    docs_url="/docs" if ENV == "dev" else None,
-    openapi_url="/openapi.json" if ENV == "dev" else None,
-    redoc_url=None,
-)
+##########################################
+#
+# Utility functions
+#
+##########################################
 
 
 
 
-app.add_middleware(
-    CORSMiddleware,
-    allow_origins=CORS_ALLOW_ORIGIN,
-    allow_credentials=True,
-    allow_methods=["*"],
-    allow_headers=["*"],
-)
+async def send_get_request(url, key=None):
+    timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
+    try:
+        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
+            async with session.get(
+                url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
+            ) as response:
+                return await response.json()
+    except Exception as e:
+        # Handle connection error here
+        log.error(f"Connection error: {e}")
+        return None
+
+
+async def cleanup_response(
+    response: Optional[aiohttp.ClientResponse],
+    session: Optional[aiohttp.ClientSession],
+):
+    if response:
+        response.close()
+    if session:
+        await session.close()
+
 
 
-app.state.config = AppConfig()
+def openai_o1_handler(payload):
+    """
+    Handle O1 specific parameters
+    """
+    if "max_tokens" in payload:
+        # Remove "max_tokens" from the payload
+        payload["max_completion_tokens"] = payload["max_tokens"]
+        del payload["max_tokens"]
 
 
-app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
-app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
-app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
-app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
+    # Fix: O1 does not support the "system" parameter, Modify "system" to "user"
+    if payload["messages"][0]["role"] == "system":
+        payload["messages"][0]["role"] = "user"
 
 
+    return payload
 
 
-@app.get("/config")
-async def get_config(user=Depends(get_admin_user)):
+
+##########################################
+#
+# API routes
+#
+##########################################
+
+router = APIRouter()
+
+
+@router.get("/config")
+async def get_config(request: Request, user=Depends(get_admin_user)):
     return {
     return {
-        "ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API,
-        "OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS,
-        "OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS,
-        "OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS,
+        "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API,
+        "OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS,
+        "OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS,
+        "OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS,
     }
     }
 
 
 
 
@@ -88,50 +117,56 @@ class OpenAIConfigForm(BaseModel):
     OPENAI_API_CONFIGS: dict
     OPENAI_API_CONFIGS: dict
 
 
 
 
-@app.post("/config/update")
-async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
-    app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API
-
-    app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS
-    app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS
+@router.post("/config/update")
+async def update_config(
+    request: Request, form_data: OpenAIConfigForm, user=Depends(get_admin_user)
+):
+    request.app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API
+    request.app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS
+    request.app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS
 
 
     # Check if API KEYS length is same than API URLS length
     # Check if API KEYS length is same than API URLS length
-    if len(app.state.config.OPENAI_API_KEYS) != len(
-        app.state.config.OPENAI_API_BASE_URLS
+    if len(request.app.state.config.OPENAI_API_KEYS) != len(
+        request.app.state.config.OPENAI_API_BASE_URLS
     ):
     ):
-        if len(app.state.config.OPENAI_API_KEYS) > len(
-            app.state.config.OPENAI_API_BASE_URLS
+        if len(request.app.state.config.OPENAI_API_KEYS) > len(
+            request.app.state.config.OPENAI_API_BASE_URLS
         ):
         ):
-            app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
-                : len(app.state.config.OPENAI_API_BASE_URLS)
-            ]
+            request.app.state.config.OPENAI_API_KEYS = (
+                request.app.state.config.OPENAI_API_KEYS[
+                    : len(request.app.state.config.OPENAI_API_BASE_URLS)
+                ]
+            )
         else:
         else:
-            app.state.config.OPENAI_API_KEYS += [""] * (
-                len(app.state.config.OPENAI_API_BASE_URLS)
-                - len(app.state.config.OPENAI_API_KEYS)
+            request.app.state.config.OPENAI_API_KEYS += [""] * (
+                len(request.app.state.config.OPENAI_API_BASE_URLS)
+                - len(request.app.state.config.OPENAI_API_KEYS)
             )
             )
 
 
-    app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS
+    request.app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS
 
 
     # Remove any extra configs
     # Remove any extra configs
-    config_urls = app.state.config.OPENAI_API_CONFIGS.keys()
-    for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS):
+    config_urls = request.app.state.config.OPENAI_API_CONFIGS.keys()
+    for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS):
         if url not in config_urls:
         if url not in config_urls:
-            app.state.config.OPENAI_API_CONFIGS.pop(url, None)
+            request.app.state.config.OPENAI_API_CONFIGS.pop(url, None)
 
 
     return {
     return {
-        "ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API,
-        "OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS,
-        "OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS,
-        "OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS,
+        "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API,
+        "OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS,
+        "OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS,
+        "OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS,
     }
     }
 
 
 
 
-@app.post("/audio/speech")
+@router.post("/audio/speech")
 async def speech(request: Request, user=Depends(get_verified_user)):
 async def speech(request: Request, user=Depends(get_verified_user)):
     idx = None
     idx = None
     try:
     try:
-        idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
+        idx = request.app.state.config.OPENAI_API_BASE_URLS.index(
+            "https://api.openai.com/v1"
+        )
+
         body = await request.body()
         body = await request.body()
         name = hashlib.sha256(body).hexdigest()
         name = hashlib.sha256(body).hexdigest()
 
 
@@ -144,23 +179,35 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         if file_path.is_file():
         if file_path.is_file():
             return FileResponse(file_path)
             return FileResponse(file_path)
 
 
-        headers = {}
-        headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
-        headers["Content-Type"] = "application/json"
-        if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
-            headers["HTTP-Referer"] = "https://openwebui.com/"
-            headers["X-Title"] = "Open WebUI"
-        if ENABLE_FORWARD_USER_INFO_HEADERS:
-            headers["X-OpenWebUI-User-Name"] = user.name
-            headers["X-OpenWebUI-User-Id"] = user.id
-            headers["X-OpenWebUI-User-Email"] = user.email
-            headers["X-OpenWebUI-User-Role"] = user.role
+        url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
+
         r = None
         r = None
         try:
         try:
             r = requests.post(
             r = requests.post(
-                url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
+                url=f"{url}/audio/speech",
                 data=body,
                 data=body,
-                headers=headers,
+                headers={
+                    "Content-Type": "application/json",
+                    "Authorization": f"Bearer {request.app.state.config.OPENAI_API_KEYS[idx]}",
+                    **(
+                        {
+                            "HTTP-Referer": "https://openwebui.com/",
+                            "X-Title": "Open WebUI",
+                        }
+                        if "openrouter.ai" in url
+                        else {}
+                    ),
+                    **(
+                        {
+                            "X-OpenWebUI-User-Name": user.name,
+                            "X-OpenWebUI-User-Id": user.id,
+                            "X-OpenWebUI-User-Email": user.email,
+                            "X-OpenWebUI-User-Role": user.role,
+                        }
+                        if ENABLE_FORWARD_USER_INFO_HEADERS
+                        else {}
+                    ),
+                },
                 stream=True,
                 stream=True,
             )
             )
 
 
@@ -179,115 +226,62 @@ async def speech(request: Request, user=Depends(get_verified_user)):
 
 
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)
-            error_detail = "Open WebUI: Server Connection Error"
+
+            detail = None
             if r is not None:
             if r is not None:
                 try:
                 try:
                     res = r.json()
                     res = r.json()
                     if "error" in res:
                     if "error" in res:
-                        error_detail = f"External: {res['error']}"
+                        detail = f"External: {res['error']}"
                 except Exception:
                 except Exception:
-                    error_detail = f"External: {e}"
+                    detail = f"External: {e}"
 
 
             raise HTTPException(
             raise HTTPException(
-                status_code=r.status_code if r else 500, detail=error_detail
+                status_code=r.status_code if r else 500,
+                detail=detail if detail else "Open WebUI: Server Connection Error",
             )
             )
 
 
     except ValueError:
     except ValueError:
         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
 
 
 
 
-async def aiohttp_get(url, key=None):
-    timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
-    try:
-        headers = {"Authorization": f"Bearer {key}"} if key else {}
-        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
-            async with session.get(url, headers=headers) as response:
-                return await response.json()
-    except Exception as e:
-        # Handle connection error here
-        log.error(f"Connection error: {e}")
-        return None
-
-
-async def cleanup_response(
-    response: Optional[aiohttp.ClientResponse],
-    session: Optional[aiohttp.ClientSession],
-):
-    if response:
-        response.close()
-    if session:
-        await session.close()
-
-
-def merge_models_lists(model_lists):
-    log.debug(f"merge_models_lists {model_lists}")
-    merged_list = []
-
-    for idx, models in enumerate(model_lists):
-        if models is not None and "error" not in models:
-            merged_list.extend(
-                [
-                    {
-                        **model,
-                        "name": model.get("name", model["id"]),
-                        "owned_by": "openai",
-                        "openai": model,
-                        "urlIdx": idx,
-                    }
-                    for model in models
-                    if "api.openai.com"
-                    not in app.state.config.OPENAI_API_BASE_URLS[idx]
-                    or not any(
-                        name in model["id"]
-                        for name in [
-                            "babbage",
-                            "dall-e",
-                            "davinci",
-                            "embedding",
-                            "tts",
-                            "whisper",
-                        ]
-                    )
-                ]
-            )
-
-    return merged_list
-
-
-async def get_all_models_responses() -> list:
-    if not app.state.config.ENABLE_OPENAI_API:
+async def get_all_models_responses(request: Request) -> list:
+    if not request.app.state.config.ENABLE_OPENAI_API:
         return []
         return []
 
 
     # Check if API KEYS length is same than API URLS length
     # Check if API KEYS length is same than API URLS length
-    num_urls = len(app.state.config.OPENAI_API_BASE_URLS)
-    num_keys = len(app.state.config.OPENAI_API_KEYS)
+    num_urls = len(request.app.state.config.OPENAI_API_BASE_URLS)
+    num_keys = len(request.app.state.config.OPENAI_API_KEYS)
 
 
     if num_keys != num_urls:
     if num_keys != num_urls:
         # if there are more keys than urls, remove the extra keys
         # if there are more keys than urls, remove the extra keys
         if num_keys > num_urls:
         if num_keys > num_urls:
-            new_keys = app.state.config.OPENAI_API_KEYS[:num_urls]
-            app.state.config.OPENAI_API_KEYS = new_keys
+            new_keys = request.app.state.config.OPENAI_API_KEYS[:num_urls]
+            request.app.state.config.OPENAI_API_KEYS = new_keys
         # if there are more urls than keys, add empty keys
         # if there are more urls than keys, add empty keys
         else:
         else:
-            app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
-
-    tasks = []
-    for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS):
-        if url not in app.state.config.OPENAI_API_CONFIGS:
-            tasks.append(
-                aiohttp_get(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
+            request.app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
+
+    request_tasks = []
+    for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS):
+        if url not in request.app.state.config.OPENAI_API_CONFIGS:
+            request_tasks.append(
+                send_get_request(
+                    f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx]
+                )
             )
             )
         else:
         else:
-            api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {})
+            api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {})
 
 
             enable = api_config.get("enable", True)
             enable = api_config.get("enable", True)
             model_ids = api_config.get("model_ids", [])
             model_ids = api_config.get("model_ids", [])
 
 
             if enable:
             if enable:
                 if len(model_ids) == 0:
                 if len(model_ids) == 0:
-                    tasks.append(
-                        aiohttp_get(
-                            f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]
+                    request_tasks.append(
+                        send_get_request(
+                            f"{url}/models",
+                            request.app.state.config.OPENAI_API_KEYS[idx],
                         )
                         )
                     )
                     )
                 else:
                 else:
@@ -305,16 +299,18 @@ async def get_all_models_responses() -> list:
                         ],
                         ],
                     }
                     }
 
 
-                    tasks.append(asyncio.ensure_future(asyncio.sleep(0, model_list)))
+                    request_tasks.append(
+                        asyncio.ensure_future(asyncio.sleep(0, model_list))
+                    )
             else:
             else:
-                tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
+                request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
 
 
-    responses = await asyncio.gather(*tasks)
+    responses = await asyncio.gather(*request_tasks)
 
 
     for idx, response in enumerate(responses):
     for idx, response in enumerate(responses):
         if response:
         if response:
-            url = app.state.config.OPENAI_API_BASE_URLS[idx]
-            api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {})
+            url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
+            api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {})
 
 
             prefix_id = api_config.get("prefix_id", None)
             prefix_id = api_config.get("prefix_id", None)
 
 
@@ -325,18 +321,30 @@ async def get_all_models_responses() -> list:
                     model["id"] = f"{prefix_id}.{model['id']}"
                     model["id"] = f"{prefix_id}.{model['id']}"
 
 
     log.debug(f"get_all_models:responses() {responses}")
     log.debug(f"get_all_models:responses() {responses}")
-
     return responses
     return responses
 
 
 
 
+async def get_filtered_models(models, user):
+    # Filter models based on user access control
+    filtered_models = []
+    for model in models.get("data", []):
+        model_info = Models.get_model_by_id(model["id"])
+        if model_info:
+            if user.id == model_info.user_id or has_access(
+                user.id, type="read", access_control=model_info.access_control
+            ):
+                filtered_models.append(model)
+    return filtered_models
+
+
 @cached(ttl=3)
 @cached(ttl=3)
-async def get_all_models() -> dict[str, list]:
+async def get_all_models(request: Request) -> dict[str, list]:
     log.info("get_all_models()")
     log.info("get_all_models()")
 
 
-    if not app.state.config.ENABLE_OPENAI_API:
+    if not request.app.state.config.ENABLE_OPENAI_API:
         return {"data": []}
         return {"data": []}
 
 
-    responses = await get_all_models_responses()
+    responses = await get_all_models_responses(request)
 
 
     def extract_data(response):
     def extract_data(response):
         if response and "data" in response:
         if response and "data" in response:
@@ -345,41 +353,86 @@ async def get_all_models() -> dict[str, list]:
             return response
             return response
         return None
         return None
 
 
+    def merge_models_lists(model_lists):
+        log.debug(f"merge_models_lists {model_lists}")
+        merged_list = []
+
+        for idx, models in enumerate(model_lists):
+            if models is not None and "error" not in models:
+                merged_list.extend(
+                    [
+                        {
+                            **model,
+                            "name": model.get("name", model["id"]),
+                            "owned_by": "openai",
+                            "openai": model,
+                            "urlIdx": idx,
+                        }
+                        for model in models
+                        if "api.openai.com"
+                        not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
+                        or not any(
+                            name in model["id"]
+                            for name in [
+                                "babbage",
+                                "dall-e",
+                                "davinci",
+                                "embedding",
+                                "tts",
+                                "whisper",
+                            ]
+                        )
+                    ]
+                )
+
+        return merged_list
+
     models = {"data": merge_models_lists(map(extract_data, responses))}
     models = {"data": merge_models_lists(map(extract_data, responses))}
     log.debug(f"models: {models}")
     log.debug(f"models: {models}")
 
 
+    request.app.state.OPENAI_MODELS = {model["id"]: model for model in models["data"]}
     return models
     return models
 
 
 
 
-@app.get("/models")
-@app.get("/models/{url_idx}")
-async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
+@router.get("/models")
+@router.get("/models/{url_idx}")
+async def get_models(
+    request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)
+):
     models = {
     models = {
         "data": [],
         "data": [],
     }
     }
 
 
     if url_idx is None:
     if url_idx is None:
-        models = await get_all_models()
+        models = await get_all_models(request)
     else:
     else:
-        url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
-        key = app.state.config.OPENAI_API_KEYS[url_idx]
-
-        headers = {}
-        headers["Authorization"] = f"Bearer {key}"
-        headers["Content-Type"] = "application/json"
-
-        if ENABLE_FORWARD_USER_INFO_HEADERS:
-            headers["X-OpenWebUI-User-Name"] = user.name
-            headers["X-OpenWebUI-User-Id"] = user.id
-            headers["X-OpenWebUI-User-Email"] = user.email
-            headers["X-OpenWebUI-User-Role"] = user.role
+        url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
+        key = request.app.state.config.OPENAI_API_KEYS[url_idx]
 
 
         r = None
         r = None
-
-        timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
-        async with aiohttp.ClientSession(timeout=timeout) as session:
+        async with aiohttp.ClientSession(
+            timeout=aiohttp.ClientTimeout(
+                total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
+            )
+        ) as session:
             try:
             try:
-                async with session.get(f"{url}/models", headers=headers) as r:
+                async with session.get(
+                    f"{url}/models",
+                    headers={
+                        "Authorization": f"Bearer {key}",
+                        "Content-Type": "application/json",
+                        **(
+                            {
+                                "X-OpenWebUI-User-Name": user.name,
+                                "X-OpenWebUI-User-Id": user.id,
+                                "X-OpenWebUI-User-Email": user.email,
+                                "X-OpenWebUI-User-Role": user.role,
+                            }
+                            if ENABLE_FORWARD_USER_INFO_HEADERS
+                            else {}
+                        ),
+                    },
+                ) as r:
                     if r.status != 200:
                     if r.status != 200:
                         # Extract response error details if available
                         # Extract response error details if available
                         error_detail = f"HTTP Error: {r.status}"
                         error_detail = f"HTTP Error: {r.status}"
@@ -413,27 +466,16 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
             except aiohttp.ClientError as e:
             except aiohttp.ClientError as e:
                 # ClientError covers all aiohttp requests issues
                 # ClientError covers all aiohttp requests issues
                 log.exception(f"Client error: {str(e)}")
                 log.exception(f"Client error: {str(e)}")
-                # Handle aiohttp-specific connection issues, timeout etc.
                 raise HTTPException(
                 raise HTTPException(
                     status_code=500, detail="Open WebUI: Server Connection Error"
                     status_code=500, detail="Open WebUI: Server Connection Error"
                 )
                 )
             except Exception as e:
             except Exception as e:
                 log.exception(f"Unexpected error: {e}")
                 log.exception(f"Unexpected error: {e}")
-                # Generic error handler in case parsing JSON or other steps fail
                 error_detail = f"Unexpected error: {str(e)}"
                 error_detail = f"Unexpected error: {str(e)}"
                 raise HTTPException(status_code=500, detail=error_detail)
                 raise HTTPException(status_code=500, detail=error_detail)
 
 
     if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
     if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
-        # Filter models based on user access control
-        filtered_models = []
-        for model in models.get("data", []):
-            model_info = Models.get_model_by_id(model["id"])
-            if model_info:
-                if user.id == model_info.user_id or has_access(
-                    user.id, type="read", access_control=model_info.access_control
-                ):
-                    filtered_models.append(model)
-        models["data"] = filtered_models
+        models["data"] = get_filtered_models(models, user)
 
 
     return models
     return models
 
 
@@ -443,21 +485,24 @@ class ConnectionVerificationForm(BaseModel):
     key: str
     key: str
 
 
 
 
-@app.post("/verify")
+@router.post("/verify")
 async def verify_connection(
 async def verify_connection(
     form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
     form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
 ):
 ):
     url = form_data.url
     url = form_data.url
     key = form_data.key
     key = form_data.key
 
 
-    headers = {}
-    headers["Authorization"] = f"Bearer {key}"
-    headers["Content-Type"] = "application/json"
-
-    timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
-    async with aiohttp.ClientSession(timeout=timeout) as session:
+    async with aiohttp.ClientSession(
+        timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
+    ) as session:
         try:
         try:
-            async with session.get(f"{url}/models", headers=headers) as r:
+            async with session.get(
+                f"{url}/models",
+                headers={
+                    "Authorization": f"Bearer {key}",
+                    "Content-Type": "application/json",
+                },
+            ) as r:
                 if r.status != 200:
                 if r.status != 200:
                     # Extract response error details if available
                     # Extract response error details if available
                     error_detail = f"HTTP Error: {r.status}"
                     error_detail = f"HTTP Error: {r.status}"
@@ -472,26 +517,27 @@ async def verify_connection(
         except aiohttp.ClientError as e:
         except aiohttp.ClientError as e:
             # ClientError covers all aiohttp requests issues
             # ClientError covers all aiohttp requests issues
             log.exception(f"Client error: {str(e)}")
             log.exception(f"Client error: {str(e)}")
-            # Handle aiohttp-specific connection issues, timeout etc.
             raise HTTPException(
             raise HTTPException(
                 status_code=500, detail="Open WebUI: Server Connection Error"
                 status_code=500, detail="Open WebUI: Server Connection Error"
             )
             )
         except Exception as e:
         except Exception as e:
             log.exception(f"Unexpected error: {e}")
             log.exception(f"Unexpected error: {e}")
-            # Generic error handler in case parsing JSON or other steps fail
             error_detail = f"Unexpected error: {str(e)}"
             error_detail = f"Unexpected error: {str(e)}"
             raise HTTPException(status_code=500, detail=error_detail)
             raise HTTPException(status_code=500, detail=error_detail)
 
 
 
 
-@app.post("/chat/completions")
+@router.post("/chat/completions")
 async def generate_chat_completion(
 async def generate_chat_completion(
+    request: Request,
     form_data: dict,
     form_data: dict,
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
     bypass_filter: Optional[bool] = False,
     bypass_filter: Optional[bool] = False,
 ):
 ):
+    if BYPASS_MODEL_ACCESS_CONTROL:
+        bypass_filter = True
+
     idx = 0
     idx = 0
     payload = {**form_data}
     payload = {**form_data}
-
     if "metadata" in payload:
     if "metadata" in payload:
         del payload["metadata"]
         del payload["metadata"]
 
 
@@ -502,6 +548,7 @@ async def generate_chat_completion(
     if model_info:
     if model_info:
         if model_info.base_model_id:
         if model_info.base_model_id:
             payload["model"] = model_info.base_model_id
             payload["model"] = model_info.base_model_id
+            model_id = model_info.base_model_id
 
 
         params = model_info.params.model_dump()
         params = model_info.params.model_dump()
         payload = apply_model_params_to_body_openai(params, payload)
         payload = apply_model_params_to_body_openai(params, payload)
@@ -526,15 +573,7 @@ async def generate_chat_completion(
                 detail="Model not found",
                 detail="Model not found",
             )
             )
 
 
-    # Attemp to get urlIdx from the model
-    models = await get_all_models()
-
-    # Find the model from the list
-    model = next(
-        (model for model in models["data"] if model["id"] == payload.get("model")),
-        None,
-    )
-
+    model = request.app.state.OPENAI_MODELS.get(model_id)
     if model:
     if model:
         idx = model["urlIdx"]
         idx = model["urlIdx"]
     else:
     else:
@@ -544,11 +583,11 @@ async def generate_chat_completion(
         )
         )
 
 
     # Get the API config for the model
     # Get the API config for the model
-    api_config = app.state.config.OPENAI_API_CONFIGS.get(
-        app.state.config.OPENAI_API_BASE_URLS[idx], {}
+    api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
+        request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
     )
     )
-    prefix_id = api_config.get("prefix_id", None)
 
 
+    prefix_id = api_config.get("prefix_id", None)
     if prefix_id:
     if prefix_id:
         payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
         payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
 
 
@@ -561,43 +600,26 @@ async def generate_chat_completion(
             "role": user.role,
             "role": user.role,
         }
         }
 
 
-    url = app.state.config.OPENAI_API_BASE_URLS[idx]
-    key = app.state.config.OPENAI_API_KEYS[idx]
+    url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
+    key = request.app.state.config.OPENAI_API_KEYS[idx]
 
 
     # Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
     # Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
     is_o1 = payload["model"].lower().startswith("o1-")
     is_o1 = payload["model"].lower().startswith("o1-")
-    # Change max_completion_tokens to max_tokens (Backward compatible)
-    if "api.openai.com" not in url and not is_o1:
+    if is_o1:
+        payload = openai_o1_handler(payload)
+    elif "api.openai.com" not in url:
+        # Remove "max_completion_tokens" from the payload for backward compatibility
         if "max_completion_tokens" in payload:
         if "max_completion_tokens" in payload:
-            # Remove "max_completion_tokens" from the payload
             payload["max_tokens"] = payload["max_completion_tokens"]
             payload["max_tokens"] = payload["max_completion_tokens"]
             del payload["max_completion_tokens"]
             del payload["max_completion_tokens"]
-    else:
-        if is_o1 and "max_tokens" in payload:
-            payload["max_completion_tokens"] = payload["max_tokens"]
-            del payload["max_tokens"]
-        if "max_tokens" in payload and "max_completion_tokens" in payload:
-            del payload["max_tokens"]
 
 
-    # Fix: O1 does not support the "system" parameter, Modify "system" to "user"
-    if is_o1 and payload["messages"][0]["role"] == "system":
-        payload["messages"][0]["role"] = "user"
+    # TODO: check if below is needed
+    # if "max_tokens" in payload and "max_completion_tokens" in payload:
+    #     del payload["max_tokens"]
 
 
     # Convert the modified body back to JSON
     # Convert the modified body back to JSON
     payload = json.dumps(payload)
     payload = json.dumps(payload)
 
 
-    headers = {}
-    headers["Authorization"] = f"Bearer {key}"
-    headers["Content-Type"] = "application/json"
-    if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
-        headers["HTTP-Referer"] = "https://openwebui.com/"
-        headers["X-Title"] = "Open WebUI"
-    if ENABLE_FORWARD_USER_INFO_HEADERS:
-        headers["X-OpenWebUI-User-Name"] = user.name
-        headers["X-OpenWebUI-User-Id"] = user.id
-        headers["X-OpenWebUI-User-Email"] = user.email
-        headers["X-OpenWebUI-User-Role"] = user.role
-
     r = None
     r = None
     session = None
     session = None
     streaming = False
     streaming = False
@@ -607,11 +629,33 @@ async def generate_chat_completion(
         session = aiohttp.ClientSession(
         session = aiohttp.ClientSession(
             trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
             trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
         )
         )
+
         r = await session.request(
         r = await session.request(
             method="POST",
             method="POST",
             url=f"{url}/chat/completions",
             url=f"{url}/chat/completions",
             data=payload,
             data=payload,
-            headers=headers,
+            headers={
+                "Authorization": f"Bearer {key}",
+                "Content-Type": "application/json",
+                **(
+                    {
+                        "HTTP-Referer": "https://openwebui.com/",
+                        "X-Title": "Open WebUI",
+                    }
+                    if "openrouter.ai" in url
+                    else {}
+                ),
+                **(
+                    {
+                        "X-OpenWebUI-User-Name": user.name,
+                        "X-OpenWebUI-User-Id": user.id,
+                        "X-OpenWebUI-User-Email": user.email,
+                        "X-OpenWebUI-User-Role": user.role,
+                    }
+                    if ENABLE_FORWARD_USER_INFO_HEADERS
+                    else {}
+                ),
+            },
         )
         )
 
 
         # Check if response is SSE
         # Check if response is SSE
@@ -636,14 +680,18 @@ async def generate_chat_completion(
             return response
             return response
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
+
+        detail = None
         if isinstance(response, dict):
         if isinstance(response, dict):
             if "error" in response:
             if "error" in response:
-                error_detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}"
+                detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}"
         elif isinstance(response, str):
         elif isinstance(response, str):
-            error_detail = response
+            detail = response
 
 
-        raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
+        raise HTTPException(
+            status_code=r.status if r else 500,
+            detail=detail if detail else "Open WebUI: Server Connection Error",
+        )
     finally:
     finally:
         if not streaming and session:
         if not streaming and session:
             if r:
             if r:
@@ -651,25 +699,17 @@ async def generate_chat_completion(
             await session.close()
             await session.close()
 
 
 
 
-@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
+@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
 async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
 async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
-    idx = 0
+    """
+    Deprecated: proxy all requests to OpenAI API
+    """
 
 
     body = await request.body()
     body = await request.body()
 
 
-    url = app.state.config.OPENAI_API_BASE_URLS[idx]
-    key = app.state.config.OPENAI_API_KEYS[idx]
-
-    target_url = f"{url}/{path}"
-
-    headers = {}
-    headers["Authorization"] = f"Bearer {key}"
-    headers["Content-Type"] = "application/json"
-    if ENABLE_FORWARD_USER_INFO_HEADERS:
-        headers["X-OpenWebUI-User-Name"] = user.name
-        headers["X-OpenWebUI-User-Id"] = user.id
-        headers["X-OpenWebUI-User-Email"] = user.email
-        headers["X-OpenWebUI-User-Role"] = user.role
+    idx = 0
+    url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
+    key = request.app.state.config.OPENAI_API_KEYS[idx]
 
 
     r = None
     r = None
     session = None
     session = None
@@ -679,11 +719,23 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
         session = aiohttp.ClientSession(trust_env=True)
         session = aiohttp.ClientSession(trust_env=True)
         r = await session.request(
         r = await session.request(
             method=request.method,
             method=request.method,
-            url=target_url,
+            url=f"{url}/{path}",
             data=body,
             data=body,
-            headers=headers,
+            headers={
+                "Authorization": f"Bearer {key}",
+                "Content-Type": "application/json",
+                **(
+                    {
+                        "X-OpenWebUI-User-Name": user.name,
+                        "X-OpenWebUI-User-Id": user.id,
+                        "X-OpenWebUI-User-Email": user.email,
+                        "X-OpenWebUI-User-Role": user.role,
+                    }
+                    if ENABLE_FORWARD_USER_INFO_HEADERS
+                    else {}
+                ),
+            },
         )
         )
-
         r.raise_for_status()
         r.raise_for_status()
 
 
         # Check if response is SSE
         # Check if response is SSE
@@ -700,18 +752,23 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
         else:
         else:
             response_data = await r.json()
             response_data = await r.json()
             return response_data
             return response_data
+
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
+
+        detail = None
         if r is not None:
         if r is not None:
             try:
             try:
                 res = await r.json()
                 res = await r.json()
                 print(res)
                 print(res)
                 if "error" in res:
                 if "error" in res:
-                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
+                    detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
             except Exception:
             except Exception:
-                error_detail = f"External: {e}"
-        raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
+                detail = f"External: {e}"
+        raise HTTPException(
+            status_code=r.status if r else 500,
+            detail=detail if detail else "Open WebUI: Server Connection Error",
+        )
     finally:
     finally:
         if not streaming and session:
         if not streaming and session:
             if r:
             if r:

Certains fichiers n'ont pas été affichés car il y a eu trop de fichiers modifiés dans ce diff