浏览代码

Fix dev upstream merge conflicts

Taylor Wilsdon 4 月之前
父节点
当前提交
e28427803f
共有 100 个文件被更改,包括 4376 次插入7346 次删除
  1. 2 1
      .dockerignore
  2. 4 0
      .github/workflows/deploy-to-hf-spaces.yml
  3. 168 0
      CHANGELOG.md
  4. 60 38
      CODE_OF_CONDUCT.md
  5. 4 2
      README.md
  6. 0 641
      backend/open_webui/apps/audio/main.py
  7. 0 1123
      backend/open_webui/apps/ollama/main.py
  8. 0 557
      backend/open_webui/apps/openai/main.py
  9. 0 1332
      backend/open_webui/apps/retrieval/main.py
  10. 0 14
      backend/open_webui/apps/retrieval/vector/connector.py
  11. 0 458
      backend/open_webui/apps/webui/main.py
  12. 0 157
      backend/open_webui/apps/webui/models/documents.py
  13. 0 155
      backend/open_webui/apps/webui/routers/documents.py
  14. 0 381
      backend/open_webui/apps/webui/routers/knowledge.py
  15. 0 104
      backend/open_webui/apps/webui/routers/models.py
  16. 0 90
      backend/open_webui/apps/webui/routers/prompts.py
  17. 422 47
      backend/open_webui/config.py
  18. 3 0
      backend/open_webui/constants.py
  19. 14 2
      backend/open_webui/env.py
  20. 316 0
      backend/open_webui/functions.py
  21. 1 1
      backend/open_webui/internal/db.py
  22. 0 0
      backend/open_webui/internal/migrations/001_initial_schema.py
  23. 0 0
      backend/open_webui/internal/migrations/002_add_local_sharing.py
  24. 0 0
      backend/open_webui/internal/migrations/003_add_auth_api_key.py
  25. 0 0
      backend/open_webui/internal/migrations/004_add_archived.py
  26. 0 0
      backend/open_webui/internal/migrations/005_add_updated_at.py
  27. 0 0
      backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py
  28. 0 0
      backend/open_webui/internal/migrations/007_add_user_last_active_at.py
  29. 0 0
      backend/open_webui/internal/migrations/008_add_memory.py
  30. 0 0
      backend/open_webui/internal/migrations/009_add_models.py
  31. 0 0
      backend/open_webui/internal/migrations/010_migrate_modelfiles_to_models.py
  32. 0 0
      backend/open_webui/internal/migrations/011_add_user_settings.py
  33. 0 0
      backend/open_webui/internal/migrations/012_add_tools.py
  34. 0 0
      backend/open_webui/internal/migrations/013_add_user_info.py
  35. 0 0
      backend/open_webui/internal/migrations/014_add_files.py
  36. 0 0
      backend/open_webui/internal/migrations/015_add_functions.py
  37. 0 0
      backend/open_webui/internal/migrations/016_add_valves_and_is_active.py
  38. 0 0
      backend/open_webui/internal/migrations/017_add_user_oauth_sub.py
  39. 0 0
      backend/open_webui/internal/migrations/018_add_function_is_global.py
  40. 0 0
      backend/open_webui/internal/wrappers.py
  41. 642 1958
      backend/open_webui/main.py
  42. 1 1
      backend/open_webui/migrations/env.py
  43. 1 1
      backend/open_webui/migrations/script.py.mako
  44. 2 2
      backend/open_webui/migrations/versions/7e5b5dc7342b_init.py
  45. 85 0
      backend/open_webui/migrations/versions/922e7a387820_add_group_table.py
  46. 8 3
      backend/open_webui/models/auths.py
  47. 15 8
      backend/open_webui/models/chats.py
  48. 2 2
      backend/open_webui/models/feedbacks.py
  49. 1 1
      backend/open_webui/models/files.py
  50. 2 2
      backend/open_webui/models/folders.py
  51. 2 2
      backend/open_webui/models/functions.py
  52. 186 0
      backend/open_webui/models/groups.py
  53. 78 25
      backend/open_webui/models/knowledge.py
  54. 1 1
      backend/open_webui/models/memories.py
  55. 104 9
      backend/open_webui/models/models.py
  56. 59 10
      backend/open_webui/models/prompts.py
  57. 1 1
      backend/open_webui/models/tags.py
  58. 61 5
      backend/open_webui/models/tools.py
  59. 10 2
      backend/open_webui/models/users.py
  60. 5 3
      backend/open_webui/retrieval/loaders/main.py
  61. 117 0
      backend/open_webui/retrieval/loaders/youtube.py
  62. 0 0
      backend/open_webui/retrieval/models/colbert.py
  63. 83 124
      backend/open_webui/retrieval/utils.py
  64. 22 0
      backend/open_webui/retrieval/vector/connector.py
  65. 16 3
      backend/open_webui/retrieval/vector/dbs/chroma.py
  66. 1 1
      backend/open_webui/retrieval/vector/dbs/milvus.py
  67. 178 0
      backend/open_webui/retrieval/vector/dbs/opensearch.py
  68. 354 0
      backend/open_webui/retrieval/vector/dbs/pgvector.py
  69. 8 3
      backend/open_webui/retrieval/vector/dbs/qdrant.py
  70. 0 0
      backend/open_webui/retrieval/vector/main.py
  71. 73 0
      backend/open_webui/retrieval/web/bing.py
  72. 1 1
      backend/open_webui/retrieval/web/brave.py
  73. 1 1
      backend/open_webui/retrieval/web/duckduckgo.py
  74. 1 1
      backend/open_webui/retrieval/web/google_pse.py
  75. 3 5
      backend/open_webui/retrieval/web/jina_search.py
  76. 48 0
      backend/open_webui/retrieval/web/kagi.py
  77. 0 0
      backend/open_webui/retrieval/web/main.py
  78. 40 0
      backend/open_webui/retrieval/web/mojeek.py
  79. 1 1
      backend/open_webui/retrieval/web/searchapi.py
  80. 1 1
      backend/open_webui/retrieval/web/searxng.py
  81. 1 1
      backend/open_webui/retrieval/web/serper.py
  82. 1 1
      backend/open_webui/retrieval/web/serply.py
  83. 1 1
      backend/open_webui/retrieval/web/serpstack.py
  84. 1 1
      backend/open_webui/retrieval/web/tavily.py
  85. 58 0
      backend/open_webui/retrieval/web/testdata/bing.json
  86. 0 0
      backend/open_webui/retrieval/web/testdata/brave.json
  87. 0 0
      backend/open_webui/retrieval/web/testdata/google_pse.json
  88. 0 0
      backend/open_webui/retrieval/web/testdata/searchapi.json
  89. 0 0
      backend/open_webui/retrieval/web/testdata/searxng.json
  90. 0 0
      backend/open_webui/retrieval/web/testdata/serper.json
  91. 0 0
      backend/open_webui/retrieval/web/testdata/serply.json
  92. 0 0
      backend/open_webui/retrieval/web/testdata/serpstack.json
  93. 3 3
      backend/open_webui/retrieval/web/utils.py
  94. 703 0
      backend/open_webui/routers/audio.py
  95. 320 8
      backend/open_webui/routers/auths.py
  96. 13 10
      backend/open_webui/routers/chats.py
  97. 32 19
      backend/open_webui/routers/configs.py
  98. 3 3
      backend/open_webui/routers/evaluations.py
  99. 29 16
      backend/open_webui/routers/files.py
  100. 3 3
      backend/open_webui/routers/folders.py

+ 2 - 1
.dockerignore

@@ -16,4 +16,5 @@ _old
 uploads
 uploads
 .ipynb_checkpoints
 .ipynb_checkpoints
 **/*.db
 **/*.db
-_test
+_test
+backend/data/*

+ 4 - 0
.github/workflows/deploy-to-hf-spaces.yml

@@ -28,6 +28,8 @@ jobs:
     steps:
     steps:
       - name: Checkout repository
       - name: Checkout repository
         uses: actions/checkout@v4
         uses: actions/checkout@v4
+        with:
+          lfs: true
 
 
       - name: Remove git history
       - name: Remove git history
         run: rm -rf .git
         run: rm -rf .git
@@ -52,7 +54,9 @@ jobs:
       - name: Set up Git and push to Space
       - name: Set up Git and push to Space
         run: |
         run: |
           git init --initial-branch=main
           git init --initial-branch=main
+          git lfs install
           git lfs track "*.ttf"
           git lfs track "*.ttf"
+          git lfs track "*.jpg"
           rm demo.gif
           rm demo.gif
           git add .
           git add .
           git commit -m "GitHub deploy: ${{ github.sha }}"
           git commit -m "GitHub deploy: ${{ github.sha }}"

+ 168 - 0
CHANGELOG.md

@@ -5,10 +5,178 @@ All notable changes to this project will be documented in this file.
 The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
 The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
 and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
 and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
 
 
+### Added
+
+- **🌐 Enhanced Translations**: Added Slovak language, improved Czech language.
+
+## [0.4.8] - 2024-12-07
+
+### Added
+
+- **🔓 Bypass Model Access Control**: Introduced the 'BYPASS_MODEL_ACCESS_CONTROL' environment variable. Easily bypass model access controls for user roles when access control isn't required, simplifying workflows for trusted environments.
+- **📝 Markdown in Banners**: Now supports markdown for banners, enabling richer, more visually engaging announcements.
+- **🌐 Internationalization Updates**: Enhanced translations across multiple languages, further improving accessibility and global user experience.
+- **🎨 Styling Enhancements**: General UI style refinements for a cleaner and more polished interface.
+- **📋 Rich Text Reliability**: Improved the reliability and stability of rich text input across chats for smoother interactions.
+
+### Fixed
+
+- **💡 Tailwind Build Issue**: Resolved a breaking bug caused by Tailwind, ensuring smoother builds and overall system reliability.
+- **📚 Knowledge Collection Query Fix**: Addressed API endpoint issues with querying knowledge collections, ensuring accurate and reliable information retrieval.
+
+## [0.4.7] - 2024-12-01
+
+### Added
+
+- **✨ Prompt Input Auto-Completion**: Type a prompt and let AI intelligently suggest and complete your inputs. Simply press 'Tab' or swipe right on mobile to confirm. Available only with Rich Text Input (default setting). Disable via Admin Settings for full control.
+- **🌍 Improved Translations**: Enhanced localization for multiple languages, ensuring a more polished and accessible experience for international users.
+
+### Fixed
+
+- **🛠️ Tools Export Issue**: Resolved a critical issue where exporting tools wasn’t functioning, restoring seamless export capabilities.
+- **🔗 Model ID Registration**: Fixed an issue where model IDs weren’t registering correctly in the model editor, ensuring reliable model setup and tracking.
+- **🖋️ Textarea Auto-Expansion**: Corrected a bug where textareas didn’t expand automatically on certain browsers, improving usability for multi-line inputs.
+- **🔧 Ollama Embed Endpoint**: Addressed the /ollama/embed endpoint malfunction, ensuring consistent performance and functionality.
+
+### Changed
+
+- **🎨 Knowledge Base Styling**: Refined knowledge base visuals for a cleaner, more modern look, laying the groundwork for further enhancements in upcoming releases.
+
+## [0.4.6] - 2024-11-26
+
+### Added
+
+- **🌍 Enhanced Translations**: Various language translations improved to make the WebUI more accessible and user-friendly worldwide.
+
+### Fixed
+
+- **✏️ Textarea Shifting Bug**: Resolved the issue where the textarea shifted unexpectedly, ensuring a smoother typing experience.
+- **⚙️ Model Configuration Modal**: Fixed the issue where the models configuration modal introduced in 0.4.5 wasn’t working for some users.
+- **🔍 Legacy Query Support**: Restored functionality for custom query generation in RAG when using legacy prompts, ensuring both default and custom templates now work seamlessly.
+- **⚡ Improved General Reliability**: Various minor fixes improve platform stability and ensure a smoother overall experience across workflows.
+
+## [0.4.5] - 2024-11-26
+
+### Added
+
+- **🎨 Model Order/Defaults Reintroduced**: Brought back the ability to set model order and default models, now configurable via Admin Settings > Models > Configure (Gear Icon).
+
+### Fixed
+
+- **🔍 Query Generation Issue**: Resolved an error in web search query generation, enhancing search accuracy and ensuring smoother search workflows.
+- **📏 Textarea Auto Height Bug**: Fixed a layout issue where textarea input height was shifting unpredictably, particularly when editing system prompts.
+- **🔑 Ollama Authentication**: Corrected an issue with Ollama’s authorization headers, guaranteeing reliable authentication across all endpoints.
+- **⚙️ Missing Min_P Save**: Resolved an issue where the 'min_p' parameter was not being saved in configurations.
+- **🛠️ Tools Description**: Fixed a key issue that omitted tool descriptions in tools payload.
+
+## [0.4.4] - 2024-11-22
+
+### Added
+
+- **🌐 Translation Updates**: Refreshed Catalan, Brazilian Portuguese, German, and Ukrainian translations, further enhancing the platform's accessibility and improving the experience for international users.
+
+### Fixed
+
+- **📱 Mobile Controls Visibility**: Resolved an issue where the controls button was not displaying on the new chats page for mobile users, ensuring smoother navigation and functionality on smaller screens.
+- **📷 LDAP Profile Image Issue**: Fixed an LDAP integration bug related to profile images, ensuring seamless authentication and a reliable login experience for users.
+- **⏳ RAG Query Generation Issue**: Addressed a significant problem where RAG query generation occurred unnecessarily without attached files, drastically improving speed and reducing delays during chat completions.
+
+### Changed
+
+- **⚙️ Legacy Event Emitter Support**: Reintroduced compatibility with legacy "citation" types for event emitters in tools and functions, providing smoother workflows and broader tool support for users.
+
+## [0.4.3] - 2024-11-21
+
+### Added
+
+- **📚 Inline Citations for RAG Results**: Get seamless inline citations for Retrieval-Augmented Generation (RAG) responses using the default RAG prompt. Note: This feature only supports newly uploaded files, improving traceability and providing source clarity.
+- **🎨 Better Rich Text Input Support**: Enjoy smoother and more reliable rich text formatting for chats, enhancing communication quality.
+- **⚡ Faster Model Retrieval**: Implemented caching optimizations for faster model loading, providing a noticeable speed boost across workflows. Further improvements are on the way!
+
+### Fixed
+
+- **🔗 Pipelines Feature Restored**: Resolved a critical issue that previously prevented Pipelines from functioning, ensuring seamless workflows.
+- **✏️ Missing Suffix Field in Ollama Form**: Added the missing "suffix" field to the Ollama generate form, enhancing customization options.
+
+### Changed
+
+- **🗂️ Renamed "Citations" to "Sources"**: Improved clarity and consistency by renaming the "citations" field to "sources" in messages.
+
+## [0.4.2] - 2024-11-20
+
+### Fixed
+
+- **📁 Knowledge Files Visibility Issue**: Resolved the bug preventing individual files in knowledge collections from displaying when referenced with '#'.
+- **🔗 OpenAI Endpoint Prefix**: Fixed the issue where certain OpenAI connections that deviate from the official API spec weren’t working correctly with prefixes.
+- **⚔️ Arena Model Access Control**: Corrected an issue where arena model access control settings were not being saved.
+- **🔧 Usage Capability Selector**: Fixed the broken usage capabilities selector in the model editor.
+
+## [0.4.1] - 2024-11-19
+
+### Added
+
+- **📊 Enhanced Feedback System**: Introduced a detailed 1-10 rating scale for feedback alongside thumbs up/down, preparing for more precise model fine-tuning and improving feedback quality.
+- **ℹ️ Tool Descriptions on Hover**: Easily access tool descriptions by hovering over the message input, providing a smoother workflow with more context when utilizing tools.
+
+### Fixed
+
+- **🗑️ Graceful Handling of Deleted Users**: Resolved an issue where deleted users caused workspace items (models, knowledge, prompts, tools) to fail, ensuring reliable workspace loading.
+- **🔑 API Key Creation**: Fixed an issue preventing users from creating new API keys, restoring secure and seamless API management.
+- **🔗 HTTPS Proxy Fix**: Corrected HTTPS proxy issues affecting the '/api/v1/models/' endpoint, ensuring smoother, uninterrupted model management.
+
+## [0.4.0] - 2024-11-19
+
+### Added
+
+- **👥 User Groups**: You can now create and manage user groups, making user organization seamless.
+- **🔐 Group-Based Access Control**: Set granular access to models, knowledge, prompts, and tools based on user groups, allowing for more controlled and secure environments.
+- **🛠️ Group-Based User Permissions**: Easily manage workspace permissions. Grant users the ability to upload files, delete, edit, or create temporary chats, as well as define their ability to create models, knowledge, prompts, and tools.
+- **🔑 LDAP Support**: Newly introduced LDAP authentication adds robust security and scalability to user management.
+- **🌐 Enhanced OpenAI-Compatible Connections**: Added prefix ID support to avoid model ID clashes, with explicit model ID support for APIs lacking '/models' endpoint support, ensuring smooth operation with custom setups.
+- **🔐 Ollama API Key Support**: Now manage credentials for Ollama when set behind proxies, including the option to utilize prefix ID for proper distinction across multiple Ollama instances.
+- **🔄 Connection Enable/Disable Toggle**: Easily enable or disable individual OpenAI and Ollama connections as needed.
+- **🎨 Redesigned Model Workspace**: Freshly redesigned to improve usability for managing models across users and groups.
+- **🎨 Redesigned Prompt Workspace**: A fresh UI to conveniently organize and manage prompts.
+- **🧩 Sorted Functions Workspace**: Functions are now automatically categorized by type (Action, Filter, Pipe), streamlining management.
+- **💻 Redesigned Collaborative Workspace**: Enhanced support for multiple users contributing to models, knowledge, prompts, or tools, improving collaboration.
+- **🔧 Auto-Selected Tools in Model Editor**: Tools enabled through the model editor are now automatically selected, whereas previously it only gave users the option to enable the tool, reducing manual steps and enhancing efficiency.
+- **🔔 Web Search & Tools Indicator**: A clear indication now shows when web search or tools are active, reducing confusion.
+- **🔑 Toggle API Key Auth**: Tighten security by easily enabling or disabling API key authentication option for Open WebUI.
+- **🗂️ Agentic Retrieval**: Improve RAG accuracy via smart pre-processing of chat history to determine the best queries before retrieval.
+- **📁 Large Text as File Option**: Optionally convert large pasted text into a file upload, keeping the chat interface cleaner.
+- **🗂️ Toggle Citations for Models**: Ability to disable citations has been introduced in the model editor.
+- **🔍 User Settings Search**: Quickly search for settings fields, improving ease of use and navigation.
+- **🗣️ Experimental SpeechT5 TTS**: Local SpeechT5 support added for improved text-to-speech capabilities.
+- **🔄 Unified Reset for Models**: A one-click option has been introduced to reset and remove all models from the Admin Settings.
+- **🛠️ Initial Setup Wizard**: The setup process now explicitly informs users that they are creating an admin account during the first-time setup, ensuring clarity. Previously, users encountered the login page right away without this distinction.
+- **🌐 Enhanced Translations**: Several language translations, including Ukrainian, Norwegian, and Brazilian Portuguese, were refined for better localization.
+
+### Fixed
+
+- **🎥 YouTube Video Attachments**: Fixed issues preventing proper loading and attachment of YouTube videos as files.
+- **🔄 Shared Chat Update**: Corrected issues where shared chats were not updating, improving collaboration consistency.
+- **🔍 DuckDuckGo Rate Limit Fix**: Addressed issues with DuckDuckGo search integration, enhancing search stability and performance when operating within rate limits.
+- **🧾 Citations Relevance Fix**: Adjusted the relevance percentage calculation for citations, so that Open WebUI properly reflect the accuracy of a retrieved document in RAG, ensuring users get clearer insights into sources.
+- **🔑 Jina Search API Key Requirement**: Added the option to input an API key for Jina Search, ensuring smooth functionality as keys are now mandatory.
+
+### Changed
+
+- **🛠️ Functions Moved to Admin Panel**: As Functions operate as advanced plugins, they are now accessible from the Admin Panel instead of the workspace.
+- **🛠️ Manage Ollama Connections**: The "Models" section in Admin Settings has been relocated to Admin Settings > "Connections" > Ollama Connections. You can now manage Ollama instances via a dedicated "Manage Ollama" modal from "Connections", streamlining the setup and configuration of Ollama models.
+- **📊 Base Models in Admin Settings**: Admins can now find all base models, both connections or functions, in the "Models" Admin setting. Global model accessibility can be enabled or disabled here. Models are private by default, requiring explicit permission assignment for user access.
+- **📌 Sticky Model Selection for New Chats**: The model chosen from a previous chat now persists when creating a new chat. If you click "New Chat" again from the new chat page, it will revert to your default model.
+- **🎨 Design Refactoring**: Overall design refinements across the platform have been made, providing a more cohesive and polished user experience.
+
+### Removed
+
+- **📂 Model List Reordering**: Temporarily removed and will be reintroduced in upcoming user group settings improvements.
+- **⚙️ Default Model Setting**: Removed the ability to set a default model for users, will be reintroduced with user group settings in the future.
+
 ## [0.3.35] - 2024-10-26
 ## [0.3.35] - 2024-10-26
 
 
 ### Added
 ### Added
 
 
+- **🌐 Translation Update**: Added translation labels in the SearchInput and CreateCollection components and updated Brazilian Portuguese translation (pt-BR)
 - **📁 Robust File Handling**: Enhanced file input handling for chat. If the content extraction fails or is empty, users will now receive a clear warning, preventing silent failures and ensuring you always know what's happening with your uploads.
 - **📁 Robust File Handling**: Enhanced file input handling for chat. If the content extraction fails or is empty, users will now receive a clear warning, preventing silent failures and ensuring you always know what's happening with your uploads.
 - **🌍 New Language Support**: Introduced Hungarian translations and updated French translations, expanding the platform's language accessibility for a more global user base.
 - **🌍 New Language Support**: Introduced Hungarian translations and updated French translations, expanding the platform's language accessibility for a more global user base.
 
 

+ 60 - 38
CODE_OF_CONDUCT.md

@@ -2,76 +2,98 @@
 
 
 ## Our Pledge
 ## Our Pledge
 
 
-We as members, contributors, and leaders pledge to make participation in our
-community a harassment-free experience for everyone, regardless of age, body
-size, visible or invisible disability, ethnicity, sex characteristics, gender
-identity and expression, level of experience, education, socio-economic status,
-nationality, personal appearance, race, religion, or sexual identity
-and orientation.
+As members, contributors, and leaders of this community, we pledge to make participation in our open-source project a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
 
 
-We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.
+We are committed to creating and maintaining an open, respectful, and professional environment where positive contributions and meaningful discussions can flourish. By participating in this project, you agree to uphold these values and align your behavior to the standards outlined in this Code of Conduct.
+
+## Why These Standards Are Important
+
+Open-source projects rely on a community of volunteers dedicating their time, expertise, and effort toward a shared goal. These projects are inherently collaborative but also fragile, as the success of the project depends on the goodwill, energy, and productivity of those involved.
+
+Maintaining a positive and respectful environment is essential to safeguarding the integrity of this project and protecting contributors' efforts. Behavior that disrupts this atmosphere—whether through hostility, entitlement, or unprofessional conduct—can severely harm the morale and productivity of the community. **Strict enforcement of these standards ensures a safe and supportive space for meaningful collaboration.**
+
+This is a community where **respect and professionalism are mandatory.** Violations of these standards will result in **zero tolerance** and immediate enforcement to prevent disruption and ensure the well-being of all participants.
 
 
 ## Our Standards
 ## Our Standards
 
 
-Examples of behavior that contribute to a positive environment for our community include:
+Examples of behavior that contribute to a positive and professional community include:
 
 
-- Demonstrating empathy and kindness toward other people
-- Being respectful of differing opinions, viewpoints, and experiences
-- Giving and gracefully accepting constructive feedback
-- Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience
-- Focusing on what is best not just for us as individuals, but for the overall community
+- **Respecting others.** Be considerate, listen actively, and engage with empathy toward others' viewpoints and experiences.
+- **Constructive feedback.** Provide actionable, thoughtful, and respectful feedback that helps improve the project and encourages collaboration. Avoid unproductive negativity or hypercriticism.
+- **Recognizing volunteer contributions.** Appreciate that contributors dedicate their free time and resources selflessly. Approach them with gratitude and patience.
+- **Focusing on shared goals.** Collaborate in ways that prioritize the health, success, and sustainability of the community over individual agendas.
 
 
 Examples of unacceptable behavior include:
 Examples of unacceptable behavior include:
 
 
-- The use of sexualized language or imagery, and sexual attention or advances of any kind
-- Trolling, insulting or derogatory comments, and personal or political attacks
-- Public or private harassment
-- Publishing others' private information, such as a physical or email address, without their explicit permission
-- **Spamming of any kind**
-- Aggressive sales tactics targeting our community members are strictly prohibited. You can mention your product if it's relevant to the discussion, but under no circumstances should you push it forcefully
-- Other conduct which could reasonably be considered inappropriate in a professional setting
+- The use of discriminatory, demeaning, or sexualized language or behavior.
+- Personal attacks, derogatory comments, trolling, or inflammatory political or ideological arguments.
+- Harassment, intimidation, or any behavior intended to create a hostile, uncomfortable, or unsafe environment.
+- Publishing others' private information (e.g., physical or email addresses) without explicit permission.
+- **Entitlement, demand, or aggression toward contributors.** Volunteers are under no obligation to provide immediate or personalized support. Rude or dismissive behavior will not be tolerated.
+- **Unproductive or destructive behavior.** This includes venting frustration as hostility ("tantrums"), hypercriticism, attention-seeking negativity, or anything that distracts from the project's goals.
+- **Spamming and promotional exploitation.** Sharing irrelevant product promotions or self-promotion in the community is not allowed unless it directly contributes value to the discussion.
+
+### Feedback and Community Engagement
+
+- **Constructive feedback is encouraged, but hostile or entitled behavior will result in immediate action.** If you disagree with elements of the project, we encourage you to offer meaningful improvements or fork the project if necessary. Healthy discussions and technical disagreements are welcome only when handled with professionalism.
+- **Respect contributors' time and efforts.** No one is entitled to personalized or on-demand assistance. This is a community built on collaboration and shared effort; demanding or demeaning behavior undermines that trust and will not be allowed.
+
+### Zero Tolerance: No Warnings, Immediate Action
+
+This community operates under a **zero-tolerance policy.** Any behavior deemed unacceptable under this Code of Conduct will result in **immediate enforcement, without prior warning.**
+
+We employ this approach to ensure that unproductive or disruptive behavior does not escalate further or cause unnecessary harm to other contributors. The standards are clear, and violations of any kind—whether mild or severe—will be addressed decisively to protect the community.
 
 
 ## Enforcement Responsibilities
 ## Enforcement Responsibilities
 
 
-Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful.
+Community leaders are responsible for upholding and enforcing these standards. They are empowered to take **immediate and appropriate action** to address any behaviors they deem unacceptable under this Code of Conduct. These actions are taken with the goal of protecting the community and preserving its safe, positive, and productive environment.
 
 
 ## Scope
 ## Scope
 
 
-This Code of Conduct applies within all community spaces and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event.
+This Code of Conduct applies to all community spaces, including forums, repositories, social media accounts, and in-person events. It also applies when an individual represents the community in public settings, such as conferences or official communications.
+
+Additionally, any behavior outside of these defined spaces that negatively impacts the community or its members may fall within the scope of this Code of Conduct.
 
 
-## Enforcement
+## Reporting Violations
 
 
-Instances of abusive, harassing, spamming, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at hello@openwebui.com. All complaints will be reviewed and investigated promptly and fairly.
+Instances of unacceptable behavior can be reported to the leadership team at **hello@openwebui.com**. Reports will be handled promptly, confidentially, and with consideration for the safety and well-being of the reporter.
 
 
-All community leaders are obligated to respect the privacy and security of the reporter of any incident.
+All community leaders are required to uphold confidentiality and impartiality when addressing reports of violations.
 
 
 ## Enforcement Guidelines
 ## Enforcement Guidelines
 
 
-Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:
+### Ban
+
+**Community Impact**: Community leaders will issue a ban to any participant whose behavior is deemed unacceptable according to this Code of Conduct. Bans are enforced immediately and without prior notice.
+
+A ban may be temporary or permanent, depending on the severity of the violation. This includes—but is not limited to—behavior such as:
+
+- Harassment or abusive behavior toward contributors.
+- Persistent negativity or hostility that disrupts the collaborative environment.
+- Disrespectful, demanding, or aggressive interactions with others.
+- Attempts to cause harm or sabotage the community.
 
 
-### 1. Temporary Ban
+**Consequence**: A banned individual is immediately removed from access to all community spaces, communication channels, and events. Community leaders reserve the right to enforce either a time-limited suspension or a permanent ban based on the specific circumstances of the violation.
 
 
-**Community Impact**: Any violation of community standards, including but not limited to inappropriate language, unprofessional behavior, harassment, or spamming.
+This approach ensures that disruptive behaviors are addressed swiftly and decisively in order to maintain the integrity and productivity of the community.
 
 
-**Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban.
+## Why Zero Tolerance Is Necessary
 
 
-### 2. Permanent Ban
+Open-source projects thrive on collaboration, goodwill, and mutual respect. Toxic behaviors—such as entitlement, hostility, or persistent negativity—threaten not just individual contributors but the health of the project as a whole. Allowing such behaviors to persist robs contributors of their time, energy, and enthusiasm for the work they do.
 
 
-**Community Impact**: Repeated or severe violations of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.
+By enforcing a zero-tolerance policy, we ensure that the community remains a safe, welcoming space for all participants. These measures are not about harshness—they are about protecting contributors and fostering a productive environment where innovation can thrive.
 
 
-**Consequence**: A permanent ban from any sort of public interaction within the community.
+Our expectations are clear, and our enforcement reflects our commitment to this project's long-term success.
 
 
 ## Attribution
 ## Attribution
 
 
-This Code of Conduct is adapted from the [Contributor Covenant][homepage],
-version 2.0, available at
+This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at  
 https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
 https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
 
 
-Community Impact Guidelines were inspired by [Mozilla's code of conduct
-enforcement ladder](https://github.com/mozilla/diversity).
+Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity).
 
 
 [homepage]: https://www.contributor-covenant.org
 [homepage]: https://www.contributor-covenant.org
 
 
-For answers to common questions about this code of conduct, see the FAQ at
-https://www.contributor-covenant.org/faq. Translations are available at
+For answers to common questions about this code of conduct, see the FAQ at  
+https://www.contributor-covenant.org/faq. Translations are available at  
 https://www.contributor-covenant.org/translations.
 https://www.contributor-covenant.org/translations.

+ 4 - 2
README.md

@@ -21,7 +21,7 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature-
 
 
 - 🤝 **Ollama/OpenAI API Integration**: Effortlessly integrate OpenAI-compatible APIs for versatile conversations alongside Ollama models. Customize the OpenAI API URL to link with **LMStudio, GroqCloud, Mistral, OpenRouter, and more**.
 - 🤝 **Ollama/OpenAI API Integration**: Effortlessly integrate OpenAI-compatible APIs for versatile conversations alongside Ollama models. Customize the OpenAI API URL to link with **LMStudio, GroqCloud, Mistral, OpenRouter, and more**.
 
 
-- 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more.
+- 🛡️ **Granular Permissions and User Groups**: By allowing administrators to create detailed user roles and permissions, we ensure a secure user environment. This granularity not only enhances security but also allows for customized user experiences, fostering a sense of ownership and responsibility amongst users.
 
 
 - 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices.
 - 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices.
 
 
@@ -37,7 +37,7 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature-
 
 
 - 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
 - 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
 
 
-- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch` and `SearchApi` and inject the results directly into your chat experience.
+- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch`, `SearchApi` and `Bing` and inject the results directly into your chat experience.
 
 
 - 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
 - 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
 
 
@@ -49,6 +49,8 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature-
 
 
 - 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors!
 - 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors!
 
 
+- 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more.
+
 - 🌟 **Continuous Updates**: We are committed to improving Open WebUI with regular updates, fixes, and new features.
 - 🌟 **Continuous Updates**: We are committed to improving Open WebUI with regular updates, fixes, and new features.
 
 
 Want to learn more about Open WebUI's features? Check out our [Open WebUI documentation](https://docs.openwebui.com/features) for a comprehensive overview!
 Want to learn more about Open WebUI's features? Check out our [Open WebUI documentation](https://docs.openwebui.com/features) for a comprehensive overview!

+ 0 - 641
backend/open_webui/apps/audio/main.py

@@ -1,641 +0,0 @@
-import hashlib
-import json
-import logging
-import os
-import uuid
-from functools import lru_cache
-from pathlib import Path
-from pydub import AudioSegment
-from pydub.silence import split_on_silence
-
-import requests
-from open_webui.config import (
-    AUDIO_STT_ENGINE,
-    AUDIO_STT_MODEL,
-    AUDIO_STT_OPENAI_API_BASE_URL,
-    AUDIO_STT_OPENAI_API_KEY,
-    AUDIO_TTS_API_KEY,
-    AUDIO_TTS_ENGINE,
-    AUDIO_TTS_MODEL,
-    AUDIO_TTS_OPENAI_API_BASE_URL,
-    AUDIO_TTS_OPENAI_API_KEY,
-    AUDIO_TTS_SPLIT_ON,
-    AUDIO_TTS_VOICE,
-    AUDIO_TTS_AZURE_SPEECH_REGION,
-    AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
-    CACHE_DIR,
-    CORS_ALLOW_ORIGIN,
-    WHISPER_MODEL,
-    WHISPER_MODEL_AUTO_UPDATE,
-    WHISPER_MODEL_DIR,
-    AppConfig,
-)
-
-from open_webui.constants import ERROR_MESSAGES
-from open_webui.env import ENV, SRC_LOG_LEVELS, DEVICE_TYPE
-from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import FileResponse
-from pydantic import BaseModel
-from open_webui.utils.utils import get_admin_user, get_verified_user
-
-# Constants
-MAX_FILE_SIZE_MB = 25
-MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024  # Convert MB to bytes
-
-
-log = logging.getLogger(__name__)
-log.setLevel(SRC_LOG_LEVELS["AUDIO"])
-
-app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None)
-
-app.add_middleware(
-    CORSMiddleware,
-    allow_origins=CORS_ALLOW_ORIGIN,
-    allow_credentials=True,
-    allow_methods=["*"],
-    allow_headers=["*"],
-)
-
-app.state.config = AppConfig()
-
-app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
-app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
-app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
-app.state.config.STT_MODEL = AUDIO_STT_MODEL
-
-app.state.config.WHISPER_MODEL = WHISPER_MODEL
-app.state.faster_whisper_model = None
-
-app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
-app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
-app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
-app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
-app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
-app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
-app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
-
-app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
-app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
-
-# setting device type for whisper model
-whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
-log.info(f"whisper_device_type: {whisper_device_type}")
-
-SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
-SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
-
-
-def set_faster_whisper_model(model: str, auto_update: bool = False):
-    if model and app.state.config.STT_ENGINE == "":
-        from faster_whisper import WhisperModel
-
-        faster_whisper_kwargs = {
-            "model_size_or_path": model,
-            "device": whisper_device_type,
-            "compute_type": "int8",
-            "download_root": WHISPER_MODEL_DIR,
-            "local_files_only": not auto_update,
-        }
-
-        try:
-            app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
-        except Exception:
-            log.warning(
-                "WhisperModel initialization failed, attempting download with local_files_only=False"
-            )
-            faster_whisper_kwargs["local_files_only"] = False
-            app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
-
-    else:
-        app.state.faster_whisper_model = None
-
-
-class TTSConfigForm(BaseModel):
-    OPENAI_API_BASE_URL: str
-    OPENAI_API_KEY: str
-    API_KEY: str
-    ENGINE: str
-    MODEL: str
-    VOICE: str
-    SPLIT_ON: str
-    AZURE_SPEECH_REGION: str
-    AZURE_SPEECH_OUTPUT_FORMAT: str
-
-
-class STTConfigForm(BaseModel):
-    OPENAI_API_BASE_URL: str
-    OPENAI_API_KEY: str
-    ENGINE: str
-    MODEL: str
-    WHISPER_MODEL: str
-
-
-class AudioConfigUpdateForm(BaseModel):
-    tts: TTSConfigForm
-    stt: STTConfigForm
-
-
-from pydub import AudioSegment
-from pydub.utils import mediainfo
-
-
-def is_mp4_audio(file_path):
-    """Check if the given file is an MP4 audio file."""
-    if not os.path.isfile(file_path):
-        print(f"File not found: {file_path}")
-        return False
-
-    info = mediainfo(file_path)
-    if (
-        info.get("codec_name") == "aac"
-        and info.get("codec_type") == "audio"
-        and info.get("codec_tag_string") == "mp4a"
-    ):
-        return True
-    return False
-
-
-def convert_mp4_to_wav(file_path, output_path):
-    """Convert MP4 audio file to WAV format."""
-    audio = AudioSegment.from_file(file_path, format="mp4")
-    audio.export(output_path, format="wav")
-    print(f"Converted {file_path} to {output_path}")
-
-
-@app.get("/config")
-async def get_audio_config(user=Depends(get_admin_user)):
-    return {
-        "tts": {
-            "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
-            "API_KEY": app.state.config.TTS_API_KEY,
-            "ENGINE": app.state.config.TTS_ENGINE,
-            "MODEL": app.state.config.TTS_MODEL,
-            "VOICE": app.state.config.TTS_VOICE,
-            "SPLIT_ON": app.state.config.TTS_SPLIT_ON,
-            "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
-            "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
-        },
-        "stt": {
-            "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
-            "ENGINE": app.state.config.STT_ENGINE,
-            "MODEL": app.state.config.STT_MODEL,
-            "WHISPER_MODEL": app.state.config.WHISPER_MODEL,
-        },
-    }
-
-
-@app.post("/config/update")
-async def update_audio_config(
-    form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
-):
-    app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
-    app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
-    app.state.config.TTS_API_KEY = form_data.tts.API_KEY
-    app.state.config.TTS_ENGINE = form_data.tts.ENGINE
-    app.state.config.TTS_MODEL = form_data.tts.MODEL
-    app.state.config.TTS_VOICE = form_data.tts.VOICE
-    app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
-    app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
-    app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
-        form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
-    )
-
-    app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
-    app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
-    app.state.config.STT_ENGINE = form_data.stt.ENGINE
-    app.state.config.STT_MODEL = form_data.stt.MODEL
-    app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
-    set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
-
-    return {
-        "tts": {
-            "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
-            "API_KEY": app.state.config.TTS_API_KEY,
-            "ENGINE": app.state.config.TTS_ENGINE,
-            "MODEL": app.state.config.TTS_MODEL,
-            "VOICE": app.state.config.TTS_VOICE,
-            "SPLIT_ON": app.state.config.TTS_SPLIT_ON,
-            "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
-            "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
-        },
-        "stt": {
-            "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
-            "ENGINE": app.state.config.STT_ENGINE,
-            "MODEL": app.state.config.STT_MODEL,
-            "WHISPER_MODEL": app.state.config.WHISPER_MODEL,
-        },
-    }
-
-
-@app.post("/speech")
-async def speech(request: Request, user=Depends(get_verified_user)):
-    body = await request.body()
-    name = hashlib.sha256(body).hexdigest()
-
-    file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
-    file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
-
-    # Check if the file already exists in the cache
-    if file_path.is_file():
-        return FileResponse(file_path)
-
-    if app.state.config.TTS_ENGINE == "openai":
-        headers = {}
-        headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
-        headers["Content-Type"] = "application/json"
-
-        try:
-            body = body.decode("utf-8")
-            body = json.loads(body)
-            body["model"] = app.state.config.TTS_MODEL
-            body = json.dumps(body).encode("utf-8")
-        except Exception:
-            pass
-
-        r = None
-        try:
-            r = requests.post(
-                url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
-                data=body,
-                headers=headers,
-                stream=True,
-            )
-
-            r.raise_for_status()
-
-            # Save the streaming content to a file
-            with open(file_path, "wb") as f:
-                for chunk in r.iter_content(chunk_size=8192):
-                    f.write(chunk)
-
-            with open(file_body_path, "w") as f:
-                json.dump(json.loads(body.decode("utf-8")), f)
-
-            # Return the saved file
-            return FileResponse(file_path)
-
-        except Exception as e:
-            log.exception(e)
-            error_detail = "Open WebUI: Server Connection Error"
-            if r is not None:
-                try:
-                    res = r.json()
-                    if "error" in res:
-                        error_detail = f"External: {res['error']['message']}"
-                except Exception:
-                    error_detail = f"External: {e}"
-
-            raise HTTPException(
-                status_code=r.status_code if r != None else 500,
-                detail=error_detail,
-            )
-
-    elif app.state.config.TTS_ENGINE == "elevenlabs":
-        payload = None
-        try:
-            payload = json.loads(body.decode("utf-8"))
-        except Exception as e:
-            log.exception(e)
-            raise HTTPException(status_code=400, detail="Invalid JSON payload")
-
-        voice_id = payload.get("voice", "")
-
-        if voice_id not in get_available_voices():
-            raise HTTPException(
-                status_code=400,
-                detail="Invalid voice id",
-            )
-
-        url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
-
-        headers = {
-            "Accept": "audio/mpeg",
-            "Content-Type": "application/json",
-            "xi-api-key": app.state.config.TTS_API_KEY,
-        }
-
-        data = {
-            "text": payload["input"],
-            "model_id": app.state.config.TTS_MODEL,
-            "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
-        }
-
-        try:
-            r = requests.post(url, json=data, headers=headers)
-
-            r.raise_for_status()
-
-            # Save the streaming content to a file
-            with open(file_path, "wb") as f:
-                for chunk in r.iter_content(chunk_size=8192):
-                    f.write(chunk)
-
-            with open(file_body_path, "w") as f:
-                json.dump(json.loads(body.decode("utf-8")), f)
-
-            # Return the saved file
-            return FileResponse(file_path)
-
-        except Exception as e:
-            log.exception(e)
-            error_detail = "Open WebUI: Server Connection Error"
-            if r is not None:
-                try:
-                    res = r.json()
-                    if "error" in res:
-                        error_detail = f"External: {res['error']['message']}"
-                except Exception:
-                    error_detail = f"External: {e}"
-
-            raise HTTPException(
-                status_code=r.status_code if r != None else 500,
-                detail=error_detail,
-            )
-
-    elif app.state.config.TTS_ENGINE == "azure":
-        payload = None
-        try:
-            payload = json.loads(body.decode("utf-8"))
-        except Exception as e:
-            log.exception(e)
-            raise HTTPException(status_code=400, detail="Invalid JSON payload")
-
-        region = app.state.config.TTS_AZURE_SPEECH_REGION
-        language = app.state.config.TTS_VOICE
-        locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1])
-        output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
-        url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
-
-        headers = {
-            "Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY,
-            "Content-Type": "application/ssml+xml",
-            "X-Microsoft-OutputFormat": output_format,
-        }
-
-        data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
-                <voice name="{language}">{payload["input"]}</voice>
-            </speak>"""
-
-        response = requests.post(url, headers=headers, data=data)
-
-        if response.status_code == 200:
-            with open(file_path, "wb") as f:
-                f.write(response.content)
-            return FileResponse(file_path)
-        else:
-            log.error(f"Error synthesizing speech - {response.reason}")
-            raise HTTPException(
-                status_code=500, detail=f"Error synthesizing speech - {response.reason}"
-            )
-
-
-def transcribe(file_path):
-    print("transcribe", file_path)
-    filename = os.path.basename(file_path)
-    file_dir = os.path.dirname(file_path)
-    id = filename.split(".")[0]
-
-    if app.state.config.STT_ENGINE == "":
-        if app.state.faster_whisper_model is None:
-            set_faster_whisper_model(app.state.config.WHISPER_MODEL)
-
-        model = app.state.faster_whisper_model
-        segments, info = model.transcribe(file_path, beam_size=5)
-        log.info(
-            "Detected language '%s' with probability %f"
-            % (info.language, info.language_probability)
-        )
-
-        transcript = "".join([segment.text for segment in list(segments)])
-        data = {"text": transcript.strip()}
-
-        # save the transcript to a json file
-        transcript_file = f"{file_dir}/{id}.json"
-        with open(transcript_file, "w") as f:
-            json.dump(data, f)
-
-        log.debug(data)
-        return data
-    elif app.state.config.STT_ENGINE == "openai":
-        if is_mp4_audio(file_path):
-            print("is_mp4_audio")
-            os.rename(file_path, file_path.replace(".wav", ".mp4"))
-            # Convert MP4 audio file to WAV format
-            convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
-
-        headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
-
-        files = {"file": (filename, open(file_path, "rb"))}
-        data = {"model": app.state.config.STT_MODEL}
-
-        log.debug(files, data)
-
-        r = None
-        try:
-            r = requests.post(
-                url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
-                headers=headers,
-                files=files,
-                data=data,
-            )
-
-            r.raise_for_status()
-
-            data = r.json()
-
-            # save the transcript to a json file
-            transcript_file = f"{file_dir}/{id}.json"
-            with open(transcript_file, "w") as f:
-                json.dump(data, f)
-
-            print(data)
-            return data
-        except Exception as e:
-            log.exception(e)
-            error_detail = "Open WebUI: Server Connection Error"
-            if r is not None:
-                try:
-                    res = r.json()
-                    if "error" in res:
-                        error_detail = f"External: {res['error']['message']}"
-                except Exception:
-                    error_detail = f"External: {e}"
-
-            raise Exception(error_detail)
-
-
-@app.post("/transcriptions")
-def transcription(
-    file: UploadFile = File(...),
-    user=Depends(get_verified_user),
-):
-    log.info(f"file.content_type: {file.content_type}")
-
-    if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
-        )
-
-    try:
-        ext = file.filename.split(".")[-1]
-        id = uuid.uuid4()
-
-        filename = f"{id}.{ext}"
-        contents = file.file.read()
-
-        file_dir = f"{CACHE_DIR}/audio/transcriptions"
-        os.makedirs(file_dir, exist_ok=True)
-        file_path = f"{file_dir}/{filename}"
-
-        with open(file_path, "wb") as f:
-            f.write(contents)
-
-        try:
-            if os.path.getsize(file_path) > MAX_FILE_SIZE:  # file is bigger than 25MB
-                log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB")
-                audio = AudioSegment.from_file(file_path)
-                audio = audio.set_frame_rate(16000).set_channels(1)  # Compress audio
-                compressed_path = f"{file_dir}/{id}_compressed.opus"
-                audio.export(compressed_path, format="opus", bitrate="32k")
-                log.debug(f"Compressed audio to {compressed_path}")
-                file_path = compressed_path
-
-                if (
-                    os.path.getsize(file_path) > MAX_FILE_SIZE
-                ):  # Still larger than 25MB after compression
-                    log.debug(
-                        f"Compressed file size is still larger than {MAX_FILE_SIZE_MB}MB: {os.path.getsize(file_path)}"
-                    )
-                    raise HTTPException(
-                        status_code=status.HTTP_400_BAD_REQUEST,
-                        detail=ERROR_MESSAGES.FILE_TOO_LARGE(
-                            size=f"{MAX_FILE_SIZE_MB}MB"
-                        ),
-                    )
-
-                data = transcribe(file_path)
-            else:
-                data = transcribe(file_path)
-
-            file_path = file_path.split("/")[-1]
-            return {**data, "filename": file_path}
-        except Exception as e:
-            log.exception(e)
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.DEFAULT(e),
-            )
-
-    except Exception as e:
-        log.exception(e)
-
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.DEFAULT(e),
-        )
-
-
-def get_available_models() -> list[dict]:
-    if app.state.config.TTS_ENGINE == "openai":
-        return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
-    elif app.state.config.TTS_ENGINE == "elevenlabs":
-        headers = {
-            "xi-api-key": app.state.config.TTS_API_KEY,
-            "Content-Type": "application/json",
-        }
-
-        try:
-            response = requests.get(
-                "https://api.elevenlabs.io/v1/models", headers=headers, timeout=5
-            )
-            response.raise_for_status()
-            models = response.json()
-            return [
-                {"name": model["name"], "id": model["model_id"]} for model in models
-            ]
-        except requests.RequestException as e:
-            log.error(f"Error fetching voices: {str(e)}")
-    return []
-
-
-@app.get("/models")
-async def get_models(user=Depends(get_verified_user)):
-    return {"models": get_available_models()}
-
-
-def get_available_voices() -> dict:
-    """Returns {voice_id: voice_name} dict"""
-    ret = {}
-    if app.state.config.TTS_ENGINE == "openai":
-        ret = {
-            "alloy": "alloy",
-            "echo": "echo",
-            "fable": "fable",
-            "onyx": "onyx",
-            "nova": "nova",
-            "shimmer": "shimmer",
-        }
-    elif app.state.config.TTS_ENGINE == "elevenlabs":
-        try:
-            ret = get_elevenlabs_voices()
-        except Exception:
-            # Avoided @lru_cache with exception
-            pass
-    elif app.state.config.TTS_ENGINE == "azure":
-        try:
-            region = app.state.config.TTS_AZURE_SPEECH_REGION
-            url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
-            headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY}
-
-            response = requests.get(url, headers=headers)
-            response.raise_for_status()
-            voices = response.json()
-            for voice in voices:
-                ret[voice["ShortName"]] = (
-                    f"{voice['DisplayName']} ({voice['ShortName']})"
-                )
-        except requests.RequestException as e:
-            log.error(f"Error fetching voices: {str(e)}")
-
-    return ret
-
-
-@lru_cache
-def get_elevenlabs_voices() -> dict:
-    """
-    Note, set the following in your .env file to use Elevenlabs:
-    AUDIO_TTS_ENGINE=elevenlabs
-    AUDIO_TTS_API_KEY=sk_...  # Your Elevenlabs API key
-    AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL  # From https://api.elevenlabs.io/v1/voices
-    AUDIO_TTS_MODEL=eleven_multilingual_v2
-    """
-    headers = {
-        "xi-api-key": app.state.config.TTS_API_KEY,
-        "Content-Type": "application/json",
-    }
-    try:
-        # TODO: Add retries
-        response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers)
-        response.raise_for_status()
-        voices_data = response.json()
-
-        voices = {}
-        for voice in voices_data.get("voices", []):
-            voices[voice["voice_id"]] = voice["name"]
-    except requests.RequestException as e:
-        # Avoid @lru_cache with exception
-        log.error(f"Error fetching voices: {str(e)}")
-        raise RuntimeError(f"Error fetching voices: {str(e)}")
-
-    return voices
-
-
-@app.get("/voices")
-async def get_voices(user=Depends(get_verified_user)):
-    return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]}

+ 0 - 1123
backend/open_webui/apps/ollama/main.py

@@ -1,1123 +0,0 @@
-import asyncio
-import json
-import logging
-import os
-import random
-import re
-import time
-from typing import Optional, Union
-from urllib.parse import urlparse
-
-import aiohttp
-import requests
-from open_webui.apps.webui.models.models import Models
-from open_webui.config import (
-    CORS_ALLOW_ORIGIN,
-    ENABLE_MODEL_FILTER,
-    ENABLE_OLLAMA_API,
-    MODEL_FILTER_LIST,
-    OLLAMA_BASE_URLS,
-    UPLOAD_DIR,
-    AppConfig,
-)
-from open_webui.env import AIOHTTP_CLIENT_TIMEOUT
-
-
-from open_webui.constants import ERROR_MESSAGES
-from open_webui.env import ENV, SRC_LOG_LEVELS
-from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import StreamingResponse
-from pydantic import BaseModel, ConfigDict
-from starlette.background import BackgroundTask
-
-
-from open_webui.utils.misc import (
-    calculate_sha256,
-)
-from open_webui.utils.payload import (
-    apply_model_params_to_body_ollama,
-    apply_model_params_to_body_openai,
-    apply_model_system_prompt_to_body,
-)
-from open_webui.utils.utils import get_admin_user, get_verified_user
-
-log = logging.getLogger(__name__)
-log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
-
-
-app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None)
-
-app.add_middleware(
-    CORSMiddleware,
-    allow_origins=CORS_ALLOW_ORIGIN,
-    allow_credentials=True,
-    allow_methods=["*"],
-    allow_headers=["*"],
-)
-
-app.state.config = AppConfig()
-
-app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
-app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
-
-app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
-app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
-app.state.MODELS = {}
-
-
-# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
-# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
-# least connections, or least response time for better resource utilization and performance optimization.
-
-
-@app.middleware("http")
-async def check_url(request: Request, call_next):
-    if len(app.state.MODELS) == 0:
-        await get_all_models()
-    else:
-        pass
-
-    response = await call_next(request)
-    return response
-
-
-@app.head("/")
-@app.get("/")
-async def get_status():
-    return {"status": True}
-
-
-@app.get("/config")
-async def get_config(user=Depends(get_admin_user)):
-    return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API}
-
-
-class OllamaConfigForm(BaseModel):
-    enable_ollama_api: Optional[bool] = None
-
-
-@app.post("/config/update")
-async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
-    app.state.config.ENABLE_OLLAMA_API = form_data.enable_ollama_api
-    return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API}
-
-
-@app.get("/urls")
-async def get_ollama_api_urls(user=Depends(get_admin_user)):
-    return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
-
-
-class UrlUpdateForm(BaseModel):
-    urls: list[str]
-
-
-@app.post("/urls/update")
-async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
-    app.state.config.OLLAMA_BASE_URLS = form_data.urls
-
-    log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}")
-    return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
-
-
-async def fetch_url(url):
-    timeout = aiohttp.ClientTimeout(total=3)
-    try:
-        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
-            async with session.get(url) as response:
-                return await response.json()
-    except Exception as e:
-        # Handle connection error here
-        log.error(f"Connection error: {e}")
-        return None
-
-
-async def cleanup_response(
-    response: Optional[aiohttp.ClientResponse],
-    session: Optional[aiohttp.ClientSession],
-):
-    if response:
-        response.close()
-    if session:
-        await session.close()
-
-
-async def post_streaming_url(
-    url: str, payload: Union[str, bytes], stream: bool = True, content_type=None
-):
-    r = None
-    try:
-        session = aiohttp.ClientSession(
-            trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
-        )
-        r = await session.post(
-            url,
-            data=payload,
-            headers={"Content-Type": "application/json"},
-        )
-        r.raise_for_status()
-
-        if stream:
-            headers = dict(r.headers)
-            if content_type:
-                headers["Content-Type"] = content_type
-            return StreamingResponse(
-                r.content,
-                status_code=r.status,
-                headers=headers,
-                background=BackgroundTask(
-                    cleanup_response, response=r, session=session
-                ),
-            )
-        else:
-            res = await r.json()
-            await cleanup_response(r, session)
-            return res
-
-    except Exception as e:
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = await r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except Exception:
-                error_detail = f"Ollama: {e}"
-
-        raise HTTPException(
-            status_code=r.status if r else 500,
-            detail=error_detail,
-        )
-
-
-def merge_models_lists(model_lists):
-    merged_models = {}
-
-    for idx, model_list in enumerate(model_lists):
-        if model_list is not None:
-            for model in model_list:
-                digest = model["digest"]
-                if digest not in merged_models:
-                    model["urls"] = [idx]
-                    merged_models[digest] = model
-                else:
-                    merged_models[digest]["urls"].append(idx)
-
-    return list(merged_models.values())
-
-
-async def get_all_models():
-    log.info("get_all_models()")
-
-    if app.state.config.ENABLE_OLLAMA_API:
-        tasks = [
-            fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS
-        ]
-        responses = await asyncio.gather(*tasks)
-
-        models = {
-            "models": merge_models_lists(
-                map(
-                    lambda response: response["models"] if response else None, responses
-                )
-            )
-        }
-
-    else:
-        models = {"models": []}
-
-    app.state.MODELS = {model["model"]: model for model in models["models"]}
-
-    return models
-
-
-@app.get("/api/tags")
-@app.get("/api/tags/{url_idx}")
-async def get_ollama_tags(
-    url_idx: Optional[int] = None, user=Depends(get_verified_user)
-):
-    if url_idx is None:
-        models = await get_all_models()
-
-        if app.state.config.ENABLE_MODEL_FILTER:
-            if user.role == "user":
-                models["models"] = list(
-                    filter(
-                        lambda model: model["name"]
-                        in app.state.config.MODEL_FILTER_LIST,
-                        models["models"],
-                    )
-                )
-                return models
-        return models
-    else:
-        url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-
-        r = None
-        try:
-            r = requests.request(method="GET", url=f"{url}/api/tags")
-            r.raise_for_status()
-
-            return r.json()
-        except Exception as e:
-            log.exception(e)
-            error_detail = "Open WebUI: Server Connection Error"
-            if r is not None:
-                try:
-                    res = r.json()
-                    if "error" in res:
-                        error_detail = f"Ollama: {res['error']}"
-                except Exception:
-                    error_detail = f"Ollama: {e}"
-
-            raise HTTPException(
-                status_code=r.status_code if r else 500,
-                detail=error_detail,
-            )
-
-
-@app.get("/api/version")
-@app.get("/api/version/{url_idx}")
-async def get_ollama_versions(url_idx: Optional[int] = None):
-    if app.state.config.ENABLE_OLLAMA_API:
-        if url_idx is None:
-            # returns lowest version
-            tasks = [
-                fetch_url(f"{url}/api/version")
-                for url in app.state.config.OLLAMA_BASE_URLS
-            ]
-            responses = await asyncio.gather(*tasks)
-            responses = list(filter(lambda x: x is not None, responses))
-
-            if len(responses) > 0:
-                lowest_version = min(
-                    responses,
-                    key=lambda x: tuple(
-                        map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
-                    ),
-                )
-
-                return {"version": lowest_version["version"]}
-            else:
-                raise HTTPException(
-                    status_code=500,
-                    detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
-                )
-        else:
-            url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-
-            r = None
-            try:
-                r = requests.request(method="GET", url=f"{url}/api/version")
-                r.raise_for_status()
-
-                return r.json()
-            except Exception as e:
-                log.exception(e)
-                error_detail = "Open WebUI: Server Connection Error"
-                if r is not None:
-                    try:
-                        res = r.json()
-                        if "error" in res:
-                            error_detail = f"Ollama: {res['error']}"
-                    except Exception:
-                        error_detail = f"Ollama: {e}"
-
-                raise HTTPException(
-                    status_code=r.status_code if r else 500,
-                    detail=error_detail,
-                )
-    else:
-        return {"version": False}
-
-
-class ModelNameForm(BaseModel):
-    name: str
-
-
-@app.post("/api/pull")
-@app.post("/api/pull/{url_idx}")
-async def pull_model(
-    form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
-):
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
-
-    # Admin should be able to pull models from any source
-    payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
-
-    return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
-
-
-class PushModelForm(BaseModel):
-    name: str
-    insecure: Optional[bool] = None
-    stream: Optional[bool] = None
-
-
-@app.delete("/api/push")
-@app.delete("/api/push/{url_idx}")
-async def push_model(
-    form_data: PushModelForm,
-    url_idx: Optional[int] = None,
-    user=Depends(get_admin_user),
-):
-    if url_idx is None:
-        if form_data.name in app.state.MODELS:
-            url_idx = app.state.MODELS[form_data.name]["urls"][0]
-        else:
-            raise HTTPException(
-                status_code=400,
-                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
-            )
-
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.debug(f"url: {url}")
-
-    return await post_streaming_url(
-        f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
-    )
-
-
-class CreateModelForm(BaseModel):
-    name: str
-    modelfile: Optional[str] = None
-    stream: Optional[bool] = None
-    path: Optional[str] = None
-
-
-@app.post("/api/create")
-@app.post("/api/create/{url_idx}")
-async def create_model(
-    form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
-):
-    log.debug(f"form_data: {form_data}")
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
-
-    return await post_streaming_url(
-        f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
-    )
-
-
-class CopyModelForm(BaseModel):
-    source: str
-    destination: str
-
-
-@app.post("/api/copy")
-@app.post("/api/copy/{url_idx}")
-async def copy_model(
-    form_data: CopyModelForm,
-    url_idx: Optional[int] = None,
-    user=Depends(get_admin_user),
-):
-    if url_idx is None:
-        if form_data.source in app.state.MODELS:
-            url_idx = app.state.MODELS[form_data.source]["urls"][0]
-        else:
-            raise HTTPException(
-                status_code=400,
-                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
-            )
-
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
-    r = requests.request(
-        method="POST",
-        url=f"{url}/api/copy",
-        headers={"Content-Type": "application/json"},
-        data=form_data.model_dump_json(exclude_none=True).encode(),
-    )
-
-    try:
-        r.raise_for_status()
-
-        log.debug(f"r.text: {r.text}")
-
-        return True
-    except Exception as e:
-        log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except Exception:
-                error_detail = f"Ollama: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500,
-            detail=error_detail,
-        )
-
-
-@app.delete("/api/delete")
-@app.delete("/api/delete/{url_idx}")
-async def delete_model(
-    form_data: ModelNameForm,
-    url_idx: Optional[int] = None,
-    user=Depends(get_admin_user),
-):
-    if url_idx is None:
-        if form_data.name in app.state.MODELS:
-            url_idx = app.state.MODELS[form_data.name]["urls"][0]
-        else:
-            raise HTTPException(
-                status_code=400,
-                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
-            )
-
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
-
-    r = requests.request(
-        method="DELETE",
-        url=f"{url}/api/delete",
-        headers={"Content-Type": "application/json"},
-        data=form_data.model_dump_json(exclude_none=True).encode(),
-    )
-    try:
-        r.raise_for_status()
-
-        log.debug(f"r.text: {r.text}")
-
-        return True
-    except Exception as e:
-        log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except Exception:
-                error_detail = f"Ollama: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500,
-            detail=error_detail,
-        )
-
-
-@app.post("/api/show")
-async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
-    if form_data.name not in app.state.MODELS:
-        raise HTTPException(
-            status_code=400,
-            detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
-        )
-
-    url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
-
-    r = requests.request(
-        method="POST",
-        url=f"{url}/api/show",
-        headers={"Content-Type": "application/json"},
-        data=form_data.model_dump_json(exclude_none=True).encode(),
-    )
-    try:
-        r.raise_for_status()
-
-        return r.json()
-    except Exception as e:
-        log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except Exception:
-                error_detail = f"Ollama: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500,
-            detail=error_detail,
-        )
-
-
-class GenerateEmbeddingsForm(BaseModel):
-    model: str
-    prompt: str
-    options: Optional[dict] = None
-    keep_alive: Optional[Union[int, str]] = None
-
-
-class GenerateEmbedForm(BaseModel):
-    model: str
-    input: list[str] | str
-    truncate: Optional[bool] = None
-    options: Optional[dict] = None
-    keep_alive: Optional[Union[int, str]] = None
-
-
-@app.post("/api/embed")
-@app.post("/api/embed/{url_idx}")
-async def generate_embeddings(
-    form_data: GenerateEmbedForm,
-    url_idx: Optional[int] = None,
-    user=Depends(get_verified_user),
-):
-    return generate_ollama_batch_embeddings(form_data, url_idx)
-
-
-@app.post("/api/embeddings")
-@app.post("/api/embeddings/{url_idx}")
-async def generate_embeddings(
-    form_data: GenerateEmbeddingsForm,
-    url_idx: Optional[int] = None,
-    user=Depends(get_verified_user),
-):
-    return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
-
-
-def generate_ollama_embeddings(
-    form_data: GenerateEmbeddingsForm,
-    url_idx: Optional[int] = None,
-):
-    log.info(f"generate_ollama_embeddings {form_data}")
-
-    if url_idx is None:
-        model = form_data.model
-
-        if ":" not in model:
-            model = f"{model}:latest"
-
-        if model in app.state.MODELS:
-            url_idx = random.choice(app.state.MODELS[model]["urls"])
-        else:
-            raise HTTPException(
-                status_code=400,
-                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
-            )
-
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
-
-    r = requests.request(
-        method="POST",
-        url=f"{url}/api/embeddings",
-        headers={"Content-Type": "application/json"},
-        data=form_data.model_dump_json(exclude_none=True).encode(),
-    )
-    try:
-        r.raise_for_status()
-
-        data = r.json()
-
-        log.info(f"generate_ollama_embeddings {data}")
-
-        if "embedding" in data:
-            return data
-        else:
-            raise Exception("Something went wrong :/")
-    except Exception as e:
-        log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except Exception:
-                error_detail = f"Ollama: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500,
-            detail=error_detail,
-        )
-
-
-def generate_ollama_batch_embeddings(
-    form_data: GenerateEmbedForm,
-    url_idx: Optional[int] = None,
-):
-    log.info(f"generate_ollama_batch_embeddings {form_data}")
-
-    if url_idx is None:
-        model = form_data.model
-
-        if ":" not in model:
-            model = f"{model}:latest"
-
-        if model in app.state.MODELS:
-            url_idx = random.choice(app.state.MODELS[model]["urls"])
-        else:
-            raise HTTPException(
-                status_code=400,
-                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
-            )
-
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
-
-    r = requests.request(
-        method="POST",
-        url=f"{url}/api/embed",
-        headers={"Content-Type": "application/json"},
-        data=form_data.model_dump_json(exclude_none=True).encode(),
-    )
-    try:
-        r.raise_for_status()
-
-        data = r.json()
-
-        log.info(f"generate_ollama_batch_embeddings {data}")
-
-        if "embeddings" in data:
-            return data
-        else:
-            raise Exception("Something went wrong :/")
-    except Exception as e:
-        log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except Exception:
-                error_detail = f"Ollama: {e}"
-
-        raise Exception(error_detail)
-
-
-class GenerateCompletionForm(BaseModel):
-    model: str
-    prompt: str
-    images: Optional[list[str]] = None
-    format: Optional[str] = None
-    options: Optional[dict] = None
-    system: Optional[str] = None
-    template: Optional[str] = None
-    context: Optional[str] = None
-    stream: Optional[bool] = True
-    raw: Optional[bool] = None
-    keep_alive: Optional[Union[int, str]] = None
-
-
-@app.post("/api/generate")
-@app.post("/api/generate/{url_idx}")
-async def generate_completion(
-    form_data: GenerateCompletionForm,
-    url_idx: Optional[int] = None,
-    user=Depends(get_verified_user),
-):
-    if url_idx is None:
-        model = form_data.model
-
-        if ":" not in model:
-            model = f"{model}:latest"
-
-        if model in app.state.MODELS:
-            url_idx = random.choice(app.state.MODELS[model]["urls"])
-        else:
-            raise HTTPException(
-                status_code=400,
-                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
-            )
-
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
-
-    return await post_streaming_url(
-        f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
-    )
-
-
-class ChatMessage(BaseModel):
-    role: str
-    content: str
-    images: Optional[list[str]] = None
-
-
-class GenerateChatCompletionForm(BaseModel):
-    model: str
-    messages: list[ChatMessage]
-    format: Optional[str] = None
-    options: Optional[dict] = None
-    template: Optional[str] = None
-    stream: Optional[bool] = True
-    keep_alive: Optional[Union[int, str]] = None
-
-
-def get_ollama_url(url_idx: Optional[int], model: str):
-    if url_idx is None:
-        if model not in app.state.MODELS:
-            raise HTTPException(
-                status_code=400,
-                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
-            )
-        url_idx = random.choice(app.state.MODELS[model]["urls"])
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-    return url
-
-
-@app.post("/api/chat")
-@app.post("/api/chat/{url_idx}")
-async def generate_chat_completion(
-    form_data: GenerateChatCompletionForm,
-    url_idx: Optional[int] = None,
-    user=Depends(get_verified_user),
-    bypass_filter: Optional[bool] = False,
-):
-    payload = {**form_data.model_dump(exclude_none=True)}
-    log.debug(f"generate_chat_completion() - 1.payload = {payload}")
-    if "metadata" in payload:
-        del payload["metadata"]
-
-    model_id = form_data.model
-
-    if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER:
-        if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
-            raise HTTPException(
-                status_code=403,
-                detail="Model not found",
-            )
-
-    model_info = Models.get_model_by_id(model_id)
-
-    if model_info:
-        if model_info.base_model_id:
-            payload["model"] = model_info.base_model_id
-
-        params = model_info.params.model_dump()
-
-        if params:
-            if payload.get("options") is None:
-                payload["options"] = {}
-
-            payload["options"] = apply_model_params_to_body_ollama(
-                params, payload["options"]
-            )
-            payload = apply_model_system_prompt_to_body(params, payload, user)
-
-    if ":" not in payload["model"]:
-        payload["model"] = f"{payload['model']}:latest"
-
-    url = get_ollama_url(url_idx, payload["model"])
-    log.info(f"url: {url}")
-    log.debug(f"generate_chat_completion() - 2.payload = {payload}")
-
-    return await post_streaming_url(
-        f"{url}/api/chat",
-        json.dumps(payload),
-        stream=form_data.stream,
-        content_type="application/x-ndjson",
-    )
-
-
-# TODO: we should update this part once Ollama supports other types
-class OpenAIChatMessageContent(BaseModel):
-    type: str
-    model_config = ConfigDict(extra="allow")
-
-
-class OpenAIChatMessage(BaseModel):
-    role: str
-    content: Union[str, OpenAIChatMessageContent]
-
-    model_config = ConfigDict(extra="allow")
-
-
-class OpenAIChatCompletionForm(BaseModel):
-    model: str
-    messages: list[OpenAIChatMessage]
-
-    model_config = ConfigDict(extra="allow")
-
-
-@app.post("/v1/chat/completions")
-@app.post("/v1/chat/completions/{url_idx}")
-async def generate_openai_chat_completion(
-    form_data: dict,
-    url_idx: Optional[int] = None,
-    user=Depends(get_verified_user),
-):
-    completion_form = OpenAIChatCompletionForm(**form_data)
-    payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
-    if "metadata" in payload:
-        del payload["metadata"]
-
-    model_id = completion_form.model
-
-    if app.state.config.ENABLE_MODEL_FILTER:
-        if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
-            raise HTTPException(
-                status_code=403,
-                detail="Model not found",
-            )
-
-    model_info = Models.get_model_by_id(model_id)
-
-    if model_info:
-        if model_info.base_model_id:
-            payload["model"] = model_info.base_model_id
-
-        params = model_info.params.model_dump()
-
-        if params:
-            payload = apply_model_params_to_body_openai(params, payload)
-            payload = apply_model_system_prompt_to_body(params, payload, user)
-
-    if ":" not in payload["model"]:
-        payload["model"] = f"{payload['model']}:latest"
-
-    url = get_ollama_url(url_idx, payload["model"])
-    log.info(f"url: {url}")
-
-    return await post_streaming_url(
-        f"{url}/v1/chat/completions",
-        json.dumps(payload),
-        stream=payload.get("stream", False),
-    )
-
-
-@app.get("/v1/models")
-@app.get("/v1/models/{url_idx}")
-async def get_openai_models(
-    url_idx: Optional[int] = None,
-    user=Depends(get_verified_user),
-):
-    if url_idx is None:
-        models = await get_all_models()
-
-        if app.state.config.ENABLE_MODEL_FILTER:
-            if user.role == "user":
-                models["models"] = list(
-                    filter(
-                        lambda model: model["name"]
-                        in app.state.config.MODEL_FILTER_LIST,
-                        models["models"],
-                    )
-                )
-
-        return {
-            "data": [
-                {
-                    "id": model["model"],
-                    "object": "model",
-                    "created": int(time.time()),
-                    "owned_by": "openai",
-                }
-                for model in models["models"]
-            ],
-            "object": "list",
-        }
-
-    else:
-        url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-        try:
-            r = requests.request(method="GET", url=f"{url}/api/tags")
-            r.raise_for_status()
-
-            models = r.json()
-
-            return {
-                "data": [
-                    {
-                        "id": model["model"],
-                        "object": "model",
-                        "created": int(time.time()),
-                        "owned_by": "openai",
-                    }
-                    for model in models["models"]
-                ],
-                "object": "list",
-            }
-
-        except Exception as e:
-            log.exception(e)
-            error_detail = "Open WebUI: Server Connection Error"
-            if r is not None:
-                try:
-                    res = r.json()
-                    if "error" in res:
-                        error_detail = f"Ollama: {res['error']}"
-                except Exception:
-                    error_detail = f"Ollama: {e}"
-
-            raise HTTPException(
-                status_code=r.status_code if r else 500,
-                detail=error_detail,
-            )
-
-
-class UrlForm(BaseModel):
-    url: str
-
-
-class UploadBlobForm(BaseModel):
-    filename: str
-
-
-def parse_huggingface_url(hf_url):
-    try:
-        # Parse the URL
-        parsed_url = urlparse(hf_url)
-
-        # Get the path and split it into components
-        path_components = parsed_url.path.split("/")
-
-        # Extract the desired output
-        model_file = path_components[-1]
-
-        return model_file
-    except ValueError:
-        return None
-
-
-async def download_file_stream(
-    ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
-):
-    done = False
-
-    if os.path.exists(file_path):
-        current_size = os.path.getsize(file_path)
-    else:
-        current_size = 0
-
-    headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
-
-    timeout = aiohttp.ClientTimeout(total=600)  # Set the timeout
-
-    async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
-        async with session.get(file_url, headers=headers) as response:
-            total_size = int(response.headers.get("content-length", 0)) + current_size
-
-            with open(file_path, "ab+") as file:
-                async for data in response.content.iter_chunked(chunk_size):
-                    current_size += len(data)
-                    file.write(data)
-
-                    done = current_size == total_size
-                    progress = round((current_size / total_size) * 100, 2)
-
-                    yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
-
-                if done:
-                    file.seek(0)
-                    hashed = calculate_sha256(file)
-                    file.seek(0)
-
-                    url = f"{ollama_url}/api/blobs/sha256:{hashed}"
-                    response = requests.post(url, data=file)
-
-                    if response.ok:
-                        res = {
-                            "done": done,
-                            "blob": f"sha256:{hashed}",
-                            "name": file_name,
-                        }
-                        os.remove(file_path)
-
-                        yield f"data: {json.dumps(res)}\n\n"
-                    else:
-                        raise "Ollama: Could not create blob, Please try again."
-
-
-# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
-@app.post("/models/download")
-@app.post("/models/download/{url_idx}")
-async def download_model(
-    form_data: UrlForm,
-    url_idx: Optional[int] = None,
-    user=Depends(get_admin_user),
-):
-    allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
-
-    if not any(form_data.url.startswith(host) for host in allowed_hosts):
-        raise HTTPException(
-            status_code=400,
-            detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
-        )
-
-    if url_idx is None:
-        url_idx = 0
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-
-    file_name = parse_huggingface_url(form_data.url)
-
-    if file_name:
-        file_path = f"{UPLOAD_DIR}/{file_name}"
-
-        return StreamingResponse(
-            download_file_stream(url, form_data.url, file_path, file_name),
-        )
-    else:
-        return None
-
-
-@app.post("/models/upload")
-@app.post("/models/upload/{url_idx}")
-def upload_model(
-    file: UploadFile = File(...),
-    url_idx: Optional[int] = None,
-    user=Depends(get_admin_user),
-):
-    if url_idx is None:
-        url_idx = 0
-    ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-
-    file_path = f"{UPLOAD_DIR}/{file.filename}"
-
-    # Save file in chunks
-    with open(file_path, "wb+") as f:
-        for chunk in file.file:
-            f.write(chunk)
-
-    def file_process_stream():
-        nonlocal ollama_url
-        total_size = os.path.getsize(file_path)
-        chunk_size = 1024 * 1024
-        try:
-            with open(file_path, "rb") as f:
-                total = 0
-                done = False
-
-                while not done:
-                    chunk = f.read(chunk_size)
-                    if not chunk:
-                        done = True
-                        continue
-
-                    total += len(chunk)
-                    progress = round((total / total_size) * 100, 2)
-
-                    res = {
-                        "progress": progress,
-                        "total": total_size,
-                        "completed": total,
-                    }
-                    yield f"data: {json.dumps(res)}\n\n"
-
-                if done:
-                    f.seek(0)
-                    hashed = calculate_sha256(f)
-                    f.seek(0)
-
-                    url = f"{ollama_url}/api/blobs/sha256:{hashed}"
-                    response = requests.post(url, data=f)
-
-                    if response.ok:
-                        res = {
-                            "done": done,
-                            "blob": f"sha256:{hashed}",
-                            "name": file.filename,
-                        }
-                        os.remove(file_path)
-                        yield f"data: {json.dumps(res)}\n\n"
-                    else:
-                        raise Exception(
-                            "Ollama: Could not create blob, Please try again."
-                        )
-
-        except Exception as e:
-            res = {"error": str(e)}
-            yield f"data: {json.dumps(res)}\n\n"
-
-    return StreamingResponse(file_process_stream(), media_type="text/event-stream")

+ 0 - 557
backend/open_webui/apps/openai/main.py

@@ -1,557 +0,0 @@
-import asyncio
-import hashlib
-import json
-import logging
-from pathlib import Path
-from typing import Literal, Optional, overload
-
-import aiohttp
-import requests
-from open_webui.apps.webui.models.models import Models
-from open_webui.config import (
-    CACHE_DIR,
-    CORS_ALLOW_ORIGIN,
-    ENABLE_MODEL_FILTER,
-    ENABLE_OPENAI_API,
-    MODEL_FILTER_LIST,
-    OPENAI_API_BASE_URLS,
-    OPENAI_API_KEYS,
-    AppConfig,
-)
-from open_webui.env import (
-    AIOHTTP_CLIENT_TIMEOUT,
-    AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
-)
-
-from open_webui.constants import ERROR_MESSAGES
-from open_webui.env import ENV, SRC_LOG_LEVELS
-from fastapi import Depends, FastAPI, HTTPException, Request
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import FileResponse, StreamingResponse
-from pydantic import BaseModel
-from starlette.background import BackgroundTask
-
-from open_webui.utils.payload import (
-    apply_model_params_to_body_openai,
-    apply_model_system_prompt_to_body,
-)
-
-from open_webui.utils.utils import get_admin_user, get_verified_user
-
-log = logging.getLogger(__name__)
-log.setLevel(SRC_LOG_LEVELS["OPENAI"])
-
-
-app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None)
-
-
-app.add_middleware(
-    CORSMiddleware,
-    allow_origins=CORS_ALLOW_ORIGIN,
-    allow_credentials=True,
-    allow_methods=["*"],
-    allow_headers=["*"],
-)
-
-app.state.config = AppConfig()
-
-app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
-app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
-
-app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
-app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
-app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
-
-app.state.MODELS = {}
-
-
-@app.middleware("http")
-async def check_url(request: Request, call_next):
-    if len(app.state.MODELS) == 0:
-        await get_all_models()
-
-    response = await call_next(request)
-    return response
-
-
-@app.get("/config")
-async def get_config(user=Depends(get_admin_user)):
-    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
-
-
-class OpenAIConfigForm(BaseModel):
-    enable_openai_api: Optional[bool] = None
-
-
-@app.post("/config/update")
-async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
-    app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
-    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
-
-
-class UrlsUpdateForm(BaseModel):
-    urls: list[str]
-
-
-class KeysUpdateForm(BaseModel):
-    keys: list[str]
-
-
-@app.get("/urls")
-async def get_openai_urls(user=Depends(get_admin_user)):
-    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
-
-
-@app.post("/urls/update")
-async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
-    await get_all_models()
-    app.state.config.OPENAI_API_BASE_URLS = form_data.urls
-    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
-
-
-@app.get("/keys")
-async def get_openai_keys(user=Depends(get_admin_user)):
-    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
-
-
-@app.post("/keys/update")
-async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
-    app.state.config.OPENAI_API_KEYS = form_data.keys
-    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
-
-
-@app.post("/audio/speech")
-async def speech(request: Request, user=Depends(get_verified_user)):
-    idx = None
-    try:
-        idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
-        body = await request.body()
-        name = hashlib.sha256(body).hexdigest()
-
-        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
-        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
-        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
-        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
-
-        # Check if the file already exists in the cache
-        if file_path.is_file():
-            return FileResponse(file_path)
-
-        headers = {}
-        headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
-        headers["Content-Type"] = "application/json"
-        if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
-            headers["HTTP-Referer"] = "https://openwebui.com/"
-            headers["X-Title"] = "Open WebUI"
-        r = None
-        try:
-            r = requests.post(
-                url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
-                data=body,
-                headers=headers,
-                stream=True,
-            )
-
-            r.raise_for_status()
-
-            # Save the streaming content to a file
-            with open(file_path, "wb") as f:
-                for chunk in r.iter_content(chunk_size=8192):
-                    f.write(chunk)
-
-            with open(file_body_path, "w") as f:
-                json.dump(json.loads(body.decode("utf-8")), f)
-
-            # Return the saved file
-            return FileResponse(file_path)
-
-        except Exception as e:
-            log.exception(e)
-            error_detail = "Open WebUI: Server Connection Error"
-            if r is not None:
-                try:
-                    res = r.json()
-                    if "error" in res:
-                        error_detail = f"External: {res['error']}"
-                except Exception:
-                    error_detail = f"External: {e}"
-
-            raise HTTPException(
-                status_code=r.status_code if r else 500, detail=error_detail
-            )
-
-    except ValueError:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
-
-
-async def fetch_url(url, key):
-    timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
-    try:
-        headers = {"Authorization": f"Bearer {key}"}
-        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
-            async with session.get(url, headers=headers) as response:
-                return await response.json()
-    except Exception as e:
-        # Handle connection error here
-        log.error(f"Connection error: {e}")
-        return None
-
-
-async def cleanup_response(
-    response: Optional[aiohttp.ClientResponse],
-    session: Optional[aiohttp.ClientSession],
-):
-    if response:
-        response.close()
-    if session:
-        await session.close()
-
-
-def merge_models_lists(model_lists):
-    log.debug(f"merge_models_lists {model_lists}")
-    merged_list = []
-
-    for idx, models in enumerate(model_lists):
-        if models is not None and "error" not in models:
-            merged_list.extend(
-                [
-                    {
-                        **model,
-                        "name": model.get("name", model["id"]),
-                        "owned_by": "openai",
-                        "openai": model,
-                        "urlIdx": idx,
-                    }
-                    for model in models
-                    if "api.openai.com"
-                    not in app.state.config.OPENAI_API_BASE_URLS[idx]
-                    or not any(
-                        name in model["id"]
-                        for name in [
-                            "babbage",
-                            "dall-e",
-                            "davinci",
-                            "embedding",
-                            "tts",
-                            "whisper",
-                        ]
-                    )
-                ]
-            )
-
-    return merged_list
-
-
-def is_openai_api_disabled():
-    return not app.state.config.ENABLE_OPENAI_API
-
-
-async def get_all_models_raw() -> list:
-    if is_openai_api_disabled():
-        return []
-
-    # Check if API KEYS length is same than API URLS length
-    num_urls = len(app.state.config.OPENAI_API_BASE_URLS)
-    num_keys = len(app.state.config.OPENAI_API_KEYS)
-
-    if num_keys != num_urls:
-        # if there are more keys than urls, remove the extra keys
-        if num_keys > num_urls:
-            new_keys = app.state.config.OPENAI_API_KEYS[:num_urls]
-            app.state.config.OPENAI_API_KEYS = new_keys
-        # if there are more urls than keys, add empty keys
-        else:
-            app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
-
-    tasks = [
-        fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
-        for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
-    ]
-
-    responses = await asyncio.gather(*tasks)
-    log.debug(f"get_all_models:responses() {responses}")
-
-    return responses
-
-
-@overload
-async def get_all_models(raw: Literal[True]) -> list: ...
-
-
-@overload
-async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ...
-
-
-async def get_all_models(raw=False) -> dict[str, list] | list:
-    log.info("get_all_models()")
-    if is_openai_api_disabled():
-        return [] if raw else {"data": []}
-
-    responses = await get_all_models_raw()
-    if raw:
-        return responses
-
-    def extract_data(response):
-        if response and "data" in response:
-            return response["data"]
-        if isinstance(response, list):
-            return response
-        return None
-
-    models = {"data": merge_models_lists(map(extract_data, responses))}
-
-    log.debug(f"models: {models}")
-    app.state.MODELS = {model["id"]: model for model in models["data"]}
-
-    return models
-
-
-@app.get("/models")
-@app.get("/models/{url_idx}")
-async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
-    if url_idx is None:
-        models = await get_all_models()
-        if app.state.config.ENABLE_MODEL_FILTER:
-            if user.role == "user":
-                models["data"] = list(
-                    filter(
-                        lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
-                        models["data"],
-                    )
-                )
-                return models
-        return models
-    else:
-        url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
-        key = app.state.config.OPENAI_API_KEYS[url_idx]
-
-        headers = {}
-        headers["Authorization"] = f"Bearer {key}"
-        headers["Content-Type"] = "application/json"
-
-        r = None
-
-        try:
-            r = requests.request(method="GET", url=f"{url}/models", headers=headers)
-            r.raise_for_status()
-
-            response_data = r.json()
-
-            if "api.openai.com" in url:
-                # Filter the response data
-                response_data["data"] = [
-                    model
-                    for model in response_data["data"]
-                    if not any(
-                        name in model["id"]
-                        for name in [
-                            "babbage",
-                            "dall-e",
-                            "davinci",
-                            "embedding",
-                            "tts",
-                            "whisper",
-                        ]
-                    )
-                ]
-
-            return response_data
-        except Exception as e:
-            log.exception(e)
-            error_detail = "Open WebUI: Server Connection Error"
-            if r is not None:
-                try:
-                    res = r.json()
-                    if "error" in res:
-                        error_detail = f"External: {res['error']}"
-                except Exception:
-                    error_detail = f"External: {e}"
-
-            raise HTTPException(
-                status_code=r.status_code if r else 500,
-                detail=error_detail,
-            )
-
-
-@app.post("/chat/completions")
-@app.post("/chat/completions/{url_idx}")
-async def generate_chat_completion(
-    form_data: dict,
-    url_idx: Optional[int] = None,
-    user=Depends(get_verified_user),
-):
-    idx = 0
-    payload = {**form_data}
-
-    if "metadata" in payload:
-        del payload["metadata"]
-
-    model_id = form_data.get("model")
-    model_info = Models.get_model_by_id(model_id)
-
-    if model_info:
-        if model_info.base_model_id:
-            payload["model"] = model_info.base_model_id
-
-        params = model_info.params.model_dump()
-        payload = apply_model_params_to_body_openai(params, payload)
-        payload = apply_model_system_prompt_to_body(params, payload, user)
-
-    model = app.state.MODELS[payload.get("model")]
-    idx = model["urlIdx"]
-
-    if "pipeline" in model and model.get("pipeline"):
-        payload["user"] = {
-            "name": user.name,
-            "id": user.id,
-            "email": user.email,
-            "role": user.role,
-        }
-
-    url = app.state.config.OPENAI_API_BASE_URLS[idx]
-    key = app.state.config.OPENAI_API_KEYS[idx]
-    is_o1 = payload["model"].lower().startswith("o1-")
-
-    # Change max_completion_tokens to max_tokens (Backward compatible)
-    if "api.openai.com" not in url and not is_o1:
-        if "max_completion_tokens" in payload:
-            # Remove "max_completion_tokens" from the payload
-            payload["max_tokens"] = payload["max_completion_tokens"]
-            del payload["max_completion_tokens"]
-    else:
-        if is_o1 and "max_tokens" in payload:
-            payload["max_completion_tokens"] = payload["max_tokens"]
-            del payload["max_tokens"]
-        if "max_tokens" in payload and "max_completion_tokens" in payload:
-            del payload["max_tokens"]
-
-    # Fix: O1 does not support the "system" parameter, Modify "system" to "user"
-    if is_o1 and payload["messages"][0]["role"] == "system":
-        payload["messages"][0]["role"] = "user"
-
-    # Convert the modified body back to JSON
-    payload = json.dumps(payload)
-
-    log.debug(payload)
-
-    headers = {}
-    headers["Authorization"] = f"Bearer {key}"
-    headers["Content-Type"] = "application/json"
-    if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
-        headers["HTTP-Referer"] = "https://openwebui.com/"
-        headers["X-Title"] = "Open WebUI"
-
-    r = None
-    session = None
-    streaming = False
-    response = None
-
-    try:
-        session = aiohttp.ClientSession(
-            trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
-        )
-        r = await session.request(
-            method="POST",
-            url=f"{url}/chat/completions",
-            data=payload,
-            headers=headers,
-        )
-
-        # Check if response is SSE
-        if "text/event-stream" in r.headers.get("Content-Type", ""):
-            streaming = True
-            return StreamingResponse(
-                r.content,
-                status_code=r.status,
-                headers=dict(r.headers),
-                background=BackgroundTask(
-                    cleanup_response, response=r, session=session
-                ),
-            )
-        else:
-            try:
-                response = await r.json()
-            except Exception as e:
-                log.error(e)
-                response = await r.text()
-
-            r.raise_for_status()
-            return response
-    except Exception as e:
-        log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
-        if isinstance(response, dict):
-            if "error" in response:
-                error_detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}"
-        elif isinstance(response, str):
-            error_detail = response
-
-        raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
-    finally:
-        if not streaming and session:
-            if r:
-                r.close()
-            await session.close()
-
-
-@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
-async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
-    idx = 0
-
-    body = await request.body()
-
-    url = app.state.config.OPENAI_API_BASE_URLS[idx]
-    key = app.state.config.OPENAI_API_KEYS[idx]
-
-    target_url = f"{url}/{path}"
-
-    headers = {}
-    headers["Authorization"] = f"Bearer {key}"
-    headers["Content-Type"] = "application/json"
-
-    r = None
-    session = None
-    streaming = False
-
-    try:
-        session = aiohttp.ClientSession(trust_env=True)
-        r = await session.request(
-            method=request.method,
-            url=target_url,
-            data=body,
-            headers=headers,
-        )
-
-        r.raise_for_status()
-
-        # Check if response is SSE
-        if "text/event-stream" in r.headers.get("Content-Type", ""):
-            streaming = True
-            return StreamingResponse(
-                r.content,
-                status_code=r.status,
-                headers=dict(r.headers),
-                background=BackgroundTask(
-                    cleanup_response, response=r, session=session
-                ),
-            )
-        else:
-            response_data = await r.json()
-            return response_data
-    except Exception as e:
-        log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = await r.json()
-                print(res)
-                if "error" in res:
-                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
-            except Exception:
-                error_detail = f"External: {e}"
-        raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
-    finally:
-        if not streaming and session:
-            if r:
-                r.close()
-            await session.close()

+ 0 - 1332
backend/open_webui/apps/retrieval/main.py

@@ -1,1332 +0,0 @@
-# TODO: Merge this with the webui_app and make it a single app
-
-import json
-import logging
-import mimetypes
-import os
-import shutil
-
-import uuid
-from datetime import datetime
-from pathlib import Path
-from typing import Iterator, Optional, Sequence, Union
-
-from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
-from fastapi.middleware.cors import CORSMiddleware
-from pydantic import BaseModel
-import tiktoken
-
-
-from open_webui.storage.provider import Storage
-from open_webui.apps.webui.models.knowledge import Knowledges
-from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
-
-# Document loaders
-from open_webui.apps.retrieval.loaders.main import Loader
-
-# Web search engines
-from open_webui.apps.retrieval.web.main import SearchResult
-from open_webui.apps.retrieval.web.utils import get_web_loader
-from open_webui.apps.retrieval.web.brave import search_brave
-from open_webui.apps.retrieval.web.duckduckgo import search_duckduckgo
-from open_webui.apps.retrieval.web.google_pse import search_google_pse
-from open_webui.apps.retrieval.web.jina_search import search_jina
-from open_webui.apps.retrieval.web.searchapi import search_searchapi
-from open_webui.apps.retrieval.web.searxng import search_searxng
-from open_webui.apps.retrieval.web.serper import search_serper
-from open_webui.apps.retrieval.web.serply import search_serply
-from open_webui.apps.retrieval.web.serpstack import search_serpstack
-from open_webui.apps.retrieval.web.tavily import search_tavily
-
-
-from open_webui.apps.retrieval.utils import (
-    get_embedding_function,
-    get_model_path,
-    query_collection,
-    query_collection_with_hybrid_search,
-    query_doc,
-    query_doc_with_hybrid_search,
-)
-
-from open_webui.apps.webui.models.files import Files
-from open_webui.config import (
-    BRAVE_SEARCH_API_KEY,
-    TIKTOKEN_ENCODING_NAME,
-    RAG_TEXT_SPLITTER,
-    CHUNK_OVERLAP,
-    CHUNK_SIZE,
-    CONTENT_EXTRACTION_ENGINE,
-    CORS_ALLOW_ORIGIN,
-    ENABLE_RAG_HYBRID_SEARCH,
-    ENABLE_RAG_LOCAL_WEB_FETCH,
-    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
-    ENABLE_RAG_WEB_SEARCH,
-    ENV,
-    GOOGLE_PSE_API_KEY,
-    GOOGLE_PSE_ENGINE_ID,
-    PDF_EXTRACT_IMAGES,
-    RAG_EMBEDDING_ENGINE,
-    RAG_EMBEDDING_MODEL,
-    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
-    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
-    RAG_EMBEDDING_BATCH_SIZE,
-    RAG_FILE_MAX_COUNT,
-    RAG_FILE_MAX_SIZE,
-    RAG_OPENAI_API_BASE_URL,
-    RAG_OPENAI_API_KEY,
-    RAG_RELEVANCE_THRESHOLD,
-    RAG_RERANKING_MODEL,
-    RAG_RERANKING_MODEL_AUTO_UPDATE,
-    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
-    DEFAULT_RAG_TEMPLATE,
-    RAG_TEMPLATE,
-    RAG_TOP_K,
-    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
-    RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-    RAG_WEB_SEARCH_ENGINE,
-    RAG_WEB_SEARCH_RESULT_COUNT,
-    SEARCHAPI_API_KEY,
-    SEARCHAPI_ENGINE,
-    SEARXNG_QUERY_URL,
-    SERPER_API_KEY,
-    SERPLY_API_KEY,
-    SERPSTACK_API_KEY,
-    SERPSTACK_HTTPS,
-    TAVILY_API_KEY,
-    TIKA_SERVER_URL,
-    UPLOAD_DIR,
-    YOUTUBE_LOADER_LANGUAGE,
-    AppConfig,
-)
-from open_webui.constants import ERROR_MESSAGES
-from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE, DOCKER
-from open_webui.utils.misc import (
-    calculate_sha256,
-    calculate_sha256_string,
-    extract_folders_after_data_docs,
-    sanitize_filename,
-)
-from open_webui.utils.utils import get_admin_user, get_verified_user
-
-from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
-from langchain_community.document_loaders import (
-    YoutubeLoader,
-)
-from langchain_core.documents import Document
-
-
-log = logging.getLogger(__name__)
-log.setLevel(SRC_LOG_LEVELS["RAG"])
-
-app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None)
-
-app.state.config = AppConfig()
-
-app.state.config.TOP_K = RAG_TOP_K
-app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
-app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
-app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
-
-app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
-app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
-    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
-)
-
-app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
-app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
-
-app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
-app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
-
-app.state.config.CHUNK_SIZE = CHUNK_SIZE
-app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
-
-app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
-app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
-app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
-app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
-app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
-
-app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
-app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
-
-app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
-
-app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
-app.state.YOUTUBE_LOADER_TRANSLATION = None
-
-
-app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
-app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
-app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
-
-app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
-app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
-app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
-app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
-app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
-app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
-app.state.config.SERPER_API_KEY = SERPER_API_KEY
-app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
-app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
-app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
-app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
-app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
-app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
-
-
-def update_embedding_model(
-    embedding_model: str,
-    auto_update: bool = False,
-):
-    if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
-        from sentence_transformers import SentenceTransformer
-
-        app.state.sentence_transformer_ef = SentenceTransformer(
-            get_model_path(embedding_model, auto_update),
-            device=DEVICE_TYPE,
-            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
-        )
-    else:
-        app.state.sentence_transformer_ef = None
-
-
-def update_reranking_model(
-    reranking_model: str,
-    auto_update: bool = False,
-):
-    if reranking_model:
-        if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
-            try:
-                from open_webui.apps.retrieval.models.colbert import ColBERT
-
-                app.state.sentence_transformer_rf = ColBERT(
-                    get_model_path(reranking_model, auto_update),
-                    env="docker" if DOCKER else None,
-                )
-            except Exception as e:
-                log.error(f"ColBERT: {e}")
-                app.state.sentence_transformer_rf = None
-                app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
-        else:
-            import sentence_transformers
-
-            try:
-                app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
-                    get_model_path(reranking_model, auto_update),
-                    device=DEVICE_TYPE,
-                    trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
-                )
-            except:
-                log.error("CrossEncoder error")
-                app.state.sentence_transformer_rf = None
-                app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
-    else:
-        app.state.sentence_transformer_rf = None
-
-
-update_embedding_model(
-    app.state.config.RAG_EMBEDDING_MODEL,
-    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
-)
-
-update_reranking_model(
-    app.state.config.RAG_RERANKING_MODEL,
-    RAG_RERANKING_MODEL_AUTO_UPDATE,
-)
-
-
-app.state.EMBEDDING_FUNCTION = get_embedding_function(
-    app.state.config.RAG_EMBEDDING_ENGINE,
-    app.state.config.RAG_EMBEDDING_MODEL,
-    app.state.sentence_transformer_ef,
-    app.state.config.OPENAI_API_KEY,
-    app.state.config.OPENAI_API_BASE_URL,
-    app.state.config.RAG_EMBEDDING_BATCH_SIZE,
-)
-
-app.add_middleware(
-    CORSMiddleware,
-    allow_origins=CORS_ALLOW_ORIGIN,
-    allow_credentials=True,
-    allow_methods=["*"],
-    allow_headers=["*"],
-)
-
-
-class CollectionNameForm(BaseModel):
-    collection_name: Optional[str] = None
-
-
-class ProcessUrlForm(CollectionNameForm):
-    url: str
-
-
-class SearchForm(CollectionNameForm):
-    query: str
-
-
-@app.get("/")
-async def get_status():
-    return {
-        "status": True,
-        "chunk_size": app.state.config.CHUNK_SIZE,
-        "chunk_overlap": app.state.config.CHUNK_OVERLAP,
-        "template": app.state.config.RAG_TEMPLATE,
-        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
-        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
-        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
-        "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
-    }
-
-
-@app.get("/embedding")
-async def get_embedding_config(user=Depends(get_admin_user)):
-    return {
-        "status": True,
-        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
-        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
-        "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
-        "openai_config": {
-            "url": app.state.config.OPENAI_API_BASE_URL,
-            "key": app.state.config.OPENAI_API_KEY,
-        },
-    }
-
-
-@app.get("/reranking")
-async def get_reraanking_config(user=Depends(get_admin_user)):
-    return {
-        "status": True,
-        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
-    }
-
-
-class OpenAIConfigForm(BaseModel):
-    url: str
-    key: str
-
-
-class EmbeddingModelUpdateForm(BaseModel):
-    openai_config: Optional[OpenAIConfigForm] = None
-    embedding_engine: str
-    embedding_model: str
-    embedding_batch_size: Optional[int] = 1
-
-
-@app.post("/embedding/update")
-async def update_embedding_config(
-    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
-):
-    log.info(
-        f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
-    )
-    try:
-        app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
-        app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
-
-        if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
-            if form_data.openai_config is not None:
-                app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
-                app.state.config.OPENAI_API_KEY = form_data.openai_config.key
-            app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
-
-        update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
-
-        app.state.EMBEDDING_FUNCTION = get_embedding_function(
-            app.state.config.RAG_EMBEDDING_ENGINE,
-            app.state.config.RAG_EMBEDDING_MODEL,
-            app.state.sentence_transformer_ef,
-            app.state.config.OPENAI_API_KEY,
-            app.state.config.OPENAI_API_BASE_URL,
-            app.state.config.RAG_EMBEDDING_BATCH_SIZE,
-        )
-
-        return {
-            "status": True,
-            "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
-            "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
-            "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
-            "openai_config": {
-                "url": app.state.config.OPENAI_API_BASE_URL,
-                "key": app.state.config.OPENAI_API_KEY,
-            },
-        }
-    except Exception as e:
-        log.exception(f"Problem updating embedding model: {e}")
-        raise HTTPException(
-            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
-            detail=ERROR_MESSAGES.DEFAULT(e),
-        )
-
-
-class RerankingModelUpdateForm(BaseModel):
-    reranking_model: str
-
-
-@app.post("/reranking/update")
-async def update_reranking_config(
-    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
-):
-    log.info(
-        f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
-    )
-    try:
-        app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
-
-        update_reranking_model(app.state.config.RAG_RERANKING_MODEL, True)
-
-        return {
-            "status": True,
-            "reranking_model": app.state.config.RAG_RERANKING_MODEL,
-        }
-    except Exception as e:
-        log.exception(f"Problem updating reranking model: {e}")
-        raise HTTPException(
-            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
-            detail=ERROR_MESSAGES.DEFAULT(e),
-        )
-
-
-@app.get("/config")
-async def get_rag_config(user=Depends(get_admin_user)):
-    return {
-        "status": True,
-        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
-        "content_extraction": {
-            "engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
-            "tika_server_url": app.state.config.TIKA_SERVER_URL,
-        },
-        "chunk": {
-            "text_splitter": app.state.config.TEXT_SPLITTER,
-            "chunk_size": app.state.config.CHUNK_SIZE,
-            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
-        },
-        "file": {
-            "max_size": app.state.config.FILE_MAX_SIZE,
-            "max_count": app.state.config.FILE_MAX_COUNT,
-        },
-        "youtube": {
-            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
-            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
-        },
-        "web": {
-            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
-            "search": {
-                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
-                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
-                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
-                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
-                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
-                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
-                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
-                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
-                "serper_api_key": app.state.config.SERPER_API_KEY,
-                "serply_api_key": app.state.config.SERPLY_API_KEY,
-                "tavily_api_key": app.state.config.TAVILY_API_KEY,
-                "searchapi_api_key": app.state.config.SEARCHAPI_API_KEY,
-                "seaarchapi_engine": app.state.config.SEARCHAPI_ENGINE,
-                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
-            },
-        },
-    }
-
-
-class FileConfig(BaseModel):
-    max_size: Optional[int] = None
-    max_count: Optional[int] = None
-
-
-class ContentExtractionConfig(BaseModel):
-    engine: str = ""
-    tika_server_url: Optional[str] = None
-
-
-class ChunkParamUpdateForm(BaseModel):
-    text_splitter: Optional[str] = None
-    chunk_size: int
-    chunk_overlap: int
-
-
-class YoutubeLoaderConfig(BaseModel):
-    language: list[str]
-    translation: Optional[str] = None
-
-
-class WebSearchConfig(BaseModel):
-    enabled: bool
-    engine: Optional[str] = None
-    searxng_query_url: Optional[str] = None
-    google_pse_api_key: Optional[str] = None
-    google_pse_engine_id: Optional[str] = None
-    brave_search_api_key: Optional[str] = None
-    serpstack_api_key: Optional[str] = None
-    serpstack_https: Optional[bool] = None
-    serper_api_key: Optional[str] = None
-    serply_api_key: Optional[str] = None
-    tavily_api_key: Optional[str] = None
-    searchapi_api_key: Optional[str] = None
-    searchapi_engine: Optional[str] = None
-    result_count: Optional[int] = None
-    concurrent_requests: Optional[int] = None
-
-
-class WebConfig(BaseModel):
-    search: WebSearchConfig
-    web_loader_ssl_verification: Optional[bool] = None
-
-
-class ConfigUpdateForm(BaseModel):
-    pdf_extract_images: Optional[bool] = None
-    file: Optional[FileConfig] = None
-    content_extraction: Optional[ContentExtractionConfig] = None
-    chunk: Optional[ChunkParamUpdateForm] = None
-    youtube: Optional[YoutubeLoaderConfig] = None
-    web: Optional[WebConfig] = None
-
-
-@app.post("/config/update")
-async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
-    app.state.config.PDF_EXTRACT_IMAGES = (
-        form_data.pdf_extract_images
-        if form_data.pdf_extract_images is not None
-        else app.state.config.PDF_EXTRACT_IMAGES
-    )
-
-    if form_data.file is not None:
-        app.state.config.FILE_MAX_SIZE = form_data.file.max_size
-        app.state.config.FILE_MAX_COUNT = form_data.file.max_count
-
-    if form_data.content_extraction is not None:
-        log.info(f"Updating text settings: {form_data.content_extraction}")
-        app.state.config.CONTENT_EXTRACTION_ENGINE = form_data.content_extraction.engine
-        app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url
-
-    if form_data.chunk is not None:
-        app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
-        app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
-        app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
-
-    if form_data.youtube is not None:
-        app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
-        app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
-
-    if form_data.web is not None:
-        app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
-            form_data.web.web_loader_ssl_verification
-        )
-
-        app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
-        app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
-        app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
-        app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
-        app.state.config.GOOGLE_PSE_ENGINE_ID = (
-            form_data.web.search.google_pse_engine_id
-        )
-        app.state.config.BRAVE_SEARCH_API_KEY = (
-            form_data.web.search.brave_search_api_key
-        )
-        app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
-        app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
-        app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
-        app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
-        app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
-        app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key
-        app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine
-        app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
-        app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
-            form_data.web.search.concurrent_requests
-        )
-
-    return {
-        "status": True,
-        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
-        "file": {
-            "max_size": app.state.config.FILE_MAX_SIZE,
-            "max_count": app.state.config.FILE_MAX_COUNT,
-        },
-        "content_extraction": {
-            "engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
-            "tika_server_url": app.state.config.TIKA_SERVER_URL,
-        },
-        "chunk": {
-            "text_splitter": app.state.config.TEXT_SPLITTER,
-            "chunk_size": app.state.config.CHUNK_SIZE,
-            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
-        },
-        "youtube": {
-            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
-            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
-        },
-        "web": {
-            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
-            "search": {
-                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
-                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
-                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
-                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
-                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
-                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
-                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
-                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
-                "serper_api_key": app.state.config.SERPER_API_KEY,
-                "serply_api_key": app.state.config.SERPLY_API_KEY,
-                "serachapi_api_key": app.state.config.SEARCHAPI_API_KEY,
-                "searchapi_engine": app.state.config.SEARCHAPI_ENGINE,
-                "tavily_api_key": app.state.config.TAVILY_API_KEY,
-                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
-            },
-        },
-    }
-
-
-@app.get("/template")
-async def get_rag_template(user=Depends(get_verified_user)):
-    return {
-        "status": True,
-        "template": app.state.config.RAG_TEMPLATE,
-    }
-
-
-@app.get("/query/settings")
-async def get_query_settings(user=Depends(get_admin_user)):
-    return {
-        "status": True,
-        "template": app.state.config.RAG_TEMPLATE,
-        "k": app.state.config.TOP_K,
-        "r": app.state.config.RELEVANCE_THRESHOLD,
-        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
-    }
-
-
-class QuerySettingsForm(BaseModel):
-    k: Optional[int] = None
-    r: Optional[float] = None
-    template: Optional[str] = None
-    hybrid: Optional[bool] = None
-
-
-@app.post("/query/settings/update")
-async def update_query_settings(
-    form_data: QuerySettingsForm, user=Depends(get_admin_user)
-):
-    app.state.config.RAG_TEMPLATE = form_data.template
-    app.state.config.TOP_K = form_data.k if form_data.k else 4
-    app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
-
-    app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
-        form_data.hybrid if form_data.hybrid else False
-    )
-
-    return {
-        "status": True,
-        "template": app.state.config.RAG_TEMPLATE,
-        "k": app.state.config.TOP_K,
-        "r": app.state.config.RELEVANCE_THRESHOLD,
-        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
-    }
-
-
-####################################
-#
-# Document process and retrieval
-#
-####################################
-
-
-def save_docs_to_vector_db(
-    docs,
-    collection_name,
-    metadata: Optional[dict] = None,
-    overwrite: bool = False,
-    split: bool = True,
-    add: bool = False,
-) -> bool:
-    log.info(f"save_docs_to_vector_db {docs} {collection_name}")
-
-    # Check if entries with the same hash (metadata.hash) already exist
-    if metadata and "hash" in metadata:
-        result = VECTOR_DB_CLIENT.query(
-            collection_name=collection_name,
-            filter={"hash": metadata["hash"]},
-        )
-
-        if result is not None:
-            existing_doc_ids = result.ids[0]
-            if existing_doc_ids:
-                log.info(f"Document with hash {metadata['hash']} already exists")
-                raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
-
-    if split:
-        if app.state.config.TEXT_SPLITTER in ["", "character"]:
-            text_splitter = RecursiveCharacterTextSplitter(
-                chunk_size=app.state.config.CHUNK_SIZE,
-                chunk_overlap=app.state.config.CHUNK_OVERLAP,
-                add_start_index=True,
-            )
-        elif app.state.config.TEXT_SPLITTER == "token":
-            log.info(
-                f"Using token text splitter: {app.state.config.TIKTOKEN_ENCODING_NAME}"
-            )
-
-            tiktoken.get_encoding(str(app.state.config.TIKTOKEN_ENCODING_NAME))
-            text_splitter = TokenTextSplitter(
-                encoding_name=str(app.state.config.TIKTOKEN_ENCODING_NAME),
-                chunk_size=app.state.config.CHUNK_SIZE,
-                chunk_overlap=app.state.config.CHUNK_OVERLAP,
-                add_start_index=True,
-            )
-        else:
-            raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter"))
-
-        docs = text_splitter.split_documents(docs)
-
-    if len(docs) == 0:
-        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
-
-    texts = [doc.page_content for doc in docs]
-    metadatas = [
-        {
-            **doc.metadata,
-            **(metadata if metadata else {}),
-            "embedding_config": json.dumps(
-                {
-                    "engine": app.state.config.RAG_EMBEDDING_ENGINE,
-                    "model": app.state.config.RAG_EMBEDDING_MODEL,
-                }
-            ),
-        }
-        for doc in docs
-    ]
-
-    # ChromaDB does not like datetime formats
-    # for meta-data so convert them to string.
-    for metadata in metadatas:
-        for key, value in metadata.items():
-            if isinstance(value, datetime):
-                metadata[key] = str(value)
-
-    try:
-        if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
-            log.info(f"collection {collection_name} already exists")
-
-            if overwrite:
-                VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
-                log.info(f"deleting existing collection {collection_name}")
-            elif add is False:
-                log.info(
-                    f"collection {collection_name} already exists, overwrite is False and add is False"
-                )
-                return True
-
-        log.info(f"adding to collection {collection_name}")
-        embedding_function = get_embedding_function(
-            app.state.config.RAG_EMBEDDING_ENGINE,
-            app.state.config.RAG_EMBEDDING_MODEL,
-            app.state.sentence_transformer_ef,
-            app.state.config.OPENAI_API_KEY,
-            app.state.config.OPENAI_API_BASE_URL,
-            app.state.config.RAG_EMBEDDING_BATCH_SIZE,
-        )
-
-        embeddings = embedding_function(
-            list(map(lambda x: x.replace("\n", " "), texts))
-        )
-
-        items = [
-            {
-                "id": str(uuid.uuid4()),
-                "text": text,
-                "vector": embeddings[idx],
-                "metadata": metadatas[idx],
-            }
-            for idx, text in enumerate(texts)
-        ]
-
-        VECTOR_DB_CLIENT.insert(
-            collection_name=collection_name,
-            items=items,
-        )
-
-        return True
-    except Exception as e:
-        log.exception(e)
-        return False
-
-
-class ProcessFileForm(BaseModel):
-    file_id: str
-    content: Optional[str] = None
-    collection_name: Optional[str] = None
-
-
-@app.post("/process/file")
-def process_file(
-    form_data: ProcessFileForm,
-    user=Depends(get_verified_user),
-):
-    try:
-        file = Files.get_file_by_id(form_data.file_id)
-
-        collection_name = form_data.collection_name
-
-        if collection_name is None:
-            collection_name = f"file-{file.id}"
-
-        if form_data.content:
-            # Update the content in the file
-            # Usage: /files/{file_id}/data/content/update
-
-            VECTOR_DB_CLIENT.delete(
-                collection_name=f"file-{file.id}",
-                filter={"file_id": file.id},
-            )
-
-            docs = [
-                Document(
-                    page_content=form_data.content,
-                    metadata={
-                        "name": file.meta.get("name", file.filename),
-                        "created_by": file.user_id,
-                        "file_id": file.id,
-                        **file.meta,
-                    },
-                )
-            ]
-
-            text_content = form_data.content
-        elif form_data.collection_name:
-            # Check if the file has already been processed and save the content
-            # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update
-
-            result = VECTOR_DB_CLIENT.query(
-                collection_name=f"file-{file.id}", filter={"file_id": file.id}
-            )
-
-            if result is not None and len(result.ids[0]) > 0:
-                docs = [
-                    Document(
-                        page_content=result.documents[0][idx],
-                        metadata=result.metadatas[0][idx],
-                    )
-                    for idx, id in enumerate(result.ids[0])
-                ]
-            else:
-                docs = [
-                    Document(
-                        page_content=file.data.get("content", ""),
-                        metadata={
-                            "name": file.meta.get("name", file.filename),
-                            "created_by": file.user_id,
-                            "file_id": file.id,
-                            **file.meta,
-                        },
-                    )
-                ]
-
-            text_content = file.data.get("content", "")
-        else:
-            # Process the file and save the content
-            # Usage: /files/
-            file_path = file.path
-            if file_path:
-                file_path = Storage.get_file(file_path)
-                loader = Loader(
-                    engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
-                    TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
-                    PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
-                )
-                docs = loader.load(
-                    file.filename, file.meta.get("content_type"), file_path
-                )
-            else:
-                docs = [
-                    Document(
-                        page_content=file.data.get("content", ""),
-                        metadata={
-                            "name": file.filename,
-                            "created_by": file.user_id,
-                            "file_id": file.id,
-                            **file.meta,
-                        },
-                    )
-                ]
-            text_content = " ".join([doc.page_content for doc in docs])
-
-        log.debug(f"text_content: {text_content}")
-        Files.update_file_data_by_id(
-            file.id,
-            {"content": text_content},
-        )
-
-        hash = calculate_sha256_string(text_content)
-        Files.update_file_hash_by_id(file.id, hash)
-
-        try:
-            result = save_docs_to_vector_db(
-                docs=docs,
-                collection_name=collection_name,
-                metadata={
-                    "file_id": file.id,
-                    "name": file.meta.get("name", file.filename),
-                    "hash": hash,
-                },
-                add=(True if form_data.collection_name else False),
-            )
-
-            if result:
-                Files.update_file_metadata_by_id(
-                    file.id,
-                    {
-                        "collection_name": collection_name,
-                    },
-                )
-
-                return {
-                    "status": True,
-                    "collection_name": collection_name,
-                    "filename": file.meta.get("name", file.filename),
-                    "content": text_content,
-                }
-        except Exception as e:
-            raise e
-    except Exception as e:
-        log.exception(e)
-        if "No pandoc was found" in str(e):
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
-            )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=str(e),
-            )
-
-
-class ProcessTextForm(BaseModel):
-    name: str
-    content: str
-    collection_name: Optional[str] = None
-
-
-@app.post("/process/text")
-def process_text(
-    form_data: ProcessTextForm,
-    user=Depends(get_verified_user),
-):
-    collection_name = form_data.collection_name
-    if collection_name is None:
-        collection_name = calculate_sha256_string(form_data.content)
-
-    docs = [
-        Document(
-            page_content=form_data.content,
-            metadata={"name": form_data.name, "created_by": user.id},
-        )
-    ]
-    text_content = form_data.content
-    log.debug(f"text_content: {text_content}")
-
-    result = save_docs_to_vector_db(docs, collection_name)
-
-    if result:
-        return {
-            "status": True,
-            "collection_name": collection_name,
-            "content": text_content,
-        }
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
-            detail=ERROR_MESSAGES.DEFAULT(),
-        )
-
-
-@app.post("/process/youtube")
-def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_user)):
-    try:
-        collection_name = form_data.collection_name
-        if not collection_name:
-            collection_name = calculate_sha256_string(form_data.url)[:63]
-
-        loader = YoutubeLoader.from_youtube_url(
-            form_data.url,
-            add_video_info=True,
-            language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
-            translation=app.state.YOUTUBE_LOADER_TRANSLATION,
-        )
-        docs = loader.load()
-        content = " ".join([doc.page_content for doc in docs])
-        log.debug(f"text_content: {content}")
-        save_docs_to_vector_db(docs, collection_name, overwrite=True)
-
-        return {
-            "status": True,
-            "collection_name": collection_name,
-            "filename": form_data.url,
-            "file": {
-                "data": {
-                    "content": content,
-                },
-                "meta": {
-                    "name": form_data.url,
-                },
-            },
-        }
-    except Exception as e:
-        log.exception(e)
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.DEFAULT(e),
-        )
-
-
-@app.post("/process/web")
-def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)):
-    try:
-        collection_name = form_data.collection_name
-        if not collection_name:
-            collection_name = calculate_sha256_string(form_data.url)[:63]
-
-        loader = get_web_loader(
-            form_data.url,
-            verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
-            requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
-        )
-        docs = loader.load()
-        content = " ".join([doc.page_content for doc in docs])
-        log.debug(f"text_content: {content}")
-        save_docs_to_vector_db(docs, collection_name, overwrite=True)
-
-        return {
-            "status": True,
-            "collection_name": collection_name,
-            "filename": form_data.url,
-            "file": {
-                "data": {
-                    "content": content,
-                },
-                "meta": {
-                    "name": form_data.url,
-                },
-            },
-        }
-    except Exception as e:
-        log.exception(e)
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.DEFAULT(e),
-        )
-
-
-def search_web(engine: str, query: str) -> list[SearchResult]:
-    """Search the web using a search engine and return the results as a list of SearchResult objects.
-    Will look for a search engine API key in environment variables in the following order:
-    - SEARXNG_QUERY_URL
-    - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
-    - BRAVE_SEARCH_API_KEY
-    - SERPSTACK_API_KEY
-    - SERPER_API_KEY
-    - SERPLY_API_KEY
-    - TAVILY_API_KEY
-    - SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
-    Args:
-        query (str): The query to search for
-    """
-
-    # TODO: add playwright to search the web
-    if engine == "searxng":
-        if app.state.config.SEARXNG_QUERY_URL:
-            return search_searxng(
-                app.state.config.SEARXNG_QUERY_URL,
-                query,
-                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-            )
-        else:
-            raise Exception("No SEARXNG_QUERY_URL found in environment variables")
-    elif engine == "google_pse":
-        if (
-            app.state.config.GOOGLE_PSE_API_KEY
-            and app.state.config.GOOGLE_PSE_ENGINE_ID
-        ):
-            return search_google_pse(
-                app.state.config.GOOGLE_PSE_API_KEY,
-                app.state.config.GOOGLE_PSE_ENGINE_ID,
-                query,
-                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-            )
-        else:
-            raise Exception(
-                "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
-            )
-    elif engine == "brave":
-        if app.state.config.BRAVE_SEARCH_API_KEY:
-            return search_brave(
-                app.state.config.BRAVE_SEARCH_API_KEY,
-                query,
-                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-            )
-        else:
-            raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
-    elif engine == "serpstack":
-        if app.state.config.SERPSTACK_API_KEY:
-            return search_serpstack(
-                app.state.config.SERPSTACK_API_KEY,
-                query,
-                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-                https_enabled=app.state.config.SERPSTACK_HTTPS,
-            )
-        else:
-            raise Exception("No SERPSTACK_API_KEY found in environment variables")
-    elif engine == "serper":
-        if app.state.config.SERPER_API_KEY:
-            return search_serper(
-                app.state.config.SERPER_API_KEY,
-                query,
-                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-            )
-        else:
-            raise Exception("No SERPER_API_KEY found in environment variables")
-    elif engine == "serply":
-        if app.state.config.SERPLY_API_KEY:
-            return search_serply(
-                app.state.config.SERPLY_API_KEY,
-                query,
-                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-            )
-        else:
-            raise Exception("No SERPLY_API_KEY found in environment variables")
-    elif engine == "duckduckgo":
-        return search_duckduckgo(
-            query,
-            app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-            app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-        )
-    elif engine == "tavily":
-        if app.state.config.TAVILY_API_KEY:
-            return search_tavily(
-                app.state.config.TAVILY_API_KEY,
-                query,
-                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-            )
-        else:
-            raise Exception("No TAVILY_API_KEY found in environment variables")
-    elif engine == "searchapi":
-        if app.state.config.SEARCHAPI_API_KEY:
-            return search_searchapi(
-                app.state.config.SEARCHAPI_API_KEY,
-                app.state.config.SEARCHAPI_ENGINE,
-                query,
-                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
-                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
-            )
-        else:
-            raise Exception("No SEARCHAPI_API_KEY found in environment variables")
-    elif engine == "jina":
-        return search_jina(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
-    else:
-        raise Exception("No search engine API key found in environment variables")
-
-
-@app.post("/process/web/search")
-def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
-    try:
-        logging.info(
-            f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
-        )
-        web_results = search_web(
-            app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
-        )
-    except Exception as e:
-        log.exception(e)
-
-        print(e)
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
-        )
-
-    try:
-        collection_name = form_data.collection_name
-        if collection_name == "":
-            collection_name = calculate_sha256_string(form_data.query)[:63]
-
-        urls = [result.link for result in web_results]
-
-        loader = get_web_loader(urls)
-        docs = loader.load()
-
-        save_docs_to_vector_db(docs, collection_name, overwrite=True)
-
-        return {
-            "status": True,
-            "collection_name": collection_name,
-            "filenames": urls,
-        }
-    except Exception as e:
-        log.exception(e)
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.DEFAULT(e),
-        )
-
-
-class QueryDocForm(BaseModel):
-    collection_name: str
-    query: str
-    k: Optional[int] = None
-    r: Optional[float] = None
-    hybrid: Optional[bool] = None
-
-
-@app.post("/query/doc")
-def query_doc_handler(
-    form_data: QueryDocForm,
-    user=Depends(get_verified_user),
-):
-    try:
-        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
-            return query_doc_with_hybrid_search(
-                collection_name=form_data.collection_name,
-                query=form_data.query,
-                embedding_function=app.state.EMBEDDING_FUNCTION,
-                k=form_data.k if form_data.k else app.state.config.TOP_K,
-                reranking_function=app.state.sentence_transformer_rf,
-                r=(
-                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
-                ),
-            )
-        else:
-            return query_doc(
-                collection_name=form_data.collection_name,
-                query=form_data.query,
-                embedding_function=app.state.EMBEDDING_FUNCTION,
-                k=form_data.k if form_data.k else app.state.config.TOP_K,
-            )
-    except Exception as e:
-        log.exception(e)
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.DEFAULT(e),
-        )
-
-
-class QueryCollectionsForm(BaseModel):
-    collection_names: list[str]
-    query: str
-    k: Optional[int] = None
-    r: Optional[float] = None
-    hybrid: Optional[bool] = None
-
-
-@app.post("/query/collection")
-def query_collection_handler(
-    form_data: QueryCollectionsForm,
-    user=Depends(get_verified_user),
-):
-    try:
-        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
-            return query_collection_with_hybrid_search(
-                collection_names=form_data.collection_names,
-                query=form_data.query,
-                embedding_function=app.state.EMBEDDING_FUNCTION,
-                k=form_data.k if form_data.k else app.state.config.TOP_K,
-                reranking_function=app.state.sentence_transformer_rf,
-                r=(
-                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
-                ),
-            )
-        else:
-            return query_collection(
-                collection_names=form_data.collection_names,
-                query=form_data.query,
-                embedding_function=app.state.EMBEDDING_FUNCTION,
-                k=form_data.k if form_data.k else app.state.config.TOP_K,
-            )
-
-    except Exception as e:
-        log.exception(e)
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.DEFAULT(e),
-        )
-
-
-####################################
-#
-# Vector DB operations
-#
-####################################
-
-
-class DeleteForm(BaseModel):
-    collection_name: str
-    file_id: str
-
-
-@app.post("/delete")
-def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)):
-    try:
-        if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
-            file = Files.get_file_by_id(form_data.file_id)
-            hash = file.hash
-
-            VECTOR_DB_CLIENT.delete(
-                collection_name=form_data.collection_name,
-                metadata={"hash": hash},
-            )
-            return {"status": True}
-        else:
-            return {"status": False}
-    except Exception as e:
-        log.exception(e)
-        return {"status": False}
-
-
-@app.post("/reset/db")
-def reset_vector_db(user=Depends(get_admin_user)):
-    VECTOR_DB_CLIENT.reset()
-    Knowledges.delete_all_knowledge()
-
-
-@app.post("/reset/uploads")
-def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
-    folder = f"{UPLOAD_DIR}"
-    try:
-        # Check if the directory exists
-        if os.path.exists(folder):
-            # Iterate over all the files and directories in the specified directory
-            for filename in os.listdir(folder):
-                file_path = os.path.join(folder, filename)
-                try:
-                    if os.path.isfile(file_path) or os.path.islink(file_path):
-                        os.unlink(file_path)  # Remove the file or link
-                    elif os.path.isdir(file_path):
-                        shutil.rmtree(file_path)  # Remove the directory
-                except Exception as e:
-                    print(f"Failed to delete {file_path}. Reason: {e}")
-        else:
-            print(f"The directory {folder} does not exist")
-    except Exception as e:
-        print(f"Failed to process the directory {folder}. Reason: {e}")
-    return True
-
-
-if ENV == "dev":
-
-    @app.get("/ef")
-    async def get_embeddings():
-        return {"result": app.state.EMBEDDING_FUNCTION("hello world")}
-
-    @app.get("/ef/{text}")
-    async def get_embeddings_text(text: str):
-        return {"result": app.state.EMBEDDING_FUNCTION(text)}

+ 0 - 14
backend/open_webui/apps/retrieval/vector/connector.py

@@ -1,14 +0,0 @@
-from open_webui.config import VECTOR_DB
-
-if VECTOR_DB == "milvus":
-    from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
-
-    VECTOR_DB_CLIENT = MilvusClient()
-elif VECTOR_DB == "qdrant":
-    from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
-
-    VECTOR_DB_CLIENT = QdrantClient()
-else:
-    from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
-
-    VECTOR_DB_CLIENT = ChromaClient()

+ 0 - 458
backend/open_webui/apps/webui/main.py

@@ -1,458 +0,0 @@
-import inspect
-import json
-import logging
-import time
-from typing import AsyncGenerator, Generator, Iterator
-
-from open_webui.apps.socket.main import get_event_call, get_event_emitter
-from open_webui.apps.webui.models.functions import Functions
-from open_webui.apps.webui.models.models import Models
-from open_webui.apps.webui.routers import (
-    auths,
-    chats,
-    folders,
-    configs,
-    files,
-    functions,
-    memories,
-    models,
-    knowledge,
-    prompts,
-    evaluations,
-    tools,
-    users,
-    utils,
-)
-from open_webui.apps.webui.utils import load_function_module_by_id
-from open_webui.config import (
-    ADMIN_EMAIL,
-    CORS_ALLOW_ORIGIN,
-    DEFAULT_MODELS,
-    DEFAULT_PROMPT_SUGGESTIONS,
-    DEFAULT_USER_ROLE,
-    ENABLE_COMMUNITY_SHARING,
-    ENABLE_LOGIN_FORM,
-    ENABLE_MESSAGE_RATING,
-    ENABLE_SIGNUP,
-    ENABLE_EVALUATION_ARENA_MODELS,
-    EVALUATION_ARENA_MODELS,
-    DEFAULT_ARENA_MODEL,
-    JWT_EXPIRES_IN,
-    ENABLE_OAUTH_ROLE_MANAGEMENT,
-    OAUTH_ROLES_CLAIM,
-    OAUTH_EMAIL_CLAIM,
-    OAUTH_PICTURE_CLAIM,
-    OAUTH_USERNAME_CLAIM,
-    OAUTH_ALLOWED_ROLES,
-    OAUTH_ADMIN_ROLES,
-    SHOW_ADMIN_DETAILS,
-    USER_PERMISSIONS,
-    WEBHOOK_URL,
-    WEBUI_AUTH,
-    WEBUI_BANNERS,
-    AppConfig,
-)
-from open_webui.env import (
-    ENV,
-    WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
-    WEBUI_AUTH_TRUSTED_NAME_HEADER,
-)
-from fastapi import FastAPI
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import StreamingResponse
-from pydantic import BaseModel
-from open_webui.utils.misc import (
-    openai_chat_chunk_message_template,
-    openai_chat_completion_message_template,
-)
-from open_webui.utils.payload import (
-    apply_model_params_to_body_openai,
-    apply_model_system_prompt_to_body,
-)
-
-
-from open_webui.utils.tools import get_tools
-
-app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None)
-
-log = logging.getLogger(__name__)
-
-app.state.config = AppConfig()
-
-app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
-app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
-app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
-app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
-app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
-
-
-app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
-app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
-
-
-app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
-app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
-app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
-app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
-app.state.config.WEBHOOK_URL = WEBHOOK_URL
-app.state.config.BANNERS = WEBUI_BANNERS
-
-app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
-app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING
-
-app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS
-app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS
-
-app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
-app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
-app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
-
-app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
-app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
-app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
-app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
-
-app.state.MODELS = {}
-app.state.TOOLS = {}
-app.state.FUNCTIONS = {}
-
-app.add_middleware(
-    CORSMiddleware,
-    allow_origins=CORS_ALLOW_ORIGIN,
-    allow_credentials=True,
-    allow_methods=["*"],
-    allow_headers=["*"],
-)
-
-
-app.include_router(configs.router, prefix="/configs", tags=["configs"])
-
-app.include_router(auths.router, prefix="/auths", tags=["auths"])
-app.include_router(users.router, prefix="/users", tags=["users"])
-
-app.include_router(chats.router, prefix="/chats", tags=["chats"])
-
-app.include_router(models.router, prefix="/models", tags=["models"])
-app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
-app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
-app.include_router(tools.router, prefix="/tools", tags=["tools"])
-app.include_router(functions.router, prefix="/functions", tags=["functions"])
-
-app.include_router(memories.router, prefix="/memories", tags=["memories"])
-app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
-
-app.include_router(folders.router, prefix="/folders", tags=["folders"])
-app.include_router(files.router, prefix="/files", tags=["files"])
-
-app.include_router(utils.router, prefix="/utils", tags=["utils"])
-
-
-@app.get("/")
-async def get_status():
-    return {
-        "status": True,
-        "auth": WEBUI_AUTH,
-        "default_models": app.state.config.DEFAULT_MODELS,
-        "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
-    }
-
-
-async def get_all_models():
-    models = []
-    pipe_models = await get_pipe_models()
-    models = models + pipe_models
-
-    if app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
-        arena_models = []
-        if len(app.state.config.EVALUATION_ARENA_MODELS) > 0:
-            arena_models = [
-                {
-                    "id": model["id"],
-                    "name": model["name"],
-                    "info": {
-                        "meta": model["meta"],
-                    },
-                    "object": "model",
-                    "created": int(time.time()),
-                    "owned_by": "arena",
-                    "arena": True,
-                }
-                for model in app.state.config.EVALUATION_ARENA_MODELS
-            ]
-        else:
-            # Add default arena model
-            arena_models = [
-                {
-                    "id": DEFAULT_ARENA_MODEL["id"],
-                    "name": DEFAULT_ARENA_MODEL["name"],
-                    "info": {
-                        "meta": DEFAULT_ARENA_MODEL["meta"],
-                    },
-                    "object": "model",
-                    "created": int(time.time()),
-                    "owned_by": "arena",
-                    "arena": True,
-                }
-            ]
-        models = models + arena_models
-    return models
-
-
-def get_function_module(pipe_id: str):
-    # Check if function is already loaded
-    if pipe_id not in app.state.FUNCTIONS:
-        function_module, _, _ = load_function_module_by_id(pipe_id)
-        app.state.FUNCTIONS[pipe_id] = function_module
-    else:
-        function_module = app.state.FUNCTIONS[pipe_id]
-
-    if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
-        valves = Functions.get_function_valves_by_id(pipe_id)
-        function_module.valves = function_module.Valves(**(valves if valves else {}))
-    return function_module
-
-
-async def get_pipe_models():
-    pipes = Functions.get_functions_by_type("pipe", active_only=True)
-    pipe_models = []
-
-    for pipe in pipes:
-        function_module = get_function_module(pipe.id)
-
-        # Check if function is a manifold
-        if hasattr(function_module, "pipes"):
-            sub_pipes = []
-
-            # Check if pipes is a function or a list
-
-            try:
-                if callable(function_module.pipes):
-                    sub_pipes = function_module.pipes()
-                else:
-                    sub_pipes = function_module.pipes
-            except Exception as e:
-                log.exception(e)
-                sub_pipes = []
-
-            print(sub_pipes)
-
-            for p in sub_pipes:
-                sub_pipe_id = f'{pipe.id}.{p["id"]}'
-                sub_pipe_name = p["name"]
-
-                if hasattr(function_module, "name"):
-                    sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
-
-                pipe_flag = {"type": pipe.type}
-                pipe_models.append(
-                    {
-                        "id": sub_pipe_id,
-                        "name": sub_pipe_name,
-                        "object": "model",
-                        "created": pipe.created_at,
-                        "owned_by": "openai",
-                        "pipe": pipe_flag,
-                    }
-                )
-        else:
-            pipe_flag = {"type": "pipe"}
-
-            pipe_models.append(
-                {
-                    "id": pipe.id,
-                    "name": pipe.name,
-                    "object": "model",
-                    "created": pipe.created_at,
-                    "owned_by": "openai",
-                    "pipe": pipe_flag,
-                }
-            )
-
-    return pipe_models
-
-
-async def execute_pipe(pipe, params):
-    if inspect.iscoroutinefunction(pipe):
-        return await pipe(**params)
-    else:
-        return pipe(**params)
-
-
-async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
-    if isinstance(res, str):
-        return res
-    if isinstance(res, Generator):
-        return "".join(map(str, res))
-    if isinstance(res, AsyncGenerator):
-        return "".join([str(stream) async for stream in res])
-
-
-def process_line(form_data: dict, line):
-    if isinstance(line, BaseModel):
-        line = line.model_dump_json()
-        line = f"data: {line}"
-    if isinstance(line, dict):
-        line = f"data: {json.dumps(line)}"
-
-    try:
-        line = line.decode("utf-8")
-    except Exception:
-        pass
-
-    if line.startswith("data:"):
-        return f"{line}\n\n"
-    else:
-        line = openai_chat_chunk_message_template(form_data["model"], line)
-        return f"data: {json.dumps(line)}\n\n"
-
-
-def get_pipe_id(form_data: dict) -> str:
-    pipe_id = form_data["model"]
-    if "." in pipe_id:
-        pipe_id, _ = pipe_id.split(".", 1)
-    print(pipe_id)
-    return pipe_id
-
-
-def get_function_params(function_module, form_data, user, extra_params=None):
-    if extra_params is None:
-        extra_params = {}
-
-    pipe_id = get_pipe_id(form_data)
-
-    # Get the signature of the function
-    sig = inspect.signature(function_module.pipe)
-    params = {"body": form_data} | {
-        k: v for k, v in extra_params.items() if k in sig.parameters
-    }
-
-    if "__user__" in params and hasattr(function_module, "UserValves"):
-        user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
-        try:
-            params["__user__"]["valves"] = function_module.UserValves(**user_valves)
-        except Exception as e:
-            log.exception(e)
-            params["__user__"]["valves"] = function_module.UserValves()
-
-    return params
-
-
-async def generate_function_chat_completion(form_data, user):
-    model_id = form_data.get("model")
-    model_info = Models.get_model_by_id(model_id)
-
-    metadata = form_data.pop("metadata", {})
-
-    files = metadata.get("files", [])
-    tool_ids = metadata.get("tool_ids", [])
-    # Check if tool_ids is None
-    if tool_ids is None:
-        tool_ids = []
-
-    __event_emitter__ = None
-    __event_call__ = None
-    __task__ = None
-    __task_body__ = None
-
-    if metadata:
-        if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
-            __event_emitter__ = get_event_emitter(metadata)
-            __event_call__ = get_event_call(metadata)
-        __task__ = metadata.get("task", None)
-        __task_body__ = metadata.get("task_body", None)
-
-    extra_params = {
-        "__event_emitter__": __event_emitter__,
-        "__event_call__": __event_call__,
-        "__task__": __task__,
-        "__task_body__": __task_body__,
-        "__files__": files,
-        "__user__": {
-            "id": user.id,
-            "email": user.email,
-            "name": user.name,
-            "role": user.role,
-        },
-    }
-    extra_params["__tools__"] = get_tools(
-        app,
-        tool_ids,
-        user,
-        {
-            **extra_params,
-            "__model__": app.state.MODELS[form_data["model"]],
-            "__messages__": form_data["messages"],
-            "__files__": files,
-        },
-    )
-
-    if model_info:
-        if model_info.base_model_id:
-            form_data["model"] = model_info.base_model_id
-
-        params = model_info.params.model_dump()
-        form_data = apply_model_params_to_body_openai(params, form_data)
-        form_data = apply_model_system_prompt_to_body(params, form_data, user)
-
-    pipe_id = get_pipe_id(form_data)
-    function_module = get_function_module(pipe_id)
-
-    pipe = function_module.pipe
-    params = get_function_params(function_module, form_data, user, extra_params)
-
-    if form_data.get("stream", False):
-
-        async def stream_content():
-            try:
-                res = await execute_pipe(pipe, params)
-
-                # Directly return if the response is a StreamingResponse
-                if isinstance(res, StreamingResponse):
-                    async for data in res.body_iterator:
-                        yield data
-                    return
-                if isinstance(res, dict):
-                    yield f"data: {json.dumps(res)}\n\n"
-                    return
-
-            except Exception as e:
-                print(f"Error: {e}")
-                yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
-                return
-
-            if isinstance(res, str):
-                message = openai_chat_chunk_message_template(form_data["model"], res)
-                yield f"data: {json.dumps(message)}\n\n"
-
-            if isinstance(res, Iterator):
-                for line in res:
-                    yield process_line(form_data, line)
-
-            if isinstance(res, AsyncGenerator):
-                async for line in res:
-                    yield process_line(form_data, line)
-
-            if isinstance(res, str) or isinstance(res, Generator):
-                finish_message = openai_chat_chunk_message_template(
-                    form_data["model"], ""
-                )
-                finish_message["choices"][0]["finish_reason"] = "stop"
-                yield f"data: {json.dumps(finish_message)}\n\n"
-                yield "data: [DONE]"
-
-        return StreamingResponse(stream_content(), media_type="text/event-stream")
-    else:
-        try:
-            res = await execute_pipe(pipe, params)
-
-        except Exception as e:
-            print(f"Error: {e}")
-            return {"error": {"detail": str(e)}}
-
-        if isinstance(res, StreamingResponse) or isinstance(res, dict):
-            return res
-        if isinstance(res, BaseModel):
-            return res.model_dump()
-
-        message = await get_message_content(res)
-        return openai_chat_completion_message_template(form_data["model"], message)

+ 0 - 157
backend/open_webui/apps/webui/models/documents.py

@@ -1,157 +0,0 @@
-import json
-import logging
-import time
-from typing import Optional
-
-from open_webui.apps.webui.internal.db import Base, get_db
-from open_webui.env import SRC_LOG_LEVELS
-from pydantic import BaseModel, ConfigDict
-from sqlalchemy import BigInteger, Column, String, Text
-
-log = logging.getLogger(__name__)
-log.setLevel(SRC_LOG_LEVELS["MODELS"])
-
-####################
-# Documents DB Schema
-####################
-
-
-class Document(Base):
-    __tablename__ = "document"
-
-    collection_name = Column(String, primary_key=True)
-    name = Column(String, unique=True)
-    title = Column(Text)
-    filename = Column(Text)
-    content = Column(Text, nullable=True)
-    user_id = Column(String)
-    timestamp = Column(BigInteger)
-
-
-class DocumentModel(BaseModel):
-    model_config = ConfigDict(from_attributes=True)
-
-    collection_name: str
-    name: str
-    title: str
-    filename: str
-    content: Optional[str] = None
-    user_id: str
-    timestamp: int  # timestamp in epoch
-
-
-####################
-# Forms
-####################
-
-
-class DocumentResponse(BaseModel):
-    collection_name: str
-    name: str
-    title: str
-    filename: str
-    content: Optional[dict] = None
-    user_id: str
-    timestamp: int  # timestamp in epoch
-
-
-class DocumentUpdateForm(BaseModel):
-    name: str
-    title: str
-
-
-class DocumentForm(DocumentUpdateForm):
-    collection_name: str
-    filename: str
-    content: Optional[str] = None
-
-
-class DocumentsTable:
-    def insert_new_doc(
-        self, user_id: str, form_data: DocumentForm
-    ) -> Optional[DocumentModel]:
-        with get_db() as db:
-            document = DocumentModel(
-                **{
-                    **form_data.model_dump(),
-                    "user_id": user_id,
-                    "timestamp": int(time.time()),
-                }
-            )
-
-            try:
-                result = Document(**document.model_dump())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
-                if result:
-                    return DocumentModel.model_validate(result)
-                else:
-                    return None
-            except Exception:
-                return None
-
-    def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
-        try:
-            with get_db() as db:
-                document = db.query(Document).filter_by(name=name).first()
-                return DocumentModel.model_validate(document) if document else None
-        except Exception:
-            return None
-
-    def get_docs(self) -> list[DocumentModel]:
-        with get_db() as db:
-            return [
-                DocumentModel.model_validate(doc) for doc in db.query(Document).all()
-            ]
-
-    def update_doc_by_name(
-        self, name: str, form_data: DocumentUpdateForm
-    ) -> Optional[DocumentModel]:
-        try:
-            with get_db() as db:
-                db.query(Document).filter_by(name=name).update(
-                    {
-                        "title": form_data.title,
-                        "name": form_data.name,
-                        "timestamp": int(time.time()),
-                    }
-                )
-                db.commit()
-                return self.get_doc_by_name(form_data.name)
-        except Exception as e:
-            log.exception(e)
-            return None
-
-    def update_doc_content_by_name(
-        self, name: str, updated: dict
-    ) -> Optional[DocumentModel]:
-        try:
-            doc = self.get_doc_by_name(name)
-            doc_content = json.loads(doc.content if doc.content else "{}")
-            doc_content = {**doc_content, **updated}
-
-            with get_db() as db:
-                db.query(Document).filter_by(name=name).update(
-                    {
-                        "content": json.dumps(doc_content),
-                        "timestamp": int(time.time()),
-                    }
-                )
-                db.commit()
-                return self.get_doc_by_name(name)
-        except Exception as e:
-            log.exception(e)
-            return None
-
-    def delete_doc_by_name(self, name: str) -> bool:
-        try:
-            with get_db() as db:
-                db.query(Document).filter_by(name=name).delete()
-                db.commit()
-                return True
-        except Exception:
-            return False
-
-
-Documents = DocumentsTable()

+ 0 - 155
backend/open_webui/apps/webui/routers/documents.py

@@ -1,155 +0,0 @@
-import json
-from typing import Optional
-
-from open_webui.apps.webui.models.documents import (
-    DocumentForm,
-    DocumentResponse,
-    Documents,
-    DocumentUpdateForm,
-)
-from open_webui.constants import ERROR_MESSAGES
-from fastapi import APIRouter, Depends, HTTPException, status
-from pydantic import BaseModel
-from open_webui.utils.utils import get_admin_user, get_verified_user
-
-router = APIRouter()
-
-############################
-# GetDocuments
-############################
-
-
-@router.get("/", response_model=list[DocumentResponse])
-async def get_documents(user=Depends(get_verified_user)):
-    docs = [
-        DocumentResponse(
-            **{
-                **doc.model_dump(),
-                "content": json.loads(doc.content if doc.content else "{}"),
-            }
-        )
-        for doc in Documents.get_docs()
-    ]
-    return docs
-
-
-############################
-# CreateNewDoc
-############################
-
-
-@router.post("/create", response_model=Optional[DocumentResponse])
-async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
-    doc = Documents.get_doc_by_name(form_data.name)
-    if doc is None:
-        doc = Documents.insert_new_doc(user.id, form_data)
-
-        if doc:
-            return DocumentResponse(
-                **{
-                    **doc.model_dump(),
-                    "content": json.loads(doc.content if doc.content else "{}"),
-                }
-            )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.FILE_EXISTS,
-            )
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.NAME_TAG_TAKEN,
-        )
-
-
-############################
-# GetDocByName
-############################
-
-
-@router.get("/doc", response_model=Optional[DocumentResponse])
-async def get_doc_by_name(name: str, user=Depends(get_verified_user)):
-    doc = Documents.get_doc_by_name(name)
-
-    if doc:
-        return DocumentResponse(
-            **{
-                **doc.model_dump(),
-                "content": json.loads(doc.content if doc.content else "{}"),
-            }
-        )
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.NOT_FOUND,
-        )
-
-
-############################
-# TagDocByName
-############################
-
-
-class TagItem(BaseModel):
-    name: str
-
-
-class TagDocumentForm(BaseModel):
-    name: str
-    tags: list[dict]
-
-
-@router.post("/doc/tags", response_model=Optional[DocumentResponse])
-async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_user)):
-    doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
-
-    if doc:
-        return DocumentResponse(
-            **{
-                **doc.model_dump(),
-                "content": json.loads(doc.content if doc.content else "{}"),
-            }
-        )
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.NOT_FOUND,
-        )
-
-
-############################
-# UpdateDocByName
-############################
-
-
-@router.post("/doc/update", response_model=Optional[DocumentResponse])
-async def update_doc_by_name(
-    name: str,
-    form_data: DocumentUpdateForm,
-    user=Depends(get_admin_user),
-):
-    doc = Documents.update_doc_by_name(name, form_data)
-    if doc:
-        return DocumentResponse(
-            **{
-                **doc.model_dump(),
-                "content": json.loads(doc.content if doc.content else "{}"),
-            }
-        )
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.NAME_TAG_TAKEN,
-        )
-
-
-############################
-# DeleteDocByName
-############################
-
-
-@router.delete("/doc/delete", response_model=bool)
-async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
-    result = Documents.delete_doc_by_name(name)
-    return result

+ 0 - 381
backend/open_webui/apps/webui/routers/knowledge.py

@@ -1,381 +0,0 @@
-import json
-from typing import Optional, Union
-from pydantic import BaseModel
-from fastapi import APIRouter, Depends, HTTPException, status
-import logging
-
-from open_webui.apps.webui.models.knowledge import (
-    Knowledges,
-    KnowledgeUpdateForm,
-    KnowledgeForm,
-    KnowledgeResponse,
-)
-from open_webui.apps.webui.models.files import Files, FileModel
-from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
-from open_webui.apps.retrieval.main import process_file, ProcessFileForm
-
-
-from open_webui.constants import ERROR_MESSAGES
-from open_webui.utils.utils import get_admin_user, get_verified_user
-from open_webui.env import SRC_LOG_LEVELS
-
-
-log = logging.getLogger(__name__)
-log.setLevel(SRC_LOG_LEVELS["MODELS"])
-
-router = APIRouter()
-
-############################
-# GetKnowledgeItems
-############################
-
-
-@router.get(
-    "/", response_model=Optional[Union[list[KnowledgeResponse], KnowledgeResponse]]
-)
-async def get_knowledge_items(
-    id: Optional[str] = None, user=Depends(get_verified_user)
-):
-    if id:
-        knowledge = Knowledges.get_knowledge_by_id(id=id)
-
-        if knowledge:
-            return knowledge
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.NOT_FOUND,
-            )
-    else:
-        knowledge_bases = []
-
-        for knowledge in Knowledges.get_knowledge_items():
-
-            files = []
-            if knowledge.data:
-                files = Files.get_file_metadatas_by_ids(
-                    knowledge.data.get("file_ids", [])
-                )
-
-                # Check if all files exist
-                if len(files) != len(knowledge.data.get("file_ids", [])):
-                    missing_files = list(
-                        set(knowledge.data.get("file_ids", []))
-                        - set([file.id for file in files])
-                    )
-                    if missing_files:
-                        data = knowledge.data or {}
-                        file_ids = data.get("file_ids", [])
-
-                        for missing_file in missing_files:
-                            file_ids.remove(missing_file)
-
-                        data["file_ids"] = file_ids
-                        Knowledges.update_knowledge_by_id(
-                            id=knowledge.id, form_data=KnowledgeUpdateForm(data=data)
-                        )
-
-                        files = Files.get_file_metadatas_by_ids(file_ids)
-
-            knowledge_bases.append(
-                KnowledgeResponse(
-                    **knowledge.model_dump(),
-                    files=files,
-                )
-            )
-        return knowledge_bases
-
-
-############################
-# CreateNewKnowledge
-############################
-
-
-@router.post("/create", response_model=Optional[KnowledgeResponse])
-async def create_new_knowledge(form_data: KnowledgeForm, user=Depends(get_admin_user)):
-    knowledge = Knowledges.insert_new_knowledge(user.id, form_data)
-
-    if knowledge:
-        return knowledge
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.FILE_EXISTS,
-        )
-
-
-############################
-# GetKnowledgeById
-############################
-
-
-class KnowledgeFilesResponse(KnowledgeResponse):
-    files: list[FileModel]
-
-
-@router.get("/{id}", response_model=Optional[KnowledgeFilesResponse])
-async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
-    knowledge = Knowledges.get_knowledge_by_id(id=id)
-
-    if knowledge:
-        file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
-        files = Files.get_files_by_ids(file_ids)
-
-        return KnowledgeFilesResponse(
-            **knowledge.model_dump(),
-            files=files,
-        )
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.NOT_FOUND,
-        )
-
-
-############################
-# UpdateKnowledgeById
-############################
-
-
-@router.post("/{id}/update", response_model=Optional[KnowledgeFilesResponse])
-async def update_knowledge_by_id(
-    id: str,
-    form_data: KnowledgeUpdateForm,
-    user=Depends(get_admin_user),
-):
-    knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
-
-    if knowledge:
-        file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
-        files = Files.get_files_by_ids(file_ids)
-
-        return KnowledgeFilesResponse(
-            **knowledge.model_dump(),
-            files=files,
-        )
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.ID_TAKEN,
-        )
-
-
-############################
-# AddFileToKnowledge
-############################
-
-
-class KnowledgeFileIdForm(BaseModel):
-    file_id: str
-
-
-@router.post("/{id}/file/add", response_model=Optional[KnowledgeFilesResponse])
-def add_file_to_knowledge_by_id(
-    id: str,
-    form_data: KnowledgeFileIdForm,
-    user=Depends(get_admin_user),
-):
-    knowledge = Knowledges.get_knowledge_by_id(id=id)
-    file = Files.get_file_by_id(form_data.file_id)
-    if not file:
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.NOT_FOUND,
-        )
-    if not file.data:
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.FILE_NOT_PROCESSED,
-        )
-
-    # Add content to the vector database
-    try:
-        process_file(ProcessFileForm(file_id=form_data.file_id, collection_name=id))
-    except Exception as e:
-        log.debug(e)
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=str(e),
-        )
-
-    if knowledge:
-        data = knowledge.data or {}
-        file_ids = data.get("file_ids", [])
-
-        if form_data.file_id not in file_ids:
-            file_ids.append(form_data.file_id)
-            data["file_ids"] = file_ids
-
-            knowledge = Knowledges.update_knowledge_by_id(
-                id=id, form_data=KnowledgeUpdateForm(data=data)
-            )
-
-            if knowledge:
-                files = Files.get_files_by_ids(file_ids)
-
-                return KnowledgeFilesResponse(
-                    **knowledge.model_dump(),
-                    files=files,
-                )
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_400_BAD_REQUEST,
-                    detail=ERROR_MESSAGES.DEFAULT("knowledge"),
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.DEFAULT("file_id"),
-            )
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.NOT_FOUND,
-        )
-
-
-@router.post("/{id}/file/update", response_model=Optional[KnowledgeFilesResponse])
-def update_file_from_knowledge_by_id(
-    id: str,
-    form_data: KnowledgeFileIdForm,
-    user=Depends(get_admin_user),
-):
-    knowledge = Knowledges.get_knowledge_by_id(id=id)
-    file = Files.get_file_by_id(form_data.file_id)
-    if not file:
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.NOT_FOUND,
-        )
-
-    # Remove content from the vector database
-    VECTOR_DB_CLIENT.delete(
-        collection_name=knowledge.id, filter={"file_id": form_data.file_id}
-    )
-
-    # Add content to the vector database
-    try:
-        process_file(ProcessFileForm(file_id=form_data.file_id, collection_name=id))
-    except Exception as e:
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=str(e),
-        )
-
-    if knowledge:
-        data = knowledge.data or {}
-        file_ids = data.get("file_ids", [])
-
-        files = Files.get_files_by_ids(file_ids)
-
-        return KnowledgeFilesResponse(
-            **knowledge.model_dump(),
-            files=files,
-        )
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.NOT_FOUND,
-        )
-
-
-############################
-# RemoveFileFromKnowledge
-############################
-
-
-@router.post("/{id}/file/remove", response_model=Optional[KnowledgeFilesResponse])
-def remove_file_from_knowledge_by_id(
-    id: str,
-    form_data: KnowledgeFileIdForm,
-    user=Depends(get_admin_user),
-):
-    knowledge = Knowledges.get_knowledge_by_id(id=id)
-    file = Files.get_file_by_id(form_data.file_id)
-    if not file:
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.NOT_FOUND,
-        )
-
-    # Remove content from the vector database
-    VECTOR_DB_CLIENT.delete(
-        collection_name=knowledge.id, filter={"file_id": form_data.file_id}
-    )
-
-    result = VECTOR_DB_CLIENT.query(
-        collection_name=knowledge.id,
-        filter={"file_id": form_data.file_id},
-    )
-
-    Files.delete_file_by_id(form_data.file_id)
-
-    if knowledge:
-        data = knowledge.data or {}
-        file_ids = data.get("file_ids", [])
-
-        if form_data.file_id in file_ids:
-            file_ids.remove(form_data.file_id)
-            data["file_ids"] = file_ids
-
-            knowledge = Knowledges.update_knowledge_by_id(
-                id=id, form_data=KnowledgeUpdateForm(data=data)
-            )
-
-            if knowledge:
-                files = Files.get_files_by_ids(file_ids)
-
-                return KnowledgeFilesResponse(
-                    **knowledge.model_dump(),
-                    files=files,
-                )
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_400_BAD_REQUEST,
-                    detail=ERROR_MESSAGES.DEFAULT("knowledge"),
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.DEFAULT("file_id"),
-            )
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.NOT_FOUND,
-        )
-
-
-############################
-# ResetKnowledgeById
-############################
-
-
-@router.post("/{id}/reset", response_model=Optional[KnowledgeResponse])
-async def reset_knowledge_by_id(id: str, user=Depends(get_admin_user)):
-    try:
-        VECTOR_DB_CLIENT.delete_collection(collection_name=id)
-    except Exception as e:
-        log.debug(e)
-        pass
-
-    knowledge = Knowledges.update_knowledge_by_id(
-        id=id, form_data=KnowledgeUpdateForm(data={"file_ids": []})
-    )
-    return knowledge
-
-
-############################
-# DeleteKnowledgeById
-############################
-
-
-@router.delete("/{id}/delete", response_model=bool)
-async def delete_knowledge_by_id(id: str, user=Depends(get_admin_user)):
-    try:
-        VECTOR_DB_CLIENT.delete_collection(collection_name=id)
-    except Exception as e:
-        log.debug(e)
-        pass
-    result = Knowledges.delete_knowledge_by_id(id=id)
-    return result

+ 0 - 104
backend/open_webui/apps/webui/routers/models.py

@@ -1,104 +0,0 @@
-from typing import Optional
-
-from open_webui.apps.webui.models.models import (
-    ModelForm,
-    ModelModel,
-    ModelResponse,
-    Models,
-)
-from open_webui.constants import ERROR_MESSAGES
-from fastapi import APIRouter, Depends, HTTPException, Request, status
-from open_webui.utils.utils import get_admin_user, get_verified_user
-
-router = APIRouter()
-
-###########################
-# getModels
-###########################
-
-
-@router.get("/", response_model=list[ModelResponse])
-async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
-    if id:
-        model = Models.get_model_by_id(id)
-        if model:
-            return [model]
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.NOT_FOUND,
-            )
-    else:
-        return Models.get_all_models()
-
-
-############################
-# AddNewModel
-############################
-
-
-@router.post("/add", response_model=Optional[ModelModel])
-async def add_new_model(
-    request: Request,
-    form_data: ModelForm,
-    user=Depends(get_admin_user),
-):
-    if form_data.id in request.app.state.MODELS:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
-        )
-    else:
-        model = Models.insert_new_model(form_data, user.id)
-
-        if model:
-            return model
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.DEFAULT(),
-            )
-
-
-############################
-# UpdateModelById
-############################
-
-
-@router.post("/update", response_model=Optional[ModelModel])
-async def update_model_by_id(
-    request: Request,
-    id: str,
-    form_data: ModelForm,
-    user=Depends(get_admin_user),
-):
-    model = Models.get_model_by_id(id)
-    if model:
-        model = Models.update_model_by_id(id, form_data)
-        return model
-    else:
-        if form_data.id in request.app.state.MODELS:
-            model = Models.insert_new_model(form_data, user.id)
-            if model:
-                return model
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_401_UNAUTHORIZED,
-                    detail=ERROR_MESSAGES.DEFAULT(),
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.DEFAULT(),
-            )
-
-
-############################
-# DeleteModelById
-############################
-
-
-@router.delete("/delete", response_model=bool)
-async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
-    result = Models.delete_model_by_id(id)
-    return result

+ 0 - 90
backend/open_webui/apps/webui/routers/prompts.py

@@ -1,90 +0,0 @@
-from typing import Optional
-
-from open_webui.apps.webui.models.prompts import PromptForm, PromptModel, Prompts
-from open_webui.constants import ERROR_MESSAGES
-from fastapi import APIRouter, Depends, HTTPException, status
-from open_webui.utils.utils import get_admin_user, get_verified_user
-
-router = APIRouter()
-
-############################
-# GetPrompts
-############################
-
-
-@router.get("/", response_model=list[PromptModel])
-async def get_prompts(user=Depends(get_verified_user)):
-    return Prompts.get_prompts()
-
-
-############################
-# CreateNewPrompt
-############################
-
-
-@router.post("/create", response_model=Optional[PromptModel])
-async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
-    prompt = Prompts.get_prompt_by_command(form_data.command)
-    if prompt is None:
-        prompt = Prompts.insert_new_prompt(user.id, form_data)
-
-        if prompt:
-            return prompt
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.DEFAULT(),
-        )
-    raise HTTPException(
-        status_code=status.HTTP_400_BAD_REQUEST,
-        detail=ERROR_MESSAGES.COMMAND_TAKEN,
-    )
-
-
-############################
-# GetPromptByCommand
-############################
-
-
-@router.get("/command/{command}", response_model=Optional[PromptModel])
-async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
-    prompt = Prompts.get_prompt_by_command(f"/{command}")
-
-    if prompt:
-        return prompt
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.NOT_FOUND,
-        )
-
-
-############################
-# UpdatePromptByCommand
-############################
-
-
-@router.post("/command/{command}/update", response_model=Optional[PromptModel])
-async def update_prompt_by_command(
-    command: str,
-    form_data: PromptForm,
-    user=Depends(get_admin_user),
-):
-    prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
-    if prompt:
-        return prompt
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-        )
-
-
-############################
-# DeletePromptByCommand
-############################
-
-
-@router.delete("/command/{command}/delete", response_model=bool)
-async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
-    result = Prompts.delete_prompt_by_command(f"/{command}")
-    return result

+ 422 - 47
backend/open_webui/config.py

@@ -10,7 +10,7 @@ from urllib.parse import urlparse
 import chromadb
 import chromadb
 import requests
 import requests
 import yaml
 import yaml
-from open_webui.apps.webui.internal.db import Base, get_db
+from open_webui.internal.db import Base, get_db
 from open_webui.env import (
 from open_webui.env import (
     OPEN_WEBUI_DIR,
     OPEN_WEBUI_DIR,
     DATA_DIR,
     DATA_DIR,
@@ -20,6 +20,8 @@ from open_webui.env import (
     WEBUI_FAVICON_URL,
     WEBUI_FAVICON_URL,
     WEBUI_NAME,
     WEBUI_NAME,
     log,
     log,
+    DATABASE_URL,
+    OFFLINE_MODE,
 )
 )
 from pydantic import BaseModel
 from pydantic import BaseModel
 from sqlalchemy import JSON, Column, DateTime, Integer, func
 from sqlalchemy import JSON, Column, DateTime, Integer, func
@@ -264,6 +266,13 @@ class AppConfig:
 # WEBUI_AUTH (Required for security)
 # WEBUI_AUTH (Required for security)
 ####################################
 ####################################
 
 
+ENABLE_API_KEY = PersistentConfig(
+    "ENABLE_API_KEY",
+    "auth.api_key.enable",
+    os.environ.get("ENABLE_API_KEY", "True").lower() == "true",
+)
+
+
 JWT_EXPIRES_IN = PersistentConfig(
 JWT_EXPIRES_IN = PersistentConfig(
     "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
     "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
 )
 )
@@ -406,12 +415,24 @@ OAUTH_EMAIL_CLAIM = PersistentConfig(
     os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
     os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
 )
 )
 
 
+OAUTH_GROUPS_CLAIM = PersistentConfig(
+    "OAUTH_GROUPS_CLAIM",
+    "oauth.oidc.group_claim",
+    os.environ.get("OAUTH_GROUP_CLAIM", "groups"),
+)
+
 ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig(
 ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig(
     "ENABLE_OAUTH_ROLE_MANAGEMENT",
     "ENABLE_OAUTH_ROLE_MANAGEMENT",
     "oauth.enable_role_mapping",
     "oauth.enable_role_mapping",
     os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true",
     os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true",
 )
 )
 
 
+ENABLE_OAUTH_GROUP_MANAGEMENT = PersistentConfig(
+    "ENABLE_OAUTH_GROUP_MANAGEMENT",
+    "oauth.enable_group_mapping",
+    os.environ.get("ENABLE_OAUTH_GROUP_MANAGEMENT", "False").lower() == "true",
+)
+
 OAUTH_ROLES_CLAIM = PersistentConfig(
 OAUTH_ROLES_CLAIM = PersistentConfig(
     "OAUTH_ROLES_CLAIM",
     "OAUTH_ROLES_CLAIM",
     "oauth.roles_claim",
     "oauth.roles_claim",
@@ -433,6 +454,15 @@ OAUTH_ADMIN_ROLES = PersistentConfig(
     [role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
     [role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
 )
 )
 
 
+OAUTH_ALLOWED_DOMAINS = PersistentConfig(
+    "OAUTH_ALLOWED_DOMAINS",
+    "oauth.allowed_domains",
+    [
+        domain.strip()
+        for domain in os.environ.get("OAUTH_ALLOWED_DOMAINS", "*").split(",")
+    ],
+)
+
 
 
 def load_oauth_providers():
 def load_oauth_providers():
     OAUTH_PROVIDERS.clear()
     OAUTH_PROVIDERS.clear()
@@ -587,6 +617,12 @@ OLLAMA_API_BASE_URL = os.environ.get(
 )
 )
 
 
 OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
 OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
+if OLLAMA_BASE_URL:
+    # Remove trailing slash
+    OLLAMA_BASE_URL = (
+        OLLAMA_BASE_URL[:-1] if OLLAMA_BASE_URL.endswith("/") else OLLAMA_BASE_URL
+    )
+
 
 
 K8S_FLAG = os.environ.get("K8S_FLAG", "")
 K8S_FLAG = os.environ.get("K8S_FLAG", "")
 USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false")
 USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false")
@@ -618,6 +654,12 @@ OLLAMA_BASE_URLS = PersistentConfig(
     "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS
     "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS
 )
 )
 
 
+OLLAMA_API_CONFIGS = PersistentConfig(
+    "OLLAMA_API_CONFIGS",
+    "ollama.api_configs",
+    {},
+)
+
 ####################################
 ####################################
 # OPENAI_API
 # OPENAI_API
 ####################################
 ####################################
@@ -658,15 +700,20 @@ OPENAI_API_BASE_URLS = PersistentConfig(
     "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS
     "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS
 )
 )
 
 
-OPENAI_API_KEY = ""
+OPENAI_API_CONFIGS = PersistentConfig(
+    "OPENAI_API_CONFIGS",
+    "openai.api_configs",
+    {},
+)
 
 
+# Get the actual OpenAI API key based on the base URL
+OPENAI_API_KEY = ""
 try:
 try:
     OPENAI_API_KEY = OPENAI_API_KEYS.value[
     OPENAI_API_KEY = OPENAI_API_KEYS.value[
         OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1")
         OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1")
     ]
     ]
 except Exception:
 except Exception:
     pass
     pass
-
 OPENAI_API_BASE_URL = "https://api.openai.com/v1"
 OPENAI_API_BASE_URL = "https://api.openai.com/v1"
 
 
 ####################################
 ####################################
@@ -689,6 +736,7 @@ ENABLE_LOGIN_FORM = PersistentConfig(
     os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true",
     os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true",
 )
 )
 
 
+
 DEFAULT_LOCALE = PersistentConfig(
 DEFAULT_LOCALE = PersistentConfig(
     "DEFAULT_LOCALE",
     "DEFAULT_LOCALE",
     "ui.default_locale",
     "ui.default_locale",
@@ -733,18 +781,47 @@ DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig(
     ],
     ],
 )
 )
 
 
+MODEL_ORDER_LIST = PersistentConfig(
+    "MODEL_ORDER_LIST",
+    "ui.model_order_list",
+    [],
+)
+
 DEFAULT_USER_ROLE = PersistentConfig(
 DEFAULT_USER_ROLE = PersistentConfig(
     "DEFAULT_USER_ROLE",
     "DEFAULT_USER_ROLE",
     "ui.default_user_role",
     "ui.default_user_role",
     os.getenv("DEFAULT_USER_ROLE", "pending"),
     os.getenv("DEFAULT_USER_ROLE", "pending"),
 )
 )
 
 
-USER_PERMISSIONS_CHAT_DELETION = (
-    os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true"
+USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS = (
+    os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS", "False").lower()
+    == "true"
 )
 )
 
 
-USER_PERMISSIONS_CHAT_EDITING = (
-    os.environ.get("USER_PERMISSIONS_CHAT_EDITING", "True").lower() == "true"
+USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS = (
+    os.environ.get("USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS", "False").lower()
+    == "true"
+)
+
+USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS = (
+    os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS", "False").lower()
+    == "true"
+)
+
+USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS = (
+    os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true"
+)
+
+USER_PERMISSIONS_CHAT_FILE_UPLOAD = (
+    os.environ.get("USER_PERMISSIONS_CHAT_FILE_UPLOAD", "True").lower() == "true"
+)
+
+USER_PERMISSIONS_CHAT_DELETE = (
+    os.environ.get("USER_PERMISSIONS_CHAT_DELETE", "True").lower() == "true"
+)
+
+USER_PERMISSIONS_CHAT_EDIT = (
+    os.environ.get("USER_PERMISSIONS_CHAT_EDIT", "True").lower() == "true"
 )
 )
 
 
 USER_PERMISSIONS_CHAT_TEMPORARY = (
 USER_PERMISSIONS_CHAT_TEMPORARY = (
@@ -753,13 +830,20 @@ USER_PERMISSIONS_CHAT_TEMPORARY = (
 
 
 USER_PERMISSIONS = PersistentConfig(
 USER_PERMISSIONS = PersistentConfig(
     "USER_PERMISSIONS",
     "USER_PERMISSIONS",
-    "ui.user_permissions",
+    "user.permissions",
     {
     {
+        "workspace": {
+            "models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS,
+            "knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS,
+            "prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS,
+            "tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS,
+        },
         "chat": {
         "chat": {
-            "deletion": USER_PERMISSIONS_CHAT_DELETION,
-            "editing": USER_PERMISSIONS_CHAT_EDITING,
+            "file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD,
+            "delete": USER_PERMISSIONS_CHAT_DELETE,
+            "edit": USER_PERMISSIONS_CHAT_EDIT,
             "temporary": USER_PERMISSIONS_CHAT_TEMPORARY,
             "temporary": USER_PERMISSIONS_CHAT_TEMPORARY,
-        }
+        },
     },
     },
 )
 )
 
 
@@ -785,18 +869,6 @@ DEFAULT_ARENA_MODEL = {
     },
     },
 }
 }
 
 
-ENABLE_MODEL_FILTER = PersistentConfig(
-    "ENABLE_MODEL_FILTER",
-    "model_filter.enable",
-    os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true",
-)
-MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
-MODEL_FILTER_LIST = PersistentConfig(
-    "MODEL_FILTER_LIST",
-    "model_filter.list",
-    [model.strip() for model in MODEL_FILTER_LIST.split(";")],
-)
-
 WEBHOOK_URL = PersistentConfig(
 WEBHOOK_URL = PersistentConfig(
     "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "")
     "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "")
 )
 )
@@ -910,25 +982,155 @@ TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
     os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""),
     os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""),
 )
 )
 
 
+DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
+
+Examples of titles:
+📉 Stock Market Trends
+🍪 Perfect Chocolate Chip Recipe
+Evolution of Music Streaming
+Remote Work Productivity Tips
+Artificial Intelligence in Healthcare
+🎮 Video Game Development Insights
+
+<chat_history>
+{{MESSAGES:END:2}}
+</chat_history>"""
+
+
 TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
 TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
     "TAGS_GENERATION_PROMPT_TEMPLATE",
     "TAGS_GENERATION_PROMPT_TEMPLATE",
     "task.tags.prompt_template",
     "task.tags.prompt_template",
     os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""),
     os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""),
 )
 )
 
 
-ENABLE_SEARCH_QUERY = PersistentConfig(
-    "ENABLE_SEARCH_QUERY",
-    "task.search.enable",
-    os.environ.get("ENABLE_SEARCH_QUERY", "True").lower() == "true",
+DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE = """### Task:
+Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags.
+
+### Guidelines:
+- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education)
+- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation
+- If content is too short (less than 3 messages) or too diverse, use only ["General"]
+- Use the chat's primary language; default to English if multilingual
+- Prioritize accuracy over specificity
+
+### Output:
+JSON format: { "tags": ["tag1", "tag2", "tag3"] }
+
+### Chat History:
+<chat_history>
+{{MESSAGES:END:6}}
+</chat_history>"""
+
+ENABLE_TAGS_GENERATION = PersistentConfig(
+    "ENABLE_TAGS_GENERATION",
+    "task.tags.enable",
+    os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true",
+)
+
+
+ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig(
+    "ENABLE_SEARCH_QUERY_GENERATION",
+    "task.query.search.enable",
+    os.environ.get("ENABLE_SEARCH_QUERY_GENERATION", "True").lower() == "true",
+)
+
+ENABLE_RETRIEVAL_QUERY_GENERATION = PersistentConfig(
+    "ENABLE_RETRIEVAL_QUERY_GENERATION",
+    "task.query.retrieval.enable",
+    os.environ.get("ENABLE_RETRIEVAL_QUERY_GENERATION", "True").lower() == "true",
 )
 )
 
 
 
 
-SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
-    "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
-    "task.search.prompt_template",
-    os.environ.get("SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", ""),
+QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
+    "QUERY_GENERATION_PROMPT_TEMPLATE",
+    "task.query.prompt_template",
+    os.environ.get("QUERY_GENERATION_PROMPT_TEMPLATE", ""),
 )
 )
 
 
+DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE = """### Task:
+Analyze the chat history to determine the necessity of generating search queries, in the given language. By default, **prioritize generating 1-3 broad and relevant search queries** unless it is absolutely certain that no additional information is required. The aim is to retrieve comprehensive, updated, and valuable information even with minimal uncertainty. If no search is unequivocally needed, return an empty list.
+
+### Guidelines:
+- Respond **EXCLUSIVELY** with a JSON object. Any form of extra commentary, explanation, or additional text is strictly prohibited.
+- When generating search queries, respond in the format: { "queries": ["query1", "query2"] }, ensuring each query is distinct, concise, and relevant to the topic.
+- If and only if it is entirely certain that no useful results can be retrieved by a search, return: { "queries": [] }.
+- Err on the side of suggesting search queries if there is **any chance** they might provide useful or updated information.
+- Be concise and focused on composing high-quality search queries, avoiding unnecessary elaboration, commentary, or assumptions.
+- Today's date is: {{CURRENT_DATE}}.
+- Always prioritize providing actionable and broad queries that maximize informational coverage.
+
+### Output:
+Strictly return in JSON format: 
+{
+  "queries": ["query1", "query2"]
+}
+
+### Chat History:
+<chat_history>
+{{MESSAGES:END:6}}
+</chat_history>
+"""
+
+ENABLE_AUTOCOMPLETE_GENERATION = PersistentConfig(
+    "ENABLE_AUTOCOMPLETE_GENERATION",
+    "task.autocomplete.enable",
+    os.environ.get("ENABLE_AUTOCOMPLETE_GENERATION", "True").lower() == "true",
+)
+
+AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = PersistentConfig(
+    "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH",
+    "task.autocomplete.input_max_length",
+    int(os.environ.get("AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH", "-1")),
+)
+
+AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
+    "AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE",
+    "task.autocomplete.prompt_template",
+    os.environ.get("AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE", ""),
+)
+
+
+DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = """### Task:
+You are an autocompletion system. Continue the text in `<text>` based on the **completion type** in `<type>` and the given language.  
+
+### **Instructions**:
+1. Analyze `<text>` for context and meaning.  
+2. Use `<type>` to guide your output:  
+   - **General**: Provide a natural, concise continuation.  
+   - **Search Query**: Complete as if generating a realistic search query.  
+3. Start as if you are directly continuing `<text>`. Do **not** repeat, paraphrase, or respond as a model. Simply complete the text.  
+4. Ensure the continuation:
+   - Flows naturally from `<text>`.  
+   - Avoids repetition, overexplaining, or unrelated ideas.  
+5. If unsure, return: `{ "text": "" }`.  
+
+### **Output Rules**:
+- Respond only in JSON format: `{ "text": "<your_completion>" }`.
+
+### **Examples**:
+#### Example 1:  
+Input:  
+<type>General</type>  
+<text>The sun was setting over the horizon, painting the sky</text>  
+Output:  
+{ "text": "with vibrant shades of orange and pink." }
+
+#### Example 2:  
+Input:  
+<type>Search Query</type>  
+<text>Top-rated restaurants in</text>  
+Output:  
+{ "text": "New York City for Italian cuisine." }  
+
+---
+### Context:
+<chat_history>
+{{MESSAGES:END:6}}
+</chat_history>
+<type>{{TYPE}}</type>  
+<text>{{PROMPT}}</text>  
+#### Output:
+"""
 
 
 TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
 TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
     "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
     "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
@@ -937,6 +1139,19 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
 )
 )
 
 
 
 
+DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text."""
+
+
+DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE = """Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
+
+Message: ```{{prompt}}```"""
+
+DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}"
+
+Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
+
+Responses from models: {{responses}}"""
+
 ####################################
 ####################################
 # Vector Database
 # Vector Database
 ####################################
 ####################################
@@ -949,6 +1164,8 @@ CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
 CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
 CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
 CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
 CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
 CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
 CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
+CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get("CHROMA_CLIENT_AUTH_PROVIDER", "")
+CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get("CHROMA_CLIENT_AUTH_CREDENTIALS", "")
 # Comma-separated list of header=value pairs
 # Comma-separated list of header=value pairs
 CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
 CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
 if CHROMA_HTTP_HEADERS:
 if CHROMA_HTTP_HEADERS:
@@ -966,6 +1183,21 @@ MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
 
 
 # Qdrant
 # Qdrant
 QDRANT_URI = os.environ.get("QDRANT_URI", None)
 QDRANT_URI = os.environ.get("QDRANT_URI", None)
+QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)
+
+# OpenSearch
+OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
+OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", True)
+OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False)
+OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None)
+OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None)
+
+# Pgvector
+PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", DATABASE_URL)
+if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"):
+    raise ValueError(
+        "Pgvector requires setting PGVECTOR_DB_URL or using Postgres with vector extension as the primary database."
+    )
 
 
 ####################################
 ####################################
 # Information Retrieval (RAG)
 # Information Retrieval (RAG)
@@ -1045,11 +1277,12 @@ RAG_EMBEDDING_MODEL = PersistentConfig(
 log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}")
 log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}")
 
 
 RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
 RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
-    os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
+    not OFFLINE_MODE
+    and os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true"
 )
 )
 
 
 RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
 RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
-    os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
+    os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true"
 )
 )
 
 
 RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
 RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
@@ -1070,11 +1303,12 @@ if RAG_RERANKING_MODEL.value != "":
     log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
     log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
 
 
 RAG_RERANKING_MODEL_AUTO_UPDATE = (
 RAG_RERANKING_MODEL_AUTO_UPDATE = (
-    os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
+    not OFFLINE_MODE
+    and os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true"
 )
 )
 
 
 RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
 RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
-    os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
+    os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true"
 )
 )
 
 
 
 
@@ -1102,21 +1336,32 @@ CHUNK_OVERLAP = PersistentConfig(
     int(os.environ.get("CHUNK_OVERLAP", "100")),
     int(os.environ.get("CHUNK_OVERLAP", "100")),
 )
 )
 
 
-DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules.
+DEFAULT_RAG_TEMPLATE = """### Task:
+Respond to the user query using the provided context, incorporating inline citations in the format [source_id] **only when the <source_id> tag is explicitly provided** in the context.
+
+### Guidelines:
+- If you don't know the answer, clearly state that.
+- If uncertain, ask the user for clarification.
+- Respond in the same language as the user's query.
+- If the context is unreadable or of poor quality, inform the user and provide the best possible answer.
+- If the answer isn't present in the context but you possess the knowledge, explain this to the user and provide the answer using your own understanding.
+- **Only include inline citations using [source_id] when a <source_id> tag is explicitly provided in the context.**  
+- Do not cite if the <source_id> tag is not provided in the context.  
+- Do not use XML tags in your response.
+- Ensure citations are concise and directly related to the information provided.
+
+### Example of Citation:
+If the user asks about a specific topic and the information is found in "whitepaper.pdf" with a provided <source_id>, the response should include the citation like so:  
+* "According to the study, the proposed method increases efficiency by 20% [whitepaper.pdf]."
+If no <source_id> is present, the response should omit the citation.
+
+### Output:
+Provide a clear and direct response to the user's query, including inline citations in the format [source_id] only when the <source_id> tag is present in the context.
 
 
 <context>
 <context>
 {{CONTEXT}}
 {{CONTEXT}}
 </context>
 </context>
 
 
-<rules>
-- If you don't know, just say so.
-- If you are not sure, ask for clarification.
-- Answer in the same language as the user query.
-- If the context appears unreadable or of poor quality, tell the user then answer as best as you can.
-- If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge.
-- Answer directly and without using xml tags.
-</rules>
-
 <user_query>
 <user_query>
 {{QUERY}}
 {{QUERY}}
 </user_query>
 </user_query>
@@ -1139,6 +1384,19 @@ RAG_OPENAI_API_KEY = PersistentConfig(
     os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
     os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
 )
 )
 
 
+RAG_OLLAMA_BASE_URL = PersistentConfig(
+    "RAG_OLLAMA_BASE_URL",
+    "rag.ollama.url",
+    os.getenv("RAG_OLLAMA_BASE_URL", OLLAMA_BASE_URL),
+)
+
+RAG_OLLAMA_API_KEY = PersistentConfig(
+    "RAG_OLLAMA_API_KEY",
+    "rag.ollama.key",
+    os.getenv("RAG_OLLAMA_API_KEY", ""),
+)
+
+
 ENABLE_RAG_LOCAL_WEB_FETCH = (
 ENABLE_RAG_LOCAL_WEB_FETCH = (
     os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
     os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
 )
 )
@@ -1149,6 +1407,12 @@ YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
     os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","),
     os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","),
 )
 )
 
 
+YOUTUBE_LOADER_PROXY_URL = PersistentConfig(
+    "YOUTUBE_LOADER_PROXY_URL",
+    "rag.youtube_loader_proxy_url",
+    os.getenv("YOUTUBE_LOADER_PROXY_URL", ""),
+)
+
 
 
 ENABLE_RAG_WEB_SEARCH = PersistentConfig(
 ENABLE_RAG_WEB_SEARCH = PersistentConfig(
     "ENABLE_RAG_WEB_SEARCH",
     "ENABLE_RAG_WEB_SEARCH",
@@ -1198,6 +1462,18 @@ BRAVE_SEARCH_API_KEY = PersistentConfig(
     os.getenv("BRAVE_SEARCH_API_KEY", ""),
     os.getenv("BRAVE_SEARCH_API_KEY", ""),
 )
 )
 
 
+KAGI_SEARCH_API_KEY = PersistentConfig(
+    "KAGI_SEARCH_API_KEY",
+    "rag.web.search.kagi_search_api_key",
+    os.getenv("KAGI_SEARCH_API_KEY", ""),
+)
+
+MOJEEK_SEARCH_API_KEY = PersistentConfig(
+    "MOJEEK_SEARCH_API_KEY",
+    "rag.web.search.mojeek_search_api_key",
+    os.getenv("MOJEEK_SEARCH_API_KEY", ""),
+)
+
 SERPSTACK_API_KEY = PersistentConfig(
 SERPSTACK_API_KEY = PersistentConfig(
     "SERPSTACK_API_KEY",
     "SERPSTACK_API_KEY",
     "rag.web.search.serpstack_api_key",
     "rag.web.search.serpstack_api_key",
@@ -1228,6 +1504,12 @@ TAVILY_API_KEY = PersistentConfig(
     os.getenv("TAVILY_API_KEY", ""),
     os.getenv("TAVILY_API_KEY", ""),
 )
 )
 
 
+JINA_API_KEY = PersistentConfig(
+    "JINA_API_KEY",
+    "rag.web.search.jina_api_key",
+    os.getenv("JINA_API_KEY", ""),
+)
+
 SEARCHAPI_API_KEY = PersistentConfig(
 SEARCHAPI_API_KEY = PersistentConfig(
     "SEARCHAPI_API_KEY",
     "SEARCHAPI_API_KEY",
     "rag.web.search.searchapi_api_key",
     "rag.web.search.searchapi_api_key",
@@ -1240,6 +1522,21 @@ SEARCHAPI_ENGINE = PersistentConfig(
     os.getenv("SEARCHAPI_ENGINE", ""),
     os.getenv("SEARCHAPI_ENGINE", ""),
 )
 )
 
 
+BING_SEARCH_V7_ENDPOINT = PersistentConfig(
+    "BING_SEARCH_V7_ENDPOINT",
+    "rag.web.search.bing_search_v7_endpoint",
+    os.environ.get(
+        "BING_SEARCH_V7_ENDPOINT", "https://api.bing.microsoft.com/v7.0/search"
+    ),
+)
+
+BING_SEARCH_V7_SUBSCRIPTION_KEY = PersistentConfig(
+    "BING_SEARCH_V7_SUBSCRIPTION_KEY",
+    "rag.web.search.bing_search_v7_subscription_key",
+    os.environ.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", ""),
+)
+
+
 RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
 RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
     "RAG_WEB_SEARCH_RESULT_COUNT",
     "RAG_WEB_SEARCH_RESULT_COUNT",
     "rag.web.search.result_count",
     "rag.web.search.result_count",
@@ -1291,7 +1588,7 @@ AUTOMATIC1111_CFG_SCALE = PersistentConfig(
 
 
 
 
 AUTOMATIC1111_SAMPLER = PersistentConfig(
 AUTOMATIC1111_SAMPLER = PersistentConfig(
-    "AUTOMATIC1111_SAMPLERE",
+    "AUTOMATIC1111_SAMPLER",
     "image_generation.automatic1111.sampler",
     "image_generation.automatic1111.sampler",
     (
     (
         os.environ.get("AUTOMATIC1111_SAMPLER")
         os.environ.get("AUTOMATIC1111_SAMPLER")
@@ -1316,6 +1613,12 @@ COMFYUI_BASE_URL = PersistentConfig(
     os.getenv("COMFYUI_BASE_URL", ""),
     os.getenv("COMFYUI_BASE_URL", ""),
 )
 )
 
 
+COMFYUI_API_KEY = PersistentConfig(
+    "COMFYUI_API_KEY",
+    "image_generation.comfyui.api_key",
+    os.getenv("COMFYUI_API_KEY", ""),
+)
+
 COMFYUI_DEFAULT_WORKFLOW = """
 COMFYUI_DEFAULT_WORKFLOW = """
 {
 {
   "3": {
   "3": {
@@ -1477,7 +1780,8 @@ WHISPER_MODEL = PersistentConfig(
 
 
 WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
 WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
 WHISPER_MODEL_AUTO_UPDATE = (
 WHISPER_MODEL_AUTO_UPDATE = (
-    os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
+    not OFFLINE_MODE
+    and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
 )
 )
 
 
 
 
@@ -1560,3 +1864,74 @@ AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT = PersistentConfig(
         "AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", "audio-24khz-160kbitrate-mono-mp3"
         "AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", "audio-24khz-160kbitrate-mono-mp3"
     ),
     ),
 )
 )
+
+
+####################################
+# LDAP
+####################################
+
+ENABLE_LDAP = PersistentConfig(
+    "ENABLE_LDAP",
+    "ldap.enable",
+    os.environ.get("ENABLE_LDAP", "false").lower() == "true",
+)
+
+LDAP_SERVER_LABEL = PersistentConfig(
+    "LDAP_SERVER_LABEL",
+    "ldap.server.label",
+    os.environ.get("LDAP_SERVER_LABEL", "LDAP Server"),
+)
+
+LDAP_SERVER_HOST = PersistentConfig(
+    "LDAP_SERVER_HOST",
+    "ldap.server.host",
+    os.environ.get("LDAP_SERVER_HOST", "localhost"),
+)
+
+LDAP_SERVER_PORT = PersistentConfig(
+    "LDAP_SERVER_PORT",
+    "ldap.server.port",
+    int(os.environ.get("LDAP_SERVER_PORT", "389")),
+)
+
+LDAP_ATTRIBUTE_FOR_USERNAME = PersistentConfig(
+    "LDAP_ATTRIBUTE_FOR_USERNAME",
+    "ldap.server.attribute_for_username",
+    os.environ.get("LDAP_ATTRIBUTE_FOR_USERNAME", "uid"),
+)
+
+LDAP_APP_DN = PersistentConfig(
+    "LDAP_APP_DN", "ldap.server.app_dn", os.environ.get("LDAP_APP_DN", "")
+)
+
+LDAP_APP_PASSWORD = PersistentConfig(
+    "LDAP_APP_PASSWORD",
+    "ldap.server.app_password",
+    os.environ.get("LDAP_APP_PASSWORD", ""),
+)
+
+LDAP_SEARCH_BASE = PersistentConfig(
+    "LDAP_SEARCH_BASE", "ldap.server.users_dn", os.environ.get("LDAP_SEARCH_BASE", "")
+)
+
+LDAP_SEARCH_FILTERS = PersistentConfig(
+    "LDAP_SEARCH_FILTER",
+    "ldap.server.search_filter",
+    os.environ.get("LDAP_SEARCH_FILTER", ""),
+)
+
+LDAP_USE_TLS = PersistentConfig(
+    "LDAP_USE_TLS",
+    "ldap.server.use_tls",
+    os.environ.get("LDAP_USE_TLS", "True").lower() == "true",
+)
+
+LDAP_CA_CERT_FILE = PersistentConfig(
+    "LDAP_CA_CERT_FILE",
+    "ldap.server.ca_cert_file",
+    os.environ.get("LDAP_CA_CERT_FILE", ""),
+)
+
+LDAP_CIPHERS = PersistentConfig(
+    "LDAP_CIPHERS", "ldap.server.ciphers", os.environ.get("LDAP_CIPHERS", "ALL")
+)

+ 3 - 0
backend/open_webui/constants.py

@@ -62,6 +62,7 @@ class ERROR_MESSAGES(str, Enum):
     NOT_FOUND = "We could not find what you're looking for :/"
     NOT_FOUND = "We could not find what you're looking for :/"
     USER_NOT_FOUND = "We could not find what you're looking for :/"
     USER_NOT_FOUND = "We could not find what you're looking for :/"
     API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
     API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
+    API_KEY_NOT_ALLOWED = "Use of API key is not enabled in the environment."
 
 
     MALICIOUS = "Unusual activities detected, please try again in a few minutes."
     MALICIOUS = "Unusual activities detected, please try again in a few minutes."
 
 
@@ -75,6 +76,7 @@ class ERROR_MESSAGES(str, Enum):
     OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
     OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
     OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
     OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
     CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
     CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
+    API_KEY_CREATION_NOT_ALLOWED = "API key creation is not allowed in the environment."
 
 
     EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
     EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
 
 
@@ -111,5 +113,6 @@ class TASKS(str, Enum):
     TAGS_GENERATION = "tags_generation"
     TAGS_GENERATION = "tags_generation"
     EMOJI_GENERATION = "emoji_generation"
     EMOJI_GENERATION = "emoji_generation"
     QUERY_GENERATION = "query_generation"
     QUERY_GENERATION = "query_generation"
+    AUTOCOMPLETE_GENERATION = "autocomplete_generation"
     FUNCTION_CALLING = "function_calling"
     FUNCTION_CALLING = "function_calling"
     MOA_RESPONSE_GENERATION = "moa_response_generation"
     MOA_RESPONSE_GENERATION = "moa_response_generation"

+ 14 - 2
backend/open_webui/env.py

@@ -195,6 +195,15 @@ CHANGELOG = changelog_json
 
 
 SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
 SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
 
 
+####################################
+# ENABLE_FORWARD_USER_INFO_HEADERS
+####################################
+
+ENABLE_FORWARD_USER_INFO_HEADERS = (
+    os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true"
+)
+
+
 ####################################
 ####################################
 # WEBUI_BUILD_HASH
 # WEBUI_BUILD_HASH
 ####################################
 ####################################
@@ -320,6 +329,9 @@ WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
 )
 )
 WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
 WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
 
 
+BYPASS_MODEL_ACCESS_CONTROL = (
+    os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
+)
 
 
 ####################################
 ####################################
 # WEBUI_SECRET_KEY
 # WEBUI_SECRET_KEY
@@ -364,7 +376,7 @@ else:
         AIOHTTP_CLIENT_TIMEOUT = 300
         AIOHTTP_CLIENT_TIMEOUT = 300
 
 
 AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
 AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
-    "AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "3"
+    "AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""
 )
 )
 
 
 if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
 if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
@@ -375,7 +387,7 @@ else:
             AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
             AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
         )
         )
     except Exception:
     except Exception:
-        AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 3
+        AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 5
 
 
 ####################################
 ####################################
 # OFFLINE_MODE
 # OFFLINE_MODE

+ 316 - 0
backend/open_webui/functions.py

@@ -0,0 +1,316 @@
+import logging
+import sys
+import inspect
+import json
+
+from pydantic import BaseModel
+from typing import AsyncGenerator, Generator, Iterator
+from fastapi import (
+    Depends,
+    FastAPI,
+    File,
+    Form,
+    HTTPException,
+    Request,
+    UploadFile,
+    status,
+)
+from starlette.responses import Response, StreamingResponse
+
+
+from open_webui.socket.main import (
+    get_event_call,
+    get_event_emitter,
+)
+
+
+from open_webui.models.functions import Functions
+from open_webui.models.models import Models
+
+from open_webui.utils.plugin import load_function_module_by_id
+from open_webui.utils.tools import get_tools
+from open_webui.utils.access_control import has_access
+
+from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
+
+from open_webui.utils.misc import (
+    add_or_update_system_message,
+    get_last_user_message,
+    prepend_to_first_user_message_content,
+    openai_chat_chunk_message_template,
+    openai_chat_completion_message_template,
+)
+from open_webui.utils.payload import (
+    apply_model_params_to_body_openai,
+    apply_model_system_prompt_to_body,
+)
+
+
+logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MAIN"])
+
+
+def get_function_module_by_id(request: Request, pipe_id: str):
+    # Check if function is already loaded
+    if pipe_id not in request.app.state.FUNCTIONS:
+        function_module, _, _ = load_function_module_by_id(pipe_id)
+        request.app.state.FUNCTIONS[pipe_id] = function_module
+    else:
+        function_module = request.app.state.FUNCTIONS[pipe_id]
+
+    if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
+        valves = Functions.get_function_valves_by_id(pipe_id)
+        function_module.valves = function_module.Valves(**(valves if valves else {}))
+    return function_module
+
+
+async def get_function_models(request):
+    pipes = Functions.get_functions_by_type("pipe", active_only=True)
+    pipe_models = []
+
+    for pipe in pipes:
+        function_module = get_function_module_by_id(request, pipe.id)
+
+        # Check if function is a manifold
+        if hasattr(function_module, "pipes"):
+            sub_pipes = []
+
+            # Check if pipes is a function or a list
+
+            try:
+                if callable(function_module.pipes):
+                    sub_pipes = function_module.pipes()
+                else:
+                    sub_pipes = function_module.pipes
+            except Exception as e:
+                log.exception(e)
+                sub_pipes = []
+
+            log.debug(
+                f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
+            )
+
+            for p in sub_pipes:
+                sub_pipe_id = f'{pipe.id}.{p["id"]}'
+                sub_pipe_name = p["name"]
+
+                if hasattr(function_module, "name"):
+                    sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
+
+                pipe_flag = {"type": pipe.type}
+
+                pipe_models.append(
+                    {
+                        "id": sub_pipe_id,
+                        "name": sub_pipe_name,
+                        "object": "model",
+                        "created": pipe.created_at,
+                        "owned_by": "openai",
+                        "pipe": pipe_flag,
+                    }
+                )
+        else:
+            pipe_flag = {"type": "pipe"}
+
+            log.debug(
+                f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
+            )
+
+            pipe_models.append(
+                {
+                    "id": pipe.id,
+                    "name": pipe.name,
+                    "object": "model",
+                    "created": pipe.created_at,
+                    "owned_by": "openai",
+                    "pipe": pipe_flag,
+                }
+            )
+
+    return pipe_models
+
+
+async def generate_function_chat_completion(
+    request, form_data, user, models: dict = {}
+):
+    async def execute_pipe(pipe, params):
+        if inspect.iscoroutinefunction(pipe):
+            return await pipe(**params)
+        else:
+            return pipe(**params)
+
+    async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
+        if isinstance(res, str):
+            return res
+        if isinstance(res, Generator):
+            return "".join(map(str, res))
+        if isinstance(res, AsyncGenerator):
+            return "".join([str(stream) async for stream in res])
+
+    def process_line(form_data: dict, line):
+        if isinstance(line, BaseModel):
+            line = line.model_dump_json()
+            line = f"data: {line}"
+        if isinstance(line, dict):
+            line = f"data: {json.dumps(line)}"
+
+        try:
+            line = line.decode("utf-8")
+        except Exception:
+            pass
+
+        if line.startswith("data:"):
+            return f"{line}\n\n"
+        else:
+            line = openai_chat_chunk_message_template(form_data["model"], line)
+            return f"data: {json.dumps(line)}\n\n"
+
+    def get_pipe_id(form_data: dict) -> str:
+        pipe_id = form_data["model"]
+        if "." in pipe_id:
+            pipe_id, _ = pipe_id.split(".", 1)
+        return pipe_id
+
+    def get_function_params(function_module, form_data, user, extra_params=None):
+        if extra_params is None:
+            extra_params = {}
+
+        pipe_id = get_pipe_id(form_data)
+
+        # Get the signature of the function
+        sig = inspect.signature(function_module.pipe)
+        params = {"body": form_data} | {
+            k: v for k, v in extra_params.items() if k in sig.parameters
+        }
+
+        if "__user__" in params and hasattr(function_module, "UserValves"):
+            user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
+            try:
+                params["__user__"]["valves"] = function_module.UserValves(**user_valves)
+            except Exception as e:
+                log.exception(e)
+                params["__user__"]["valves"] = function_module.UserValves()
+
+        return params
+
+    model_id = form_data.get("model")
+    model_info = Models.get_model_by_id(model_id)
+
+    metadata = form_data.pop("metadata", {})
+
+    files = metadata.get("files", [])
+    tool_ids = metadata.get("tool_ids", [])
+    # Check if tool_ids is None
+    if tool_ids is None:
+        tool_ids = []
+
+    __event_emitter__ = None
+    __event_call__ = None
+    __task__ = None
+    __task_body__ = None
+
+    if metadata:
+        if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
+            __event_emitter__ = get_event_emitter(metadata)
+            __event_call__ = get_event_call(metadata)
+        __task__ = metadata.get("task", None)
+        __task_body__ = metadata.get("task_body", None)
+
+    extra_params = {
+        "__event_emitter__": __event_emitter__,
+        "__event_call__": __event_call__,
+        "__task__": __task__,
+        "__task_body__": __task_body__,
+        "__files__": files,
+        "__user__": {
+            "id": user.id,
+            "email": user.email,
+            "name": user.name,
+            "role": user.role,
+        },
+        "__metadata__": metadata,
+        "__request__": request,
+    }
+    extra_params["__tools__"] = get_tools(
+        request,
+        tool_ids,
+        user,
+        {
+            **extra_params,
+            "__model__": models.get(form_data["model"], None),
+            "__messages__": form_data["messages"],
+            "__files__": files,
+        },
+    )
+
+    if model_info:
+        if model_info.base_model_id:
+            form_data["model"] = model_info.base_model_id
+
+        params = model_info.params.model_dump()
+        form_data = apply_model_params_to_body_openai(params, form_data)
+        form_data = apply_model_system_prompt_to_body(params, form_data, user)
+
+    pipe_id = get_pipe_id(form_data)
+    function_module = get_function_module_by_id(request, pipe_id)
+
+    pipe = function_module.pipe
+    params = get_function_params(function_module, form_data, user, extra_params)
+
+    if form_data.get("stream", False):
+
+        async def stream_content():
+            try:
+                res = await execute_pipe(pipe, params)
+
+                # Directly return if the response is a StreamingResponse
+                if isinstance(res, StreamingResponse):
+                    async for data in res.body_iterator:
+                        yield data
+                    return
+                if isinstance(res, dict):
+                    yield f"data: {json.dumps(res)}\n\n"
+                    return
+
+            except Exception as e:
+                log.error(f"Error: {e}")
+                yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
+                return
+
+            if isinstance(res, str):
+                message = openai_chat_chunk_message_template(form_data["model"], res)
+                yield f"data: {json.dumps(message)}\n\n"
+
+            if isinstance(res, Iterator):
+                for line in res:
+                    yield process_line(form_data, line)
+
+            if isinstance(res, AsyncGenerator):
+                async for line in res:
+                    yield process_line(form_data, line)
+
+            if isinstance(res, str) or isinstance(res, Generator):
+                finish_message = openai_chat_chunk_message_template(
+                    form_data["model"], ""
+                )
+                finish_message["choices"][0]["finish_reason"] = "stop"
+                yield f"data: {json.dumps(finish_message)}\n\n"
+                yield "data: [DONE]"
+
+        return StreamingResponse(stream_content(), media_type="text/event-stream")
+    else:
+        try:
+            res = await execute_pipe(pipe, params)
+
+        except Exception as e:
+            log.error(f"Error: {e}")
+            return {"error": {"detail": str(e)}}
+
+        if isinstance(res, StreamingResponse) or isinstance(res, dict):
+            return res
+        if isinstance(res, BaseModel):
+            return res.model_dump()
+
+        message = await get_message_content(res)
+        return openai_chat_completion_message_template(form_data["model"], message)

+ 1 - 1
backend/open_webui/apps/webui/internal/db.py → backend/open_webui/internal/db.py

@@ -3,7 +3,7 @@ import logging
 from contextlib import contextmanager
 from contextlib import contextmanager
 from typing import Any, Optional
 from typing import Any, Optional
 
 
-from open_webui.apps.webui.internal.wrappers import register_connection
+from open_webui.internal.wrappers import register_connection
 from open_webui.env import (
 from open_webui.env import (
     OPEN_WEBUI_DIR,
     OPEN_WEBUI_DIR,
     DATABASE_URL,
     DATABASE_URL,

+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/001_initial_schema.py → backend/open_webui/internal/migrations/001_initial_schema.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/002_add_local_sharing.py → backend/open_webui/internal/migrations/002_add_local_sharing.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/003_add_auth_api_key.py → backend/open_webui/internal/migrations/003_add_auth_api_key.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/004_add_archived.py → backend/open_webui/internal/migrations/004_add_archived.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/005_add_updated_at.py → backend/open_webui/internal/migrations/005_add_updated_at.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py → backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/007_add_user_last_active_at.py → backend/open_webui/internal/migrations/007_add_user_last_active_at.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/008_add_memory.py → backend/open_webui/internal/migrations/008_add_memory.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/009_add_models.py → backend/open_webui/internal/migrations/009_add_models.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py → backend/open_webui/internal/migrations/010_migrate_modelfiles_to_models.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/011_add_user_settings.py → backend/open_webui/internal/migrations/011_add_user_settings.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/012_add_tools.py → backend/open_webui/internal/migrations/012_add_tools.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/013_add_user_info.py → backend/open_webui/internal/migrations/013_add_user_info.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/014_add_files.py → backend/open_webui/internal/migrations/014_add_files.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/015_add_functions.py → backend/open_webui/internal/migrations/015_add_functions.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/016_add_valves_and_is_active.py → backend/open_webui/internal/migrations/016_add_valves_and_is_active.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/017_add_user_oauth_sub.py → backend/open_webui/internal/migrations/017_add_user_oauth_sub.py


+ 0 - 0
backend/open_webui/apps/webui/internal/migrations/018_add_function_is_global.py → backend/open_webui/internal/migrations/018_add_function_is_global.py


+ 0 - 0
backend/open_webui/apps/webui/internal/wrappers.py → backend/open_webui/internal/wrappers.py


文件差异内容过多而无法显示
+ 642 - 1958
backend/open_webui/main.py


+ 1 - 1
backend/open_webui/migrations/env.py

@@ -1,7 +1,7 @@
 from logging.config import fileConfig
 from logging.config import fileConfig
 
 
 from alembic import context
 from alembic import context
-from open_webui.apps.webui.models.auths import Auth
+from open_webui.models.auths import Auth
 from open_webui.env import DATABASE_URL
 from open_webui.env import DATABASE_URL
 from sqlalchemy import engine_from_config, pool
 from sqlalchemy import engine_from_config, pool
 
 

+ 1 - 1
backend/open_webui/migrations/script.py.mako

@@ -9,7 +9,7 @@ from typing import Sequence, Union
 
 
 from alembic import op
 from alembic import op
 import sqlalchemy as sa
 import sqlalchemy as sa
-import open_webui.apps.webui.internal.db
+import open_webui.internal.db
 ${imports if imports else ""}
 ${imports if imports else ""}
 
 
 # revision identifiers, used by Alembic.
 # revision identifiers, used by Alembic.

+ 2 - 2
backend/open_webui/migrations/versions/7e5b5dc7342b_init.py

@@ -11,8 +11,8 @@ from typing import Sequence, Union
 import sqlalchemy as sa
 import sqlalchemy as sa
 from alembic import op
 from alembic import op
 
 
-import open_webui.apps.webui.internal.db
-from open_webui.apps.webui.internal.db import JSONField
+import open_webui.internal.db
+from open_webui.internal.db import JSONField
 from open_webui.migrations.util import get_existing_tables
 from open_webui.migrations.util import get_existing_tables
 
 
 # revision identifiers, used by Alembic.
 # revision identifiers, used by Alembic.

+ 85 - 0
backend/open_webui/migrations/versions/922e7a387820_add_group_table.py

@@ -0,0 +1,85 @@
+"""Add group table
+
+Revision ID: 922e7a387820
+Revises: 4ace53fd72c8
+Create Date: 2024-11-14 03:00:00.000000
+
+"""
+
+from alembic import op
+import sqlalchemy as sa
+
+revision = "922e7a387820"
+down_revision = "4ace53fd72c8"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    op.create_table(
+        "group",
+        sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
+        sa.Column("user_id", sa.Text(), nullable=True),
+        sa.Column("name", sa.Text(), nullable=True),
+        sa.Column("description", sa.Text(), nullable=True),
+        sa.Column("data", sa.JSON(), nullable=True),
+        sa.Column("meta", sa.JSON(), nullable=True),
+        sa.Column("permissions", sa.JSON(), nullable=True),
+        sa.Column("user_ids", sa.JSON(), nullable=True),
+        sa.Column("created_at", sa.BigInteger(), nullable=True),
+        sa.Column("updated_at", sa.BigInteger(), nullable=True),
+    )
+
+    # Add 'access_control' column to 'model' table
+    op.add_column(
+        "model",
+        sa.Column("access_control", sa.JSON(), nullable=True),
+    )
+
+    # Add 'is_active' column to 'model' table
+    op.add_column(
+        "model",
+        sa.Column(
+            "is_active",
+            sa.Boolean(),
+            nullable=False,
+            server_default=sa.sql.expression.true(),
+        ),
+    )
+
+    # Add 'access_control' column to 'knowledge' table
+    op.add_column(
+        "knowledge",
+        sa.Column("access_control", sa.JSON(), nullable=True),
+    )
+
+    # Add 'access_control' column to 'prompt' table
+    op.add_column(
+        "prompt",
+        sa.Column("access_control", sa.JSON(), nullable=True),
+    )
+
+    # Add 'access_control' column to 'tools' table
+    op.add_column(
+        "tool",
+        sa.Column("access_control", sa.JSON(), nullable=True),
+    )
+
+
+def downgrade():
+    op.drop_table("group")
+
+    # Drop 'access_control' column from 'model' table
+    op.drop_column("model", "access_control")
+
+    # Drop 'is_active' column from 'model' table
+    op.drop_column("model", "is_active")
+
+    # Drop 'access_control' column from 'knowledge' table
+    op.drop_column("knowledge", "access_control")
+
+    # Drop 'access_control' column from 'prompt' table
+    op.drop_column("prompt", "access_control")
+
+    # Drop 'access_control' column from 'tools' table
+    op.drop_column("tool", "access_control")

+ 8 - 3
backend/open_webui/apps/webui/models/auths.py → backend/open_webui/models/auths.py

@@ -2,12 +2,12 @@ import logging
 import uuid
 import uuid
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, get_db
-from open_webui.apps.webui.models.users import UserModel, Users
+from open_webui.internal.db import Base, get_db
+from open_webui.models.users import UserModel, Users
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from pydantic import BaseModel
 from pydantic import BaseModel
 from sqlalchemy import Boolean, Column, String, Text
 from sqlalchemy import Boolean, Column, String, Text
-from open_webui.utils.utils import verify_password
+from open_webui.utils.auth import verify_password
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -64,6 +64,11 @@ class SigninForm(BaseModel):
     password: str
     password: str
 
 
 
 
+class LdapForm(BaseModel):
+    user: str
+    password: str
+
+
 class ProfileImageUrlForm(BaseModel):
 class ProfileImageUrlForm(BaseModel):
     profile_image_url: str
     profile_image_url: str
 
 

+ 15 - 8
backend/open_webui/apps/webui/models/chats.py → backend/open_webui/models/chats.py

@@ -3,8 +3,8 @@ import time
 import uuid
 import uuid
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, get_db
-from open_webui.apps.webui.models.tags import TagModel, Tag, Tags
+from open_webui.internal.db import Base, get_db
+from open_webui.models.tags import TagModel, Tag, Tags
 
 
 
 
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
@@ -203,15 +203,22 @@ class ChatTable:
     def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
     def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
         try:
         try:
             with get_db() as db:
             with get_db() as db:
-                print("update_shared_chat_by_id")
                 chat = db.get(Chat, chat_id)
                 chat = db.get(Chat, chat_id)
-                print(chat)
-                chat.title = chat.title
-                chat.chat = chat.chat
+                shared_chat = (
+                    db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first()
+                )
+
+                if shared_chat is None:
+                    return self.insert_shared_chat_by_chat_id(chat_id)
+
+                shared_chat.title = chat.title
+                shared_chat.chat = chat.chat
+
+                shared_chat.updated_at = int(time.time())
                 db.commit()
                 db.commit()
-                db.refresh(chat)
+                db.refresh(shared_chat)
 
 
-                return self.get_chat_by_id(chat.share_id)
+                return ChatModel.model_validate(shared_chat)
         except Exception:
         except Exception:
             return None
             return None
 
 

+ 2 - 2
backend/open_webui/apps/webui/models/feedbacks.py → backend/open_webui/models/feedbacks.py

@@ -3,8 +3,8 @@ import time
 import uuid
 import uuid
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, get_db
-from open_webui.apps.webui.models.chats import Chats
+from open_webui.internal.db import Base, get_db
+from open_webui.models.chats import Chats
 
 
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict

+ 1 - 1
backend/open_webui/apps/webui/models/files.py → backend/open_webui/models/files.py

@@ -2,7 +2,7 @@ import logging
 import time
 import time
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, JSONField, get_db
+from open_webui.internal.db import Base, JSONField, get_db
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
 from sqlalchemy import BigInteger, Column, String, Text, JSON
 from sqlalchemy import BigInteger, Column, String, Text, JSON

+ 2 - 2
backend/open_webui/apps/webui/models/folders.py → backend/open_webui/models/folders.py

@@ -3,8 +3,8 @@ import time
 import uuid
 import uuid
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, get_db
-from open_webui.apps.webui.models.chats import Chats
+from open_webui.internal.db import Base, get_db
+from open_webui.models.chats import Chats
 
 
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict

+ 2 - 2
backend/open_webui/apps/webui/models/functions.py → backend/open_webui/models/functions.py

@@ -2,8 +2,8 @@ import logging
 import time
 import time
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, JSONField, get_db
-from open_webui.apps.webui.models.users import Users
+from open_webui.internal.db import Base, JSONField, get_db
+from open_webui.models.users import Users
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
 from sqlalchemy import BigInteger, Boolean, Column, String, Text
 from sqlalchemy import BigInteger, Boolean, Column, String, Text

+ 186 - 0
backend/open_webui/models/groups.py

@@ -0,0 +1,186 @@
+import json
+import logging
+import time
+from typing import Optional
+import uuid
+
+from open_webui.internal.db import Base, get_db
+from open_webui.env import SRC_LOG_LEVELS
+
+from open_webui.models.files import FileMetadataResponse
+
+
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, String, Text, JSON, func
+
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MODELS"])
+
+####################
+# UserGroup DB Schema
+####################
+
+
+class Group(Base):
+    __tablename__ = "group"
+
+    id = Column(Text, unique=True, primary_key=True)
+    user_id = Column(Text)
+
+    name = Column(Text)
+    description = Column(Text)
+
+    data = Column(JSON, nullable=True)
+    meta = Column(JSON, nullable=True)
+
+    permissions = Column(JSON, nullable=True)
+    user_ids = Column(JSON, nullable=True)
+
+    created_at = Column(BigInteger)
+    updated_at = Column(BigInteger)
+
+
+class GroupModel(BaseModel):
+    model_config = ConfigDict(from_attributes=True)
+    id: str
+    user_id: str
+
+    name: str
+    description: str
+
+    data: Optional[dict] = None
+    meta: Optional[dict] = None
+
+    permissions: Optional[dict] = None
+    user_ids: list[str] = []
+
+    created_at: int  # timestamp in epoch
+    updated_at: int  # timestamp in epoch
+
+
+####################
+# Forms
+####################
+
+
+class GroupResponse(BaseModel):
+    id: str
+    user_id: str
+    name: str
+    description: str
+    permissions: Optional[dict] = None
+    data: Optional[dict] = None
+    meta: Optional[dict] = None
+    user_ids: list[str] = []
+    created_at: int  # timestamp in epoch
+    updated_at: int  # timestamp in epoch
+
+
+class GroupForm(BaseModel):
+    name: str
+    description: str
+
+
+class GroupUpdateForm(GroupForm):
+    permissions: Optional[dict] = None
+    user_ids: Optional[list[str]] = None
+    admin_ids: Optional[list[str]] = None
+
+
+class GroupTable:
+    def insert_new_group(
+        self, user_id: str, form_data: GroupForm
+    ) -> Optional[GroupModel]:
+        with get_db() as db:
+            group = GroupModel(
+                **{
+                    **form_data.model_dump(),
+                    "id": str(uuid.uuid4()),
+                    "user_id": user_id,
+                    "created_at": int(time.time()),
+                    "updated_at": int(time.time()),
+                }
+            )
+
+            try:
+                result = Group(**group.model_dump())
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+                if result:
+                    return GroupModel.model_validate(result)
+                else:
+                    return None
+
+            except Exception:
+                return None
+
+    def get_groups(self) -> list[GroupModel]:
+        with get_db() as db:
+            return [
+                GroupModel.model_validate(group)
+                for group in db.query(Group).order_by(Group.updated_at.desc()).all()
+            ]
+
+    def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
+        with get_db() as db:
+            return [
+                GroupModel.model_validate(group)
+                for group in db.query(Group)
+                .filter(
+                    func.json_array_length(Group.user_ids) > 0
+                )  # Ensure array exists
+                .filter(
+                    Group.user_ids.cast(String).like(f'%"{user_id}"%')
+                )  # String-based check
+                .order_by(Group.updated_at.desc())
+                .all()
+            ]
+
+    def get_group_by_id(self, id: str) -> Optional[GroupModel]:
+        try:
+            with get_db() as db:
+                group = db.query(Group).filter_by(id=id).first()
+                return GroupModel.model_validate(group) if group else None
+        except Exception:
+            return None
+
+    def update_group_by_id(
+        self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
+    ) -> Optional[GroupModel]:
+        try:
+            with get_db() as db:
+                db.query(Group).filter_by(id=id).update(
+                    {
+                        **form_data.model_dump(exclude_none=True),
+                        "updated_at": int(time.time()),
+                    }
+                )
+                db.commit()
+                return self.get_group_by_id(id=id)
+        except Exception as e:
+            log.exception(e)
+            return None
+
+    def delete_group_by_id(self, id: str) -> bool:
+        try:
+            with get_db() as db:
+                db.query(Group).filter_by(id=id).delete()
+                db.commit()
+                return True
+        except Exception:
+            return False
+
+    def delete_all_groups(self) -> bool:
+        with get_db() as db:
+            try:
+                db.query(Group).delete()
+                db.commit()
+
+                return True
+            except Exception:
+                return False
+
+
+Groups = GroupTable()

+ 78 - 25
backend/open_webui/apps/webui/models/knowledge.py → backend/open_webui/models/knowledge.py

@@ -4,15 +4,17 @@ import time
 from typing import Optional
 from typing import Optional
 import uuid
 import uuid
 
 
-from open_webui.apps.webui.internal.db import Base, get_db
+from open_webui.internal.db import Base, get_db
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
-from open_webui.apps.webui.models.files import FileMetadataResponse
+from open_webui.models.files import FileMetadataResponse
+from open_webui.models.users import Users, UserResponse
 
 
 
 
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
 from sqlalchemy import BigInteger, Column, String, Text, JSON
 from sqlalchemy import BigInteger, Column, String, Text, JSON
 
 
+from open_webui.utils.access_control import has_access
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -34,6 +36,23 @@ class Knowledge(Base):
     data = Column(JSON, nullable=True)
     data = Column(JSON, nullable=True)
     meta = Column(JSON, nullable=True)
     meta = Column(JSON, nullable=True)
 
 
+    access_control = Column(JSON, nullable=True)  # Controls data access levels.
+    # Defines access control rules for this entry.
+    # - `None`: Public access, available to all users with the "user" role.
+    # - `{}`: Private access, restricted exclusively to the owner.
+    # - Custom permissions: Specific access control for reading and writing;
+    #   Can specify group or user-level restrictions:
+    #   {
+    #      "read": {
+    #          "group_ids": ["group_id1", "group_id2"],
+    #          "user_ids":  ["user_id1", "user_id2"]
+    #      },
+    #      "write": {
+    #          "group_ids": ["group_id1", "group_id2"],
+    #          "user_ids":  ["user_id1", "user_id2"]
+    #      }
+    #   }
+
     created_at = Column(BigInteger)
     created_at = Column(BigInteger)
     updated_at = Column(BigInteger)
     updated_at = Column(BigInteger)
 
 
@@ -50,6 +69,8 @@ class KnowledgeModel(BaseModel):
     data: Optional[dict] = None
     data: Optional[dict] = None
     meta: Optional[dict] = None
     meta: Optional[dict] = None
 
 
+    access_control: Optional[dict] = None
+
     created_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
     updated_at: int  # timestamp in epoch
     updated_at: int  # timestamp in epoch
 
 
@@ -59,15 +80,15 @@ class KnowledgeModel(BaseModel):
 ####################
 ####################
 
 
 
 
-class KnowledgeResponse(BaseModel):
-    id: str
-    name: str
-    description: str
-    data: Optional[dict] = None
-    meta: Optional[dict] = None
-    created_at: int  # timestamp in epoch
-    updated_at: int  # timestamp in epoch
+class KnowledgeUserModel(KnowledgeModel):
+    user: Optional[UserResponse] = None
+
+
+class KnowledgeResponse(KnowledgeModel):
+    files: Optional[list[FileMetadataResponse | dict]] = None
 
 
+
+class KnowledgeUserResponse(KnowledgeUserModel):
     files: Optional[list[FileMetadataResponse | dict]] = None
     files: Optional[list[FileMetadataResponse | dict]] = None
 
 
 
 
@@ -75,12 +96,7 @@ class KnowledgeForm(BaseModel):
     name: str
     name: str
     description: str
     description: str
     data: Optional[dict] = None
     data: Optional[dict] = None
-
-
-class KnowledgeUpdateForm(BaseModel):
-    name: Optional[str] = None
-    description: Optional[str] = None
-    data: Optional[dict] = None
+    access_control: Optional[dict] = None
 
 
 
 
 class KnowledgeTable:
 class KnowledgeTable:
@@ -110,14 +126,33 @@ class KnowledgeTable:
             except Exception:
             except Exception:
                 return None
                 return None
 
 
-    def get_knowledge_items(self) -> list[KnowledgeModel]:
+    def get_knowledge_bases(self) -> list[KnowledgeUserModel]:
         with get_db() as db:
         with get_db() as db:
-            return [
-                KnowledgeModel.model_validate(knowledge)
-                for knowledge in db.query(Knowledge)
-                .order_by(Knowledge.updated_at.desc())
-                .all()
-            ]
+            knowledge_bases = []
+            for knowledge in (
+                db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
+            ):
+                user = Users.get_user_by_id(knowledge.user_id)
+                knowledge_bases.append(
+                    KnowledgeUserModel.model_validate(
+                        {
+                            **KnowledgeModel.model_validate(knowledge).model_dump(),
+                            "user": user.model_dump() if user else None,
+                        }
+                    )
+                )
+            return knowledge_bases
+
+    def get_knowledge_bases_by_user_id(
+        self, user_id: str, permission: str = "write"
+    ) -> list[KnowledgeUserModel]:
+        knowledge_bases = self.get_knowledge_bases()
+        return [
+            knowledge_base
+            for knowledge_base in knowledge_bases
+            if knowledge_base.user_id == user_id
+            or has_access(user_id, permission, knowledge_base.access_control)
+        ]
 
 
     def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
     def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
         try:
         try:
@@ -128,14 +163,32 @@ class KnowledgeTable:
             return None
             return None
 
 
     def update_knowledge_by_id(
     def update_knowledge_by_id(
-        self, id: str, form_data: KnowledgeUpdateForm, overwrite: bool = False
+        self, id: str, form_data: KnowledgeForm, overwrite: bool = False
+    ) -> Optional[KnowledgeModel]:
+        try:
+            with get_db() as db:
+                knowledge = self.get_knowledge_by_id(id=id)
+                db.query(Knowledge).filter_by(id=id).update(
+                    {
+                        **form_data.model_dump(),
+                        "updated_at": int(time.time()),
+                    }
+                )
+                db.commit()
+                return self.get_knowledge_by_id(id=id)
+        except Exception as e:
+            log.exception(e)
+            return None
+
+    def update_knowledge_data_by_id(
+        self, id: str, data: dict
     ) -> Optional[KnowledgeModel]:
     ) -> Optional[KnowledgeModel]:
         try:
         try:
             with get_db() as db:
             with get_db() as db:
                 knowledge = self.get_knowledge_by_id(id=id)
                 knowledge = self.get_knowledge_by_id(id=id)
                 db.query(Knowledge).filter_by(id=id).update(
                 db.query(Knowledge).filter_by(id=id).update(
                     {
                     {
-                        **form_data.model_dump(exclude_none=True),
+                        "data": data,
                         "updated_at": int(time.time()),
                         "updated_at": int(time.time()),
                     }
                     }
                 )
                 )

+ 1 - 1
backend/open_webui/apps/webui/models/memories.py → backend/open_webui/models/memories.py

@@ -2,7 +2,7 @@ import time
 import uuid
 import uuid
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, get_db
+from open_webui.internal.db import Base, get_db
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
 from sqlalchemy import BigInteger, Column, String, Text
 from sqlalchemy import BigInteger, Column, String, Text
 
 

+ 104 - 9
backend/open_webui/apps/webui/models/models.py → backend/open_webui/models/models.py

@@ -2,10 +2,21 @@ import logging
 import time
 import time
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, JSONField, get_db
+from open_webui.internal.db import Base, JSONField, get_db
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
+
+from open_webui.models.users import Users, UserResponse
+
+
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
-from sqlalchemy import BigInteger, Column, Text
+
+from sqlalchemy import or_, and_, func
+from sqlalchemy.dialects import postgresql, sqlite
+from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
+
+
+from open_webui.utils.access_control import has_access
+
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -67,6 +78,25 @@ class Model(Base):
         Holds a JSON encoded blob of metadata, see `ModelMeta`.
         Holds a JSON encoded blob of metadata, see `ModelMeta`.
     """
     """
 
 
+    access_control = Column(JSON, nullable=True)  # Controls data access levels.
+    # Defines access control rules for this entry.
+    # - `None`: Public access, available to all users with the "user" role.
+    # - `{}`: Private access, restricted exclusively to the owner.
+    # - Custom permissions: Specific access control for reading and writing;
+    #   Can specify group or user-level restrictions:
+    #   {
+    #      "read": {
+    #          "group_ids": ["group_id1", "group_id2"],
+    #          "user_ids":  ["user_id1", "user_id2"]
+    #      },
+    #      "write": {
+    #          "group_ids": ["group_id1", "group_id2"],
+    #          "user_ids":  ["user_id1", "user_id2"]
+    #      }
+    #   }
+
+    is_active = Column(Boolean, default=True)
+
     updated_at = Column(BigInteger)
     updated_at = Column(BigInteger)
     created_at = Column(BigInteger)
     created_at = Column(BigInteger)
 
 
@@ -80,6 +110,9 @@ class ModelModel(BaseModel):
     params: ModelParams
     params: ModelParams
     meta: ModelMeta
     meta: ModelMeta
 
 
+    access_control: Optional[dict] = None
+
+    is_active: bool
     updated_at: int  # timestamp in epoch
     updated_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
 
 
@@ -91,12 +124,12 @@ class ModelModel(BaseModel):
 ####################
 ####################
 
 
 
 
-class ModelResponse(BaseModel):
-    id: str
-    name: str
-    meta: ModelMeta
-    updated_at: int  # timestamp in epoch
-    created_at: int  # timestamp in epoch
+class ModelUserResponse(ModelModel):
+    user: Optional[UserResponse] = None
+
+
+class ModelResponse(ModelModel):
+    pass
 
 
 
 
 class ModelForm(BaseModel):
 class ModelForm(BaseModel):
@@ -105,6 +138,8 @@ class ModelForm(BaseModel):
     name: str
     name: str
     meta: ModelMeta
     meta: ModelMeta
     params: ModelParams
     params: ModelParams
+    access_control: Optional[dict] = None
+    is_active: bool = True
 
 
 
 
 class ModelsTable:
 class ModelsTable:
@@ -138,6 +173,39 @@ class ModelsTable:
         with get_db() as db:
         with get_db() as db:
             return [ModelModel.model_validate(model) for model in db.query(Model).all()]
             return [ModelModel.model_validate(model) for model in db.query(Model).all()]
 
 
+    def get_models(self) -> list[ModelUserResponse]:
+        with get_db() as db:
+            models = []
+            for model in db.query(Model).filter(Model.base_model_id != None).all():
+                user = Users.get_user_by_id(model.user_id)
+                models.append(
+                    ModelUserResponse.model_validate(
+                        {
+                            **ModelModel.model_validate(model).model_dump(),
+                            "user": user.model_dump() if user else None,
+                        }
+                    )
+                )
+            return models
+
+    def get_base_models(self) -> list[ModelModel]:
+        with get_db() as db:
+            return [
+                ModelModel.model_validate(model)
+                for model in db.query(Model).filter(Model.base_model_id == None).all()
+            ]
+
+    def get_models_by_user_id(
+        self, user_id: str, permission: str = "write"
+    ) -> list[ModelUserResponse]:
+        models = self.get_models()
+        return [
+            model
+            for model in models
+            if model.user_id == user_id
+            or has_access(user_id, permission, model.access_control)
+        ]
+
     def get_model_by_id(self, id: str) -> Optional[ModelModel]:
     def get_model_by_id(self, id: str) -> Optional[ModelModel]:
         try:
         try:
             with get_db() as db:
             with get_db() as db:
@@ -146,6 +214,23 @@ class ModelsTable:
         except Exception:
         except Exception:
             return None
             return None
 
 
+    def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
+        with get_db() as db:
+            try:
+                is_active = db.query(Model).filter_by(id=id).first().is_active
+
+                db.query(Model).filter_by(id=id).update(
+                    {
+                        "is_active": not is_active,
+                        "updated_at": int(time.time()),
+                    }
+                )
+                db.commit()
+
+                return self.get_model_by_id(id)
+            except Exception:
+                return None
+
     def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
     def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
         try:
         try:
             with get_db() as db:
             with get_db() as db:
@@ -153,7 +238,7 @@ class ModelsTable:
                 result = (
                 result = (
                     db.query(Model)
                     db.query(Model)
                     .filter_by(id=id)
                     .filter_by(id=id)
-                    .update(model.model_dump(exclude={"id"}, exclude_none=True))
+                    .update(model.model_dump(exclude={"id"}))
                 )
                 )
                 db.commit()
                 db.commit()
 
 
@@ -175,5 +260,15 @@ class ModelsTable:
         except Exception:
         except Exception:
             return False
             return False
 
 
+    def delete_all_models(self) -> bool:
+        try:
+            with get_db() as db:
+                db.query(Model).delete()
+                db.commit()
+
+                return True
+        except Exception:
+            return False
+
 
 
 Models = ModelsTable()
 Models = ModelsTable()

+ 59 - 10
backend/open_webui/apps/webui/models/prompts.py → backend/open_webui/models/prompts.py

@@ -1,9 +1,13 @@
 import time
 import time
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, get_db
+from open_webui.internal.db import Base, get_db
+from open_webui.models.users import Users, UserResponse
+
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
-from sqlalchemy import BigInteger, Column, String, Text
+from sqlalchemy import BigInteger, Column, String, Text, JSON
+
+from open_webui.utils.access_control import has_access
 
 
 ####################
 ####################
 # Prompts DB Schema
 # Prompts DB Schema
@@ -19,6 +23,23 @@ class Prompt(Base):
     content = Column(Text)
     content = Column(Text)
     timestamp = Column(BigInteger)
     timestamp = Column(BigInteger)
 
 
+    access_control = Column(JSON, nullable=True)  # Controls data access levels.
+    # Defines access control rules for this entry.
+    # - `None`: Public access, available to all users with the "user" role.
+    # - `{}`: Private access, restricted exclusively to the owner.
+    # - Custom permissions: Specific access control for reading and writing;
+    #   Can specify group or user-level restrictions:
+    #   {
+    #      "read": {
+    #          "group_ids": ["group_id1", "group_id2"],
+    #          "user_ids":  ["user_id1", "user_id2"]
+    #      },
+    #      "write": {
+    #          "group_ids": ["group_id1", "group_id2"],
+    #          "user_ids":  ["user_id1", "user_id2"]
+    #      }
+    #   }
+
 
 
 class PromptModel(BaseModel):
 class PromptModel(BaseModel):
     command: str
     command: str
@@ -27,6 +48,7 @@ class PromptModel(BaseModel):
     content: str
     content: str
     timestamp: int  # timestamp in epoch
     timestamp: int  # timestamp in epoch
 
 
+    access_control: Optional[dict] = None
     model_config = ConfigDict(from_attributes=True)
     model_config = ConfigDict(from_attributes=True)
 
 
 
 
@@ -35,10 +57,15 @@ class PromptModel(BaseModel):
 ####################
 ####################
 
 
 
 
+class PromptUserResponse(PromptModel):
+    user: Optional[UserResponse] = None
+
+
 class PromptForm(BaseModel):
 class PromptForm(BaseModel):
     command: str
     command: str
     title: str
     title: str
     content: str
     content: str
+    access_control: Optional[dict] = None
 
 
 
 
 class PromptsTable:
 class PromptsTable:
@@ -48,16 +75,14 @@ class PromptsTable:
         prompt = PromptModel(
         prompt = PromptModel(
             **{
             **{
                 "user_id": user_id,
                 "user_id": user_id,
-                "command": form_data.command,
-                "title": form_data.title,
-                "content": form_data.content,
+                **form_data.model_dump(),
                 "timestamp": int(time.time()),
                 "timestamp": int(time.time()),
             }
             }
         )
         )
 
 
         try:
         try:
             with get_db() as db:
             with get_db() as db:
-                result = Prompt(**prompt.dict())
+                result = Prompt(**prompt.model_dump())
                 db.add(result)
                 db.add(result)
                 db.commit()
                 db.commit()
                 db.refresh(result)
                 db.refresh(result)
@@ -76,11 +101,34 @@ class PromptsTable:
         except Exception:
         except Exception:
             return None
             return None
 
 
-    def get_prompts(self) -> list[PromptModel]:
+    def get_prompts(self) -> list[PromptUserResponse]:
         with get_db() as db:
         with get_db() as db:
-            return [
-                PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
-            ]
+            prompts = []
+
+            for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all():
+                user = Users.get_user_by_id(prompt.user_id)
+                prompts.append(
+                    PromptUserResponse.model_validate(
+                        {
+                            **PromptModel.model_validate(prompt).model_dump(),
+                            "user": user.model_dump() if user else None,
+                        }
+                    )
+                )
+
+            return prompts
+
+    def get_prompts_by_user_id(
+        self, user_id: str, permission: str = "write"
+    ) -> list[PromptUserResponse]:
+        prompts = self.get_prompts()
+
+        return [
+            prompt
+            for prompt in prompts
+            if prompt.user_id == user_id
+            or has_access(user_id, permission, prompt.access_control)
+        ]
 
 
     def update_prompt_by_command(
     def update_prompt_by_command(
         self, command: str, form_data: PromptForm
         self, command: str, form_data: PromptForm
@@ -90,6 +138,7 @@ class PromptsTable:
                 prompt = db.query(Prompt).filter_by(command=command).first()
                 prompt = db.query(Prompt).filter_by(command=command).first()
                 prompt.title = form_data.title
                 prompt.title = form_data.title
                 prompt.content = form_data.content
                 prompt.content = form_data.content
+                prompt.access_control = form_data.access_control
                 prompt.timestamp = int(time.time())
                 prompt.timestamp = int(time.time())
                 db.commit()
                 db.commit()
                 return PromptModel.model_validate(prompt)
                 return PromptModel.model_validate(prompt)

+ 1 - 1
backend/open_webui/apps/webui/models/tags.py → backend/open_webui/models/tags.py

@@ -3,7 +3,7 @@ import time
 import uuid
 import uuid
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, get_db
+from open_webui.internal.db import Base, get_db
 
 
 
 
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS

+ 61 - 5
backend/open_webui/apps/webui/models/tools.py → backend/open_webui/models/tools.py

@@ -2,11 +2,14 @@ import logging
 import time
 import time
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, JSONField, get_db
-from open_webui.apps.webui.models.users import Users
+from open_webui.internal.db import Base, JSONField, get_db
+from open_webui.models.users import Users, UserResponse
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
-from sqlalchemy import BigInteger, Column, String, Text
+from sqlalchemy import BigInteger, Column, String, Text, JSON
+
+from open_webui.utils.access_control import has_access
+
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -26,6 +29,24 @@ class Tool(Base):
     specs = Column(JSONField)
     specs = Column(JSONField)
     meta = Column(JSONField)
     meta = Column(JSONField)
     valves = Column(JSONField)
     valves = Column(JSONField)
+
+    access_control = Column(JSON, nullable=True)  # Controls data access levels.
+    # Defines access control rules for this entry.
+    # - `None`: Public access, available to all users with the "user" role.
+    # - `{}`: Private access, restricted exclusively to the owner.
+    # - Custom permissions: Specific access control for reading and writing;
+    #   Can specify group or user-level restrictions:
+    #   {
+    #      "read": {
+    #          "group_ids": ["group_id1", "group_id2"],
+    #          "user_ids":  ["user_id1", "user_id2"]
+    #      },
+    #      "write": {
+    #          "group_ids": ["group_id1", "group_id2"],
+    #          "user_ids":  ["user_id1", "user_id2"]
+    #      }
+    #   }
+
     updated_at = Column(BigInteger)
     updated_at = Column(BigInteger)
     created_at = Column(BigInteger)
     created_at = Column(BigInteger)
 
 
@@ -42,6 +63,8 @@ class ToolModel(BaseModel):
     content: str
     content: str
     specs: list[dict]
     specs: list[dict]
     meta: ToolMeta
     meta: ToolMeta
+    access_control: Optional[dict] = None
+
     updated_at: int  # timestamp in epoch
     updated_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
 
 
@@ -53,20 +76,30 @@ class ToolModel(BaseModel):
 ####################
 ####################
 
 
 
 
+class ToolUserModel(ToolModel):
+    user: Optional[UserResponse] = None
+
+
 class ToolResponse(BaseModel):
 class ToolResponse(BaseModel):
     id: str
     id: str
     user_id: str
     user_id: str
     name: str
     name: str
     meta: ToolMeta
     meta: ToolMeta
+    access_control: Optional[dict] = None
     updated_at: int  # timestamp in epoch
     updated_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
 
 
 
 
+class ToolUserResponse(ToolResponse):
+    user: Optional[UserResponse] = None
+
+
 class ToolForm(BaseModel):
 class ToolForm(BaseModel):
     id: str
     id: str
     name: str
     name: str
     content: str
     content: str
     meta: ToolMeta
     meta: ToolMeta
+    access_control: Optional[dict] = None
 
 
 
 
 class ToolValves(BaseModel):
 class ToolValves(BaseModel):
@@ -109,9 +142,32 @@ class ToolsTable:
         except Exception:
         except Exception:
             return None
             return None
 
 
-    def get_tools(self) -> list[ToolModel]:
+    def get_tools(self) -> list[ToolUserModel]:
         with get_db() as db:
         with get_db() as db:
-            return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
+            tools = []
+            for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all():
+                user = Users.get_user_by_id(tool.user_id)
+                tools.append(
+                    ToolUserModel.model_validate(
+                        {
+                            **ToolModel.model_validate(tool).model_dump(),
+                            "user": user.model_dump() if user else None,
+                        }
+                    )
+                )
+            return tools
+
+    def get_tools_by_user_id(
+        self, user_id: str, permission: str = "write"
+    ) -> list[ToolUserModel]:
+        tools = self.get_tools()
+
+        return [
+            tool
+            for tool in tools
+            if tool.user_id == user_id
+            or has_access(user_id, permission, tool.access_control)
+        ]
 
 
     def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
     def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
         try:
         try:

+ 10 - 2
backend/open_webui/apps/webui/models/users.py → backend/open_webui/models/users.py

@@ -1,8 +1,8 @@
 import time
 import time
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.internal.db import Base, JSONField, get_db
-from open_webui.apps.webui.models.chats import Chats
+from open_webui.internal.db import Base, JSONField, get_db
+from open_webui.models.chats import Chats
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
 from sqlalchemy import BigInteger, Column, String, Text
 from sqlalchemy import BigInteger, Column, String, Text
 
 
@@ -62,6 +62,14 @@ class UserModel(BaseModel):
 ####################
 ####################
 
 
 
 
+class UserResponse(BaseModel):
+    id: str
+    name: str
+    email: str
+    role: str
+    profile_image_url: str
+
+
 class UserRoleUpdateForm(BaseModel):
 class UserRoleUpdateForm(BaseModel):
     id: str
     id: str
     role: str
     role: str

+ 5 - 3
backend/open_webui/apps/retrieval/loaders/main.py → backend/open_webui/retrieval/loaders/main.py

@@ -1,6 +1,7 @@
 import requests
 import requests
 import logging
 import logging
 import ftfy
 import ftfy
+import sys
 
 
 from langchain_community.document_loaders import (
 from langchain_community.document_loaders import (
     BSHTMLLoader,
     BSHTMLLoader,
@@ -18,8 +19,9 @@ from langchain_community.document_loaders import (
     YoutubeLoader,
     YoutubeLoader,
 )
 )
 from langchain_core.documents import Document
 from langchain_core.documents import Document
-from open_webui.env import SRC_LOG_LEVELS
+from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
 
 
+logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
@@ -106,7 +108,7 @@ class TikaLoader:
             if "Content-Type" in raw_metadata:
             if "Content-Type" in raw_metadata:
                 headers["Content-Type"] = raw_metadata["Content-Type"]
                 headers["Content-Type"] = raw_metadata["Content-Type"]
 
 
-            log.info("Tika extracted text: %s", text)
+            log.debug("Tika extracted text: %s", text)
 
 
             return [Document(page_content=text, metadata=headers)]
             return [Document(page_content=text, metadata=headers)]
         else:
         else:
@@ -159,7 +161,7 @@ class Loader:
             elif file_ext in ["htm", "html"]:
             elif file_ext in ["htm", "html"]:
                 loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
                 loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
             elif file_ext == "md":
             elif file_ext == "md":
-                loader = UnstructuredMarkdownLoader(file_path)
+                loader = TextLoader(file_path, autodetect_encoding=True)
             elif file_content_type == "application/epub+zip":
             elif file_content_type == "application/epub+zip":
                 loader = UnstructuredEPubLoader(file_path)
                 loader = UnstructuredEPubLoader(file_path)
             elif (
             elif (

+ 117 - 0
backend/open_webui/retrieval/loaders/youtube.py

@@ -0,0 +1,117 @@
+import logging
+
+from typing import Any, Dict, Generator, List, Optional, Sequence, Union
+from urllib.parse import parse_qs, urlparse
+from langchain_core.documents import Document
+from open_webui.env import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+ALLOWED_SCHEMES = {"http", "https"}
+ALLOWED_NETLOCS = {
+    "youtu.be",
+    "m.youtube.com",
+    "youtube.com",
+    "www.youtube.com",
+    "www.youtube-nocookie.com",
+    "vid.plus",
+}
+
+
+def _parse_video_id(url: str) -> Optional[str]:
+    """Parse a YouTube URL and return the video ID if valid, otherwise None."""
+    parsed_url = urlparse(url)
+
+    if parsed_url.scheme not in ALLOWED_SCHEMES:
+        return None
+
+    if parsed_url.netloc not in ALLOWED_NETLOCS:
+        return None
+
+    path = parsed_url.path
+
+    if path.endswith("/watch"):
+        query = parsed_url.query
+        parsed_query = parse_qs(query)
+        if "v" in parsed_query:
+            ids = parsed_query["v"]
+            video_id = ids if isinstance(ids, str) else ids[0]
+        else:
+            return None
+    else:
+        path = parsed_url.path.lstrip("/")
+        video_id = path.split("/")[-1]
+
+    if len(video_id) != 11:  # Video IDs are 11 characters long
+        return None
+
+    return video_id
+
+
+class YoutubeLoader:
+    """Load `YouTube` video transcripts."""
+
+    def __init__(
+        self,
+        video_id: str,
+        language: Union[str, Sequence[str]] = "en",
+        proxy_url: Optional[str] = None,
+    ):
+        """Initialize with YouTube video ID."""
+        _video_id = _parse_video_id(video_id)
+        self.video_id = _video_id if _video_id is not None else video_id
+        self._metadata = {"source": video_id}
+        self.language = language
+        self.proxy_url = proxy_url
+        if isinstance(language, str):
+            self.language = [language]
+        else:
+            self.language = language
+
+    def load(self) -> List[Document]:
+        """Load YouTube transcripts into `Document` objects."""
+        try:
+            from youtube_transcript_api import (
+                NoTranscriptFound,
+                TranscriptsDisabled,
+                YouTubeTranscriptApi,
+            )
+        except ImportError:
+            raise ImportError(
+                'Could not import "youtube_transcript_api" Python package. '
+                "Please install it with `pip install youtube-transcript-api`."
+            )
+
+        if self.proxy_url:
+            youtube_proxies = {
+                "http": self.proxy_url,
+                "https": self.proxy_url,
+            }
+            # Don't log complete URL because it might contain secrets
+            log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
+        else:
+            youtube_proxies = None
+
+        try:
+            transcript_list = YouTubeTranscriptApi.list_transcripts(
+                self.video_id, proxies=youtube_proxies
+            )
+        except Exception as e:
+            log.exception("Loading YouTube transcript failed")
+            return []
+
+        try:
+            transcript = transcript_list.find_transcript(self.language)
+        except NoTranscriptFound:
+            transcript = transcript_list.find_transcript(["en"])
+
+        transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
+
+        transcript = " ".join(
+            map(
+                lambda transcript_piece: transcript_piece["text"].strip(" "),
+                transcript_pieces,
+            )
+        )
+        return [Document(page_content=transcript, metadata=self._metadata)]

+ 0 - 0
backend/open_webui/apps/retrieval/models/colbert.py → backend/open_webui/retrieval/models/colbert.py


+ 83 - 124
backend/open_webui/apps/retrieval/utils.py → backend/open_webui/retrieval/utils.py

@@ -3,6 +3,7 @@ import os
 import uuid
 import uuid
 from typing import Optional, Union
 from typing import Optional, Union
 
 
+import asyncio
 import requests
 import requests
 
 
 from huggingface_hub import snapshot_download
 from huggingface_hub import snapshot_download
@@ -10,17 +11,10 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev
 from langchain_community.retrievers import BM25Retriever
 from langchain_community.retrievers import BM25Retriever
 from langchain_core.documents import Document
 from langchain_core.documents import Document
 
 
-
-from open_webui.apps.ollama.main import (
-    GenerateEmbedForm,
-    generate_ollama_batch_embeddings,
-)
-from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
+from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
 from open_webui.utils.misc import get_last_user_message
 from open_webui.utils.misc import get_last_user_message
 
 
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
-from open_webui.config import DEFAULT_RAG_TEMPLATE
-
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -76,7 +70,7 @@ def query_doc(
             limit=k,
             limit=k,
         )
         )
 
 
-        log.info(f"query_doc:result {result}")
+        log.info(f"query_doc:result {result.ids} {result.metadatas}")
         return result
         return result
     except Exception as e:
     except Exception as e:
         print(e)
         print(e)
@@ -127,7 +121,10 @@ def query_doc_with_hybrid_search(
             "metadatas": [[d.metadata for d in result]],
             "metadatas": [[d.metadata for d in result]],
         }
         }
 
 
-        log.info(f"query_doc_with_hybrid_search:result {result}")
+        log.info(
+            "query_doc_with_hybrid_search:result "
+            + f'{result["metadatas"]} {result["distances"]}'
+        )
         return result
         return result
     except Exception as e:
     except Exception as e:
         raise e
         raise e
@@ -178,35 +175,34 @@ def merge_and_sort_query_results(
 
 
 def query_collection(
 def query_collection(
     collection_names: list[str],
     collection_names: list[str],
-    query: str,
+    queries: list[str],
     embedding_function,
     embedding_function,
     k: int,
     k: int,
 ) -> dict:
 ) -> dict:
-
     results = []
     results = []
-    query_embedding = embedding_function(query)
-
-    for collection_name in collection_names:
-        if collection_name:
-            try:
-                result = query_doc(
-                    collection_name=collection_name,
-                    k=k,
-                    query_embedding=query_embedding,
-                )
-                if result is not None:
-                    results.append(result.model_dump())
-            except Exception as e:
-                log.exception(f"Error when querying the collection: {e}")
-        else:
-            pass
+    for query in queries:
+        query_embedding = embedding_function(query)
+        for collection_name in collection_names:
+            if collection_name:
+                try:
+                    result = query_doc(
+                        collection_name=collection_name,
+                        k=k,
+                        query_embedding=query_embedding,
+                    )
+                    if result is not None:
+                        results.append(result.model_dump())
+                except Exception as e:
+                    log.exception(f"Error when querying the collection: {e}")
+            else:
+                pass
 
 
     return merge_and_sort_query_results(results, k=k)
     return merge_and_sort_query_results(results, k=k)
 
 
 
 
 def query_collection_with_hybrid_search(
 def query_collection_with_hybrid_search(
     collection_names: list[str],
     collection_names: list[str],
-    query: str,
+    queries: list[str],
     embedding_function,
     embedding_function,
     k: int,
     k: int,
     reranking_function,
     reranking_function,
@@ -216,15 +212,16 @@ def query_collection_with_hybrid_search(
     error = False
     error = False
     for collection_name in collection_names:
     for collection_name in collection_names:
         try:
         try:
-            result = query_doc_with_hybrid_search(
-                collection_name=collection_name,
-                query=query,
-                embedding_function=embedding_function,
-                k=k,
-                reranking_function=reranking_function,
-                r=r,
-            )
-            results.append(result)
+            for query in queries:
+                result = query_doc_with_hybrid_search(
+                    collection_name=collection_name,
+                    query=query,
+                    embedding_function=embedding_function,
+                    k=k,
+                    reranking_function=reranking_function,
+                    r=r,
+                )
+                results.append(result)
         except Exception as e:
         except Exception as e:
             log.exception(
             log.exception(
                 "Error when querying the collection with " f"hybrid_search: {e}"
                 "Error when querying the collection with " f"hybrid_search: {e}"
@@ -239,50 +236,12 @@ def query_collection_with_hybrid_search(
     return merge_and_sort_query_results(results, k=k, reverse=True)
     return merge_and_sort_query_results(results, k=k, reverse=True)
 
 
 
 
-def rag_template(template: str, context: str, query: str):
-    if template == "":
-        template = DEFAULT_RAG_TEMPLATE
-
-    if "[context]" not in template and "{{CONTEXT}}" not in template:
-        log.debug(
-            "WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder."
-        )
-
-    if "<context>" in context and "</context>" in context:
-        log.debug(
-            "WARNING: Potential prompt injection attack: the RAG "
-            "context contains '<context>' and '</context>'. This might be "
-            "nothing, or the user might be trying to hack something."
-        )
-
-    query_placeholders = []
-    if "[query]" in context:
-        query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
-        template = template.replace("[query]", query_placeholder)
-        query_placeholders.append(query_placeholder)
-
-    if "{{QUERY}}" in context:
-        query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
-        template = template.replace("{{QUERY}}", query_placeholder)
-        query_placeholders.append(query_placeholder)
-
-    template = template.replace("[context]", context)
-    template = template.replace("{{CONTEXT}}", context)
-    template = template.replace("[query]", query)
-    template = template.replace("{{QUERY}}", query)
-
-    for query_placeholder in query_placeholders:
-        template = template.replace(query_placeholder, query)
-
-    return template
-
-
 def get_embedding_function(
 def get_embedding_function(
     embedding_engine,
     embedding_engine,
     embedding_model,
     embedding_model,
     embedding_function,
     embedding_function,
-    openai_key,
-    openai_url,
+    url,
+    key,
     embedding_batch_size,
     embedding_batch_size,
 ):
 ):
     if embedding_engine == "":
     if embedding_engine == "":
@@ -292,8 +251,8 @@ def get_embedding_function(
             engine=embedding_engine,
             engine=embedding_engine,
             model=embedding_model,
             model=embedding_model,
             text=query,
             text=query,
-            key=openai_key if embedding_engine == "openai" else "",
-            url=openai_url if embedding_engine == "openai" else "",
+            url=url,
+            key=key,
         )
         )
 
 
         def generate_multiple(query, func):
         def generate_multiple(query, func):
@@ -308,17 +267,16 @@ def get_embedding_function(
         return lambda query: generate_multiple(query, func)
         return lambda query: generate_multiple(query, func)
 
 
 
 
-def get_rag_context(
+def get_sources_from_files(
     files,
     files,
-    messages,
+    queries,
     embedding_function,
     embedding_function,
     k,
     k,
     reranking_function,
     reranking_function,
     r,
     r,
     hybrid_search,
     hybrid_search,
 ):
 ):
-    log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
-    query = get_last_user_message(messages)
+    log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
 
 
     extracted_collections = []
     extracted_collections = []
     relevant_contexts = []
     relevant_contexts = []
@@ -360,7 +318,7 @@ def get_rag_context(
                         try:
                         try:
                             context = query_collection_with_hybrid_search(
                             context = query_collection_with_hybrid_search(
                                 collection_names=collection_names,
                                 collection_names=collection_names,
-                                query=query,
+                                queries=queries,
                                 embedding_function=embedding_function,
                                 embedding_function=embedding_function,
                                 k=k,
                                 k=k,
                                 reranking_function=reranking_function,
                                 reranking_function=reranking_function,
@@ -375,7 +333,7 @@ def get_rag_context(
                     if (not hybrid_search) or (context is None):
                     if (not hybrid_search) or (context is None):
                         context = query_collection(
                         context = query_collection(
                             collection_names=collection_names,
                             collection_names=collection_names,
-                            query=query,
+                            queries=queries,
                             embedding_function=embedding_function,
                             embedding_function=embedding_function,
                             k=k,
                             k=k,
                         )
                         )
@@ -389,43 +347,24 @@ def get_rag_context(
                 del file["data"]
                 del file["data"]
             relevant_contexts.append({**context, "file": file})
             relevant_contexts.append({**context, "file": file})
 
 
-    contexts = []
-    citations = []
+    sources = []
     for context in relevant_contexts:
     for context in relevant_contexts:
         try:
         try:
             if "documents" in context:
             if "documents" in context:
-                file_names = list(
-                    set(
-                        [
-                            metadata["name"]
-                            for metadata in context["metadatas"][0]
-                            if metadata is not None and "name" in metadata
-                        ]
-                    )
-                )
-                contexts.append(
-                    ((", ".join(file_names) + ":\n\n") if file_names else "")
-                    + "\n\n".join(
-                        [text for text in context["documents"][0] if text is not None]
-                    )
-                )
-
                 if "metadatas" in context:
                 if "metadatas" in context:
-                    citation = {
+                    source = {
                         "source": context["file"],
                         "source": context["file"],
                         "document": context["documents"][0],
                         "document": context["documents"][0],
                         "metadata": context["metadatas"][0],
                         "metadata": context["metadatas"][0],
                     }
                     }
                     if "distances" in context and context["distances"]:
                     if "distances" in context and context["distances"]:
-                        citation["distances"] = context["distances"][0]
-                    citations.append(citation)
+                        source["distances"] = context["distances"][0]
+
+                    sources.append(source)
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)
 
 
-    print("contexts", contexts)
-    print("citations", citations)
-
-    return contexts, citations
+    return sources
 
 
 
 
 def get_model_path(model: str, update_model: bool = False):
 def get_model_path(model: str, update_model: bool = False):
@@ -467,7 +406,7 @@ def get_model_path(model: str, update_model: bool = False):
 
 
 
 
 def generate_openai_batch_embeddings(
 def generate_openai_batch_embeddings(
-    model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
+    model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
 ) -> Optional[list[list[float]]]:
 ) -> Optional[list[list[float]]]:
     try:
     try:
         r = requests.post(
         r = requests.post(
@@ -489,29 +428,49 @@ def generate_openai_batch_embeddings(
         return None
         return None
 
 
 
 
+def generate_ollama_batch_embeddings(
+    model: str, texts: list[str], url: str, key: str = ""
+) -> Optional[list[list[float]]]:
+    try:
+        r = requests.post(
+            f"{url}/api/embed",
+            headers={
+                "Content-Type": "application/json",
+                "Authorization": f"Bearer {key}",
+            },
+            json={"input": texts, "model": model},
+        )
+        r.raise_for_status()
+        data = r.json()
+
+        if "embeddings" in data:
+            return data["embeddings"]
+        else:
+            raise "Something went wrong :/"
+    except Exception as e:
+        print(e)
+        return None
+
+
 def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
 def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
+    url = kwargs.get("url", "")
+    key = kwargs.get("key", "")
+
     if engine == "ollama":
     if engine == "ollama":
         if isinstance(text, list):
         if isinstance(text, list):
             embeddings = generate_ollama_batch_embeddings(
             embeddings = generate_ollama_batch_embeddings(
-                GenerateEmbedForm(**{"model": model, "input": text})
+                **{"model": model, "texts": text, "url": url, "key": key}
             )
             )
         else:
         else:
             embeddings = generate_ollama_batch_embeddings(
             embeddings = generate_ollama_batch_embeddings(
-                GenerateEmbedForm(**{"model": model, "input": [text]})
+                **{"model": model, "texts": [text], "url": url, "key": key}
             )
             )
-        return (
-            embeddings["embeddings"][0]
-            if isinstance(text, str)
-            else embeddings["embeddings"]
-        )
+        return embeddings[0] if isinstance(text, str) else embeddings
     elif engine == "openai":
     elif engine == "openai":
-        key = kwargs.get("key", "")
-        url = kwargs.get("url", "https://api.openai.com/v1")
-
         if isinstance(text, list):
         if isinstance(text, list):
-            embeddings = generate_openai_batch_embeddings(model, text, key, url)
+            embeddings = generate_openai_batch_embeddings(model, text, url, key)
         else:
         else:
-            embeddings = generate_openai_batch_embeddings(model, [text], key, url)
+            embeddings = generate_openai_batch_embeddings(model, [text], url, key)
 
 
         return embeddings[0] if isinstance(text, str) else embeddings
         return embeddings[0] if isinstance(text, str) else embeddings
 
 

+ 22 - 0
backend/open_webui/retrieval/vector/connector.py

@@ -0,0 +1,22 @@
+from open_webui.config import VECTOR_DB
+
+if VECTOR_DB == "milvus":
+    from open_webui.retrieval.vector.dbs.milvus import MilvusClient
+
+    VECTOR_DB_CLIENT = MilvusClient()
+elif VECTOR_DB == "qdrant":
+    from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
+
+    VECTOR_DB_CLIENT = QdrantClient()
+elif VECTOR_DB == "opensearch":
+    from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
+
+    VECTOR_DB_CLIENT = OpenSearchClient()
+elif VECTOR_DB == "pgvector":
+    from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
+
+    VECTOR_DB_CLIENT = PgvectorClient()
+else:
+    from open_webui.retrieval.vector.dbs.chroma import ChromaClient
+
+    VECTOR_DB_CLIENT = ChromaClient()

+ 16 - 3
backend/open_webui/apps/retrieval/vector/dbs/chroma.py → backend/open_webui/retrieval/vector/dbs/chroma.py

@@ -4,7 +4,7 @@ from chromadb.utils.batch_utils import create_batches
 
 
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
 from open_webui.config import (
 from open_webui.config import (
     CHROMA_DATA_PATH,
     CHROMA_DATA_PATH,
     CHROMA_HTTP_HOST,
     CHROMA_HTTP_HOST,
@@ -13,11 +13,24 @@ from open_webui.config import (
     CHROMA_HTTP_SSL,
     CHROMA_HTTP_SSL,
     CHROMA_TENANT,
     CHROMA_TENANT,
     CHROMA_DATABASE,
     CHROMA_DATABASE,
+    CHROMA_CLIENT_AUTH_PROVIDER,
+    CHROMA_CLIENT_AUTH_CREDENTIALS,
 )
 )
 
 
 
 
 class ChromaClient:
 class ChromaClient:
     def __init__(self):
     def __init__(self):
+        settings_dict = {
+            "allow_reset": True,
+            "anonymized_telemetry": False,
+        }
+        if CHROMA_CLIENT_AUTH_PROVIDER is not None:
+            settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
+        if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
+            settings_dict["chroma_client_auth_credentials"] = (
+                CHROMA_CLIENT_AUTH_CREDENTIALS
+            )
+
         if CHROMA_HTTP_HOST != "":
         if CHROMA_HTTP_HOST != "":
             self.client = chromadb.HttpClient(
             self.client = chromadb.HttpClient(
                 host=CHROMA_HTTP_HOST,
                 host=CHROMA_HTTP_HOST,
@@ -26,12 +39,12 @@ class ChromaClient:
                 ssl=CHROMA_HTTP_SSL,
                 ssl=CHROMA_HTTP_SSL,
                 tenant=CHROMA_TENANT,
                 tenant=CHROMA_TENANT,
                 database=CHROMA_DATABASE,
                 database=CHROMA_DATABASE,
-                settings=Settings(allow_reset=True, anonymized_telemetry=False),
+                settings=Settings(**settings_dict),
             )
             )
         else:
         else:
             self.client = chromadb.PersistentClient(
             self.client = chromadb.PersistentClient(
                 path=CHROMA_DATA_PATH,
                 path=CHROMA_DATA_PATH,
-                settings=Settings(allow_reset=True, anonymized_telemetry=False),
+                settings=Settings(**settings_dict),
                 tenant=CHROMA_TENANT,
                 tenant=CHROMA_TENANT,
                 database=CHROMA_DATABASE,
                 database=CHROMA_DATABASE,
             )
             )

+ 1 - 1
backend/open_webui/apps/retrieval/vector/dbs/milvus.py → backend/open_webui/retrieval/vector/dbs/milvus.py

@@ -4,7 +4,7 @@ import json
 
 
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
 from open_webui.config import (
 from open_webui.config import (
     MILVUS_URI,
     MILVUS_URI,
 )
 )

+ 178 - 0
backend/open_webui/retrieval/vector/dbs/opensearch.py

@@ -0,0 +1,178 @@
+from opensearchpy import OpenSearch
+from typing import Optional
+
+from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.config import (
+    OPENSEARCH_URI,
+    OPENSEARCH_SSL,
+    OPENSEARCH_CERT_VERIFY,
+    OPENSEARCH_USERNAME,
+    OPENSEARCH_PASSWORD,
+)
+
+
+class OpenSearchClient:
+    def __init__(self):
+        self.index_prefix = "open_webui"
+        self.client = OpenSearch(
+            hosts=[OPENSEARCH_URI],
+            use_ssl=OPENSEARCH_SSL,
+            verify_certs=OPENSEARCH_CERT_VERIFY,
+            http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
+        )
+
+    def _result_to_get_result(self, result) -> GetResult:
+        ids = []
+        documents = []
+        metadatas = []
+
+        for hit in result["hits"]["hits"]:
+            ids.append(hit["_id"])
+            documents.append(hit["_source"].get("text"))
+            metadatas.append(hit["_source"].get("metadata"))
+
+        return GetResult(ids=ids, documents=documents, metadatas=metadatas)
+
+    def _result_to_search_result(self, result) -> SearchResult:
+        ids = []
+        distances = []
+        documents = []
+        metadatas = []
+
+        for hit in result["hits"]["hits"]:
+            ids.append(hit["_id"])
+            distances.append(hit["_score"])
+            documents.append(hit["_source"].get("text"))
+            metadatas.append(hit["_source"].get("metadata"))
+
+        return SearchResult(
+            ids=ids, distances=distances, documents=documents, metadatas=metadatas
+        )
+
+    def _create_index(self, index_name: str, dimension: int):
+        body = {
+            "mappings": {
+                "properties": {
+                    "id": {"type": "keyword"},
+                    "vector": {
+                        "type": "dense_vector",
+                        "dims": dimension,  # Adjust based on your vector dimensions
+                        "index": true,
+                        "similarity": "faiss",
+                        "method": {
+                            "name": "hnsw",
+                            "space_type": "ip",  # Use inner product to approximate cosine similarity
+                            "engine": "faiss",
+                            "ef_construction": 128,
+                            "m": 16,
+                        },
+                    },
+                    "text": {"type": "text"},
+                    "metadata": {"type": "object"},
+                }
+            }
+        }
+        self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body)
+
+    def _create_batches(self, items: list[VectorItem], batch_size=100):
+        for i in range(0, len(items), batch_size):
+            yield items[i : i + batch_size]
+
+    def has_collection(self, index_name: str) -> bool:
+        # has_collection here means has index.
+        # We are simply adapting to the norms of the other DBs.
+        return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}")
+
+    def delete_colleciton(self, index_name: str):
+        # delete_collection here means delete index.
+        # We are simply adapting to the norms of the other DBs.
+        self.client.indices.delete(index=f"{self.index_prefix}_{index_name}")
+
+    def search(
+        self, index_name: str, vectors: list[list[float]], limit: int
+    ) -> Optional[SearchResult]:
+        query = {
+            "size": limit,
+            "_source": ["text", "metadata"],
+            "query": {
+                "script_score": {
+                    "query": {"match_all": {}},
+                    "script": {
+                        "source": "cosineSimilarity(params.vector, 'vector') + 1.0",
+                        "params": {
+                            "vector": vectors[0]
+                        },  # Assuming single query vector
+                    },
+                }
+            },
+        }
+
+        result = self.client.search(
+            index=f"{self.index_prefix}_{index_name}", body=query
+        )
+
+        return self._result_to_search_result(result)
+
+    def get_or_create_index(self, index_name: str, dimension: int):
+        if not self.has_index(index_name):
+            self._create_index(index_name, dimension)
+
+    def get(self, index_name: str) -> Optional[GetResult]:
+        query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
+
+        result = self.client.search(
+            index=f"{self.index_prefix}_{index_name}", body=query
+        )
+        return self._result_to_get_result(result)
+
+    def insert(self, index_name: str, items: list[VectorItem]):
+        if not self.has_index(index_name):
+            self._create_index(index_name, dimension=len(items[0]["vector"]))
+
+        for batch in self._create_batches(items):
+            actions = [
+                {
+                    "index": {
+                        "_id": item["id"],
+                        "_source": {
+                            "vector": item["vector"],
+                            "text": item["text"],
+                            "metadata": item["metadata"],
+                        },
+                    }
+                }
+                for item in batch
+            ]
+            self.client.bulk(actions)
+
+    def upsert(self, index_name: str, items: list[VectorItem]):
+        if not self.has_index(index_name):
+            self._create_index(index_name, dimension=len(items[0]["vector"]))
+
+        for batch in self._create_batches(items):
+            actions = [
+                {
+                    "index": {
+                        "_id": item["id"],
+                        "_source": {
+                            "vector": item["vector"],
+                            "text": item["text"],
+                            "metadata": item["metadata"],
+                        },
+                    }
+                }
+                for item in batch
+            ]
+            self.client.bulk(actions)
+
+    def delete(self, index_name: str, ids: list[str]):
+        actions = [
+            {"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}}
+            for id in ids
+        ]
+        self.client.bulk(body=actions)
+
+    def reset(self):
+        indices = self.client.indices.get(index=f"{self.index_prefix}_*")
+        for index in indices:
+            self.client.indices.delete(index=index)

+ 354 - 0
backend/open_webui/retrieval/vector/dbs/pgvector.py

@@ -0,0 +1,354 @@
+from typing import Optional, List, Dict, Any
+from sqlalchemy import (
+    cast,
+    column,
+    create_engine,
+    Column,
+    Integer,
+    select,
+    text,
+    Text,
+    values,
+)
+from sqlalchemy.sql import true
+from sqlalchemy.pool import NullPool
+
+from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
+from sqlalchemy.dialects.postgresql import JSONB, array
+from pgvector.sqlalchemy import Vector
+from sqlalchemy.ext.mutable import MutableDict
+
+from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.config import PGVECTOR_DB_URL
+
+VECTOR_LENGTH = 1536
+Base = declarative_base()
+
+
+class DocumentChunk(Base):
+    __tablename__ = "document_chunk"
+
+    id = Column(Text, primary_key=True)
+    vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
+    collection_name = Column(Text, nullable=False)
+    text = Column(Text, nullable=True)
+    vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
+
+
+class PgvectorClient:
+    def __init__(self) -> None:
+
+        # if no pgvector uri, use the existing database connection
+        if not PGVECTOR_DB_URL:
+            from open_webui.internal.db import Session
+
+            self.session = Session
+        else:
+            engine = create_engine(
+                PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
+            )
+            SessionLocal = sessionmaker(
+                autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
+            )
+            self.session = scoped_session(SessionLocal)
+
+        try:
+            # Ensure the pgvector extension is available
+            self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
+
+            # Create the tables if they do not exist
+            # Base.metadata.create_all requires a bind (engine or connection)
+            # Get the connection from the session
+            connection = self.session.connection()
+            Base.metadata.create_all(bind=connection)
+
+            # Create an index on the vector column if it doesn't exist
+            self.session.execute(
+                text(
+                    "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
+                    "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
+                )
+            )
+            self.session.execute(
+                text(
+                    "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
+                    "ON document_chunk (collection_name);"
+                )
+            )
+            self.session.commit()
+            print("Initialization complete.")
+        except Exception as e:
+            self.session.rollback()
+            print(f"Error during initialization: {e}")
+            raise
+
+    def adjust_vector_length(self, vector: List[float]) -> List[float]:
+        # Adjust vector to have length VECTOR_LENGTH
+        current_length = len(vector)
+        if current_length < VECTOR_LENGTH:
+            # Pad the vector with zeros
+            vector += [0.0] * (VECTOR_LENGTH - current_length)
+        elif current_length > VECTOR_LENGTH:
+            raise Exception(
+                f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
+            )
+        return vector
+
+    def insert(self, collection_name: str, items: List[VectorItem]) -> None:
+        try:
+            new_items = []
+            for item in items:
+                vector = self.adjust_vector_length(item["vector"])
+                new_chunk = DocumentChunk(
+                    id=item["id"],
+                    vector=vector,
+                    collection_name=collection_name,
+                    text=item["text"],
+                    vmetadata=item["metadata"],
+                )
+                new_items.append(new_chunk)
+            self.session.bulk_save_objects(new_items)
+            self.session.commit()
+            print(
+                f"Inserted {len(new_items)} items into collection '{collection_name}'."
+            )
+        except Exception as e:
+            self.session.rollback()
+            print(f"Error during insert: {e}")
+            raise
+
+    def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
+        try:
+            for item in items:
+                vector = self.adjust_vector_length(item["vector"])
+                existing = (
+                    self.session.query(DocumentChunk)
+                    .filter(DocumentChunk.id == item["id"])
+                    .first()
+                )
+                if existing:
+                    existing.vector = vector
+                    existing.text = item["text"]
+                    existing.vmetadata = item["metadata"]
+                    existing.collection_name = (
+                        collection_name  # Update collection_name if necessary
+                    )
+                else:
+                    new_chunk = DocumentChunk(
+                        id=item["id"],
+                        vector=vector,
+                        collection_name=collection_name,
+                        text=item["text"],
+                        vmetadata=item["metadata"],
+                    )
+                    self.session.add(new_chunk)
+            self.session.commit()
+            print(f"Upserted {len(items)} items into collection '{collection_name}'.")
+        except Exception as e:
+            self.session.rollback()
+            print(f"Error during upsert: {e}")
+            raise
+
+    def search(
+        self,
+        collection_name: str,
+        vectors: List[List[float]],
+        limit: Optional[int] = None,
+    ) -> Optional[SearchResult]:
+        try:
+            if not vectors:
+                return None
+
+            # Adjust query vectors to VECTOR_LENGTH
+            vectors = [self.adjust_vector_length(vector) for vector in vectors]
+            num_queries = len(vectors)
+
+            def vector_expr(vector):
+                return cast(array(vector), Vector(VECTOR_LENGTH))
+
+            # Create the values for query vectors
+            qid_col = column("qid", Integer)
+            q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
+            query_vectors = (
+                values(qid_col, q_vector_col)
+                .data(
+                    [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
+                )
+                .alias("query_vectors")
+            )
+
+            # Build the lateral subquery for each query vector
+            subq = (
+                select(
+                    DocumentChunk.id,
+                    DocumentChunk.text,
+                    DocumentChunk.vmetadata,
+                    (
+                        DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
+                    ).label("distance"),
+                )
+                .where(DocumentChunk.collection_name == collection_name)
+                .order_by(
+                    (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
+                )
+            )
+            if limit is not None:
+                subq = subq.limit(limit)
+            subq = subq.lateral("result")
+
+            # Build the main query by joining query_vectors and the lateral subquery
+            stmt = (
+                select(
+                    query_vectors.c.qid,
+                    subq.c.id,
+                    subq.c.text,
+                    subq.c.vmetadata,
+                    subq.c.distance,
+                )
+                .select_from(query_vectors)
+                .join(subq, true())
+                .order_by(query_vectors.c.qid, subq.c.distance)
+            )
+
+            result_proxy = self.session.execute(stmt)
+            results = result_proxy.all()
+
+            ids = [[] for _ in range(num_queries)]
+            distances = [[] for _ in range(num_queries)]
+            documents = [[] for _ in range(num_queries)]
+            metadatas = [[] for _ in range(num_queries)]
+
+            if not results:
+                return SearchResult(
+                    ids=ids,
+                    distances=distances,
+                    documents=documents,
+                    metadatas=metadatas,
+                )
+
+            for row in results:
+                qid = int(row.qid)
+                ids[qid].append(row.id)
+                distances[qid].append(row.distance)
+                documents[qid].append(row.text)
+                metadatas[qid].append(row.vmetadata)
+
+            return SearchResult(
+                ids=ids, distances=distances, documents=documents, metadatas=metadatas
+            )
+        except Exception as e:
+            print(f"Error during search: {e}")
+            return None
+
+    def query(
+        self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
+    ) -> Optional[GetResult]:
+        try:
+            query = self.session.query(DocumentChunk).filter(
+                DocumentChunk.collection_name == collection_name
+            )
+
+            for key, value in filter.items():
+                query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
+
+            if limit is not None:
+                query = query.limit(limit)
+
+            results = query.all()
+
+            if not results:
+                return None
+
+            ids = [[result.id for result in results]]
+            documents = [[result.text for result in results]]
+            metadatas = [[result.vmetadata for result in results]]
+
+            return GetResult(
+                ids=ids,
+                documents=documents,
+                metadatas=metadatas,
+            )
+        except Exception as e:
+            print(f"Error during query: {e}")
+            return None
+
+    def get(
+        self, collection_name: str, limit: Optional[int] = None
+    ) -> Optional[GetResult]:
+        try:
+            query = self.session.query(DocumentChunk).filter(
+                DocumentChunk.collection_name == collection_name
+            )
+            if limit is not None:
+                query = query.limit(limit)
+
+            results = query.all()
+
+            if not results:
+                return None
+
+            ids = [[result.id for result in results]]
+            documents = [[result.text for result in results]]
+            metadatas = [[result.vmetadata for result in results]]
+
+            return GetResult(ids=ids, documents=documents, metadatas=metadatas)
+        except Exception as e:
+            print(f"Error during get: {e}")
+            return None
+
+    def delete(
+        self,
+        collection_name: str,
+        ids: Optional[List[str]] = None,
+        filter: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        try:
+            query = self.session.query(DocumentChunk).filter(
+                DocumentChunk.collection_name == collection_name
+            )
+            if ids:
+                query = query.filter(DocumentChunk.id.in_(ids))
+            if filter:
+                for key, value in filter.items():
+                    query = query.filter(
+                        DocumentChunk.vmetadata[key].astext == str(value)
+                    )
+            deleted = query.delete(synchronize_session=False)
+            self.session.commit()
+            print(f"Deleted {deleted} items from collection '{collection_name}'.")
+        except Exception as e:
+            self.session.rollback()
+            print(f"Error during delete: {e}")
+            raise
+
+    def reset(self) -> None:
+        try:
+            deleted = self.session.query(DocumentChunk).delete()
+            self.session.commit()
+            print(
+                f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
+            )
+        except Exception as e:
+            self.session.rollback()
+            print(f"Error during reset: {e}")
+            raise
+
+    def close(self) -> None:
+        pass
+
+    def has_collection(self, collection_name: str) -> bool:
+        try:
+            exists = (
+                self.session.query(DocumentChunk)
+                .filter(DocumentChunk.collection_name == collection_name)
+                .first()
+                is not None
+            )
+            return exists
+        except Exception as e:
+            print(f"Error checking collection existence: {e}")
+            return False
+
+    def delete_collection(self, collection_name: str) -> None:
+        self.delete(collection_name)
+        print(f"Collection '{collection_name}' deleted.")

+ 8 - 3
backend/open_webui/apps/retrieval/vector/dbs/qdrant.py → backend/open_webui/retrieval/vector/dbs/qdrant.py

@@ -4,8 +4,8 @@ from qdrant_client import QdrantClient as Qclient
 from qdrant_client.http.models import PointStruct
 from qdrant_client.http.models import PointStruct
 from qdrant_client.models import models
 from qdrant_client.models import models
 
 
-from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
-from open_webui.config import QDRANT_URI
+from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.config import QDRANT_URI, QDRANT_API_KEY
 
 
 NO_LIMIT = 999999999
 NO_LIMIT = 999999999
 
 
@@ -14,7 +14,12 @@ class QdrantClient:
     def __init__(self):
     def __init__(self):
         self.collection_prefix = "open-webui"
         self.collection_prefix = "open-webui"
         self.QDRANT_URI = QDRANT_URI
         self.QDRANT_URI = QDRANT_URI
-        self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None
+        self.QDRANT_API_KEY = QDRANT_API_KEY
+        self.client = (
+            Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
+            if self.QDRANT_URI
+            else None
+        )
 
 
     def _result_to_get_result(self, points) -> GetResult:
     def _result_to_get_result(self, points) -> GetResult:
         ids = []
         ids = []

+ 0 - 0
backend/open_webui/apps/retrieval/vector/main.py → backend/open_webui/retrieval/vector/main.py


+ 73 - 0
backend/open_webui/retrieval/web/bing.py

@@ -0,0 +1,73 @@
+import logging
+import os
+from pprint import pprint
+from typing import Optional
+import requests
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.env import SRC_LOG_LEVELS
+import argparse
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+"""
+Documentation: https://docs.microsoft.com/en-us/bing/search-apis/bing-web-search/overview
+"""
+
+
+def search_bing(
+    subscription_key: str,
+    endpoint: str,
+    locale: str,
+    query: str,
+    count: int,
+    filter_list: Optional[list[str]] = None,
+) -> list[SearchResult]:
+    mkt = locale
+    params = {"q": query, "mkt": mkt, "answerCount": count}
+    headers = {"Ocp-Apim-Subscription-Key": subscription_key}
+
+    try:
+        response = requests.get(endpoint, headers=headers, params=params)
+        response.raise_for_status()
+        json_response = response.json()
+        results = json_response.get("webPages", {}).get("value", [])
+        if filter_list:
+            results = get_filtered_results(results, filter_list)
+        return [
+            SearchResult(
+                link=result["url"],
+                title=result.get("name"),
+                snippet=result.get("snippet"),
+            )
+            for result in results
+        ]
+    except Exception as ex:
+        log.error(f"Error: {ex}")
+        raise ex
+
+
+def main():
+    parser = argparse.ArgumentParser(description="Search Bing from the command line.")
+    parser.add_argument(
+        "query",
+        type=str,
+        default="Top 10 international news today",
+        help="The search query.",
+    )
+    parser.add_argument(
+        "--count", type=int, default=10, help="Number of search results to return."
+    )
+    parser.add_argument(
+        "--filter", nargs="*", help="List of filters to apply to the search results."
+    )
+    parser.add_argument(
+        "--locale",
+        type=str,
+        default="en-US",
+        help="The locale to use for the search, maps to market in api",
+    )
+
+    args = parser.parse_args()
+
+    results = search_bing(args.locale, args.query, args.count, args.filter)
+    pprint(results)

+ 1 - 1
backend/open_webui/apps/retrieval/web/brave.py → backend/open_webui/retrieval/web/brave.py

@@ -2,7 +2,7 @@ import logging
 from typing import Optional
 from typing import Optional
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/web/duckduckgo.py → backend/open_webui/retrieval/web/duckduckgo.py

@@ -1,7 +1,7 @@
 import logging
 import logging
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
 from duckduckgo_search import DDGS
 from duckduckgo_search import DDGS
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 

+ 1 - 1
backend/open_webui/apps/retrieval/web/google_pse.py → backend/open_webui/retrieval/web/google_pse.py

@@ -2,7 +2,7 @@ import logging
 from typing import Optional
 from typing import Optional
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)

+ 3 - 5
backend/open_webui/apps/retrieval/web/jina_search.py → backend/open_webui/retrieval/web/jina_search.py

@@ -1,7 +1,7 @@
 import logging
 import logging
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult
+from open_webui.retrieval.web.main import SearchResult
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from yarl import URL
 from yarl import URL
 
 
@@ -9,7 +9,7 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
 
 
-def search_jina(query: str, count: int) -> list[SearchResult]:
+def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:
     """
     """
     Search using Jina's Search API and return the results as a list of SearchResult objects.
     Search using Jina's Search API and return the results as a list of SearchResult objects.
     Args:
     Args:
@@ -20,9 +20,7 @@ def search_jina(query: str, count: int) -> list[SearchResult]:
         list[SearchResult]: A list of search results
         list[SearchResult]: A list of search results
     """
     """
     jina_search_endpoint = "https://s.jina.ai/"
     jina_search_endpoint = "https://s.jina.ai/"
-    headers = {
-        "Accept": "application/json",
-    }
+    headers = {"Accept": "application/json", "Authorization": f"Bearer {api_key}"}
     url = str(URL(jina_search_endpoint + query))
     url = str(URL(jina_search_endpoint + query))
     response = requests.get(url, headers=headers)
     response = requests.get(url, headers=headers)
     response.raise_for_status()
     response.raise_for_status()

+ 48 - 0
backend/open_webui/retrieval/web/kagi.py

@@ -0,0 +1,48 @@
+import logging
+from typing import Optional
+
+import requests
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.env import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+
+def search_kagi(
+    api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
+) -> list[SearchResult]:
+    """Search using Kagi's Search API and return the results as a list of SearchResult objects.
+
+    The Search API will inherit the settings in your account, including results personalization and snippet length.
+
+    Args:
+        api_key (str): A Kagi Search API key
+        query (str): The query to search for
+        count (int): The number of results to return
+    """
+    url = "https://kagi.com/api/v0/search"
+    headers = {
+        "Authorization": f"Bot {api_key}",
+    }
+    params = {"q": query, "limit": count}
+
+    response = requests.get(url, headers=headers, params=params)
+    response.raise_for_status()
+    json_response = response.json()
+    search_results = json_response.get("data", [])
+
+    results = [
+        SearchResult(
+            link=result["url"], title=result["title"], snippet=result.get("snippet")
+        )
+        for result in search_results
+        if result["t"] == 0
+    ]
+
+    print(results)
+
+    if filter_list:
+        results = get_filtered_results(results, filter_list)
+
+    return results

+ 0 - 0
backend/open_webui/apps/retrieval/web/main.py → backend/open_webui/retrieval/web/main.py


+ 40 - 0
backend/open_webui/retrieval/web/mojeek.py

@@ -0,0 +1,40 @@
+import logging
+from typing import Optional
+
+import requests
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.env import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+
+def search_mojeek(
+    api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
+) -> list[SearchResult]:
+    """Search using Mojeek's Search API and return the results as a list of SearchResult objects.
+
+    Args:
+        api_key (str): A Mojeek Search API key
+        query (str): The query to search for
+    """
+    url = "https://api.mojeek.com/search"
+    headers = {
+        "Accept": "application/json",
+    }
+    params = {"q": query, "api_key": api_key, "fmt": "json", "t": count}
+
+    response = requests.get(url, headers=headers, params=params)
+    response.raise_for_status()
+    json_response = response.json()
+    results = json_response.get("response", {}).get("results", [])
+    print(results)
+    if filter_list:
+        results = get_filtered_results(results, filter_list)
+
+    return [
+        SearchResult(
+            link=result["url"], title=result.get("title"), snippet=result.get("desc")
+        )
+        for result in results
+    ]

+ 1 - 1
backend/open_webui/apps/retrieval/web/searchapi.py → backend/open_webui/retrieval/web/searchapi.py

@@ -3,7 +3,7 @@ from typing import Optional
 from urllib.parse import urlencode
 from urllib.parse import urlencode
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/web/searxng.py → backend/open_webui/retrieval/web/searxng.py

@@ -2,7 +2,7 @@ import logging
 from typing import Optional
 from typing import Optional
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/web/serper.py → backend/open_webui/retrieval/web/serper.py

@@ -3,7 +3,7 @@ import logging
 from typing import Optional
 from typing import Optional
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/web/serply.py → backend/open_webui/retrieval/web/serply.py

@@ -3,7 +3,7 @@ from typing import Optional
 from urllib.parse import urlencode
 from urllib.parse import urlencode
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/web/serpstack.py → backend/open_webui/retrieval/web/serpstack.py

@@ -2,7 +2,7 @@ import logging
 from typing import Optional
 from typing import Optional
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/web/tavily.py → backend/open_webui/retrieval/web/tavily.py

@@ -1,7 +1,7 @@
 import logging
 import logging
 
 
 import requests
 import requests
-from open_webui.apps.retrieval.web.main import SearchResult
+from open_webui.retrieval.web.main import SearchResult
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)

+ 58 - 0
backend/open_webui/retrieval/web/testdata/bing.json

@@ -0,0 +1,58 @@
+{
+	"_type": "SearchResponse",
+	"queryContext": {
+		"originalQuery": "Top 10 international results"
+	},
+	"webPages": {
+		"webSearchUrl": "https://www.bing.com/search?q=Top+10+international+results",
+		"totalEstimatedMatches": 687,
+		"value": [
+			{
+				"id": "https://api.bing.microsoft.com/api/v7/#WebPages.0",
+				"name": "2024 Mexican Grand Prix - F1 results and latest standings ... - PlanetF1",
+				"url": "https://www.planetf1.com/news/f1-results-2024-mexican-grand-prix-race-standings",
+				"datePublished": "2024-10-27T00:00:00.0000000",
+				"datePublishedFreshnessText": "1 day ago",
+				"isFamilyFriendly": true,
+				"displayUrl": "https://www.planetf1.com/news/f1-results-2024-mexican-grand-prix-race-standings",
+				"snippet": "Nico Hulkenberg and Pierre Gasly completed the top 10. A full report of the Mexican Grand Prix is available at the bottom of this article. F1 results – 2024 Mexican Grand Prix",
+				"dateLastCrawled": "2024-10-28T07:15:00.0000000Z",
+				"cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=916492551782&mkt=en-US&setlang=en-US&w=zBsfaAPyF2tUrHFHr_vFFdUm8sng4g34",
+				"language": "en",
+				"isNavigational": false,
+				"noCache": false
+			},
+			{
+				"id": "https://api.bing.microsoft.com/api/v7/#WebPages.1",
+				"name": "F1 Results Today: HUGE Verstappen penalties cause major title change",
+				"url": "https://www.gpfans.com/en/f1-news/1033512/f1-results-today-mexican-grand-prix-huge-max-verstappen-penalties-cause-major-title-change/",
+				"datePublished": "2024-10-27T00:00:00.0000000",
+				"datePublishedFreshnessText": "1 day ago",
+				"isFamilyFriendly": true,
+				"displayUrl": "https://www.gpfans.com/en/f1-news/1033512/f1-results-today-mexican-grand-prix-huge-max...",
+				"snippet": "Elsewhere, Mercedes duo Lewis Hamilton and George Russell came home in P4 and P5 respectively. Meanwhile, the surprise package of the day were Haas, with both Kevin Magnussen and Nico Hulkenberg finishing inside the points.. READ MORE: RB star issues apology after red flag CRASH at Mexican GP Mexican Grand Prix 2024 results. 1. Carlos Sainz [Ferrari] 2. Lando Norris [McLaren] - +4.705",
+				"dateLastCrawled": "2024-10-28T06:06:00.0000000Z",
+				"cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=2840656522642&mkt=en-US&setlang=en-US&w=-Tbkwxnq52jZCvG7l3CtgcwT1vwAjIUD",
+				"language": "en",
+				"isNavigational": false,
+				"noCache": false
+			},
+			{
+				"id": "https://api.bing.microsoft.com/api/v7/#WebPages.2",
+				"name": "International Power Rankings: England flying, Kangaroos cruising, Fiji rise",
+				"url": "https://www.loverugbyleague.com/post/international-power-rankings-england-flying-kangaroos-cruising-fiji-rise",
+				"datePublished": "2024-10-28T00:00:00.0000000",
+				"datePublishedFreshnessText": "7 hours ago",
+				"isFamilyFriendly": true,
+				"displayUrl": "https://www.loverugbyleague.com/post/international-power-rankings-england-flying...",
+				"snippet": "LRL RECOMMENDS: England player ratings from first Test against Samoa as omnificent George Williams scores perfect 10. 2. Australia (Men) – SAME. The Kangaroos remain 2nd in our Power Rankings after their 22-10 win against New Zealand in Christchurch on Sunday. As was the case in their win against Tonga last week, Mal Meninga’s side weren ...",
+				"dateLastCrawled": "2024-10-28T07:09:00.0000000Z",
+				"cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=1535008462672&mkt=en-US&setlang=en-US&w=82ujhH4Kp0iuhCS7wh1xLUFYUeetaVVm",
+				"language": "en",
+				"isNavigational": false,
+				"noCache": false
+			}
+		],
+		"someResultsRemoved": true
+	}
+}

+ 0 - 0
backend/open_webui/apps/retrieval/web/testdata/brave.json → backend/open_webui/retrieval/web/testdata/brave.json


+ 0 - 0
backend/open_webui/apps/retrieval/web/testdata/google_pse.json → backend/open_webui/retrieval/web/testdata/google_pse.json


+ 0 - 0
backend/open_webui/apps/retrieval/web/testdata/searchapi.json → backend/open_webui/retrieval/web/testdata/searchapi.json


+ 0 - 0
backend/open_webui/apps/retrieval/web/testdata/searxng.json → backend/open_webui/retrieval/web/testdata/searxng.json


+ 0 - 0
backend/open_webui/apps/retrieval/web/testdata/serper.json → backend/open_webui/retrieval/web/testdata/serper.json


+ 0 - 0
backend/open_webui/apps/retrieval/web/testdata/serply.json → backend/open_webui/retrieval/web/testdata/serply.json


+ 0 - 0
backend/open_webui/apps/retrieval/web/testdata/serpstack.json → backend/open_webui/retrieval/web/testdata/serpstack.json


+ 3 - 3
backend/open_webui/apps/retrieval/web/utils.py → backend/open_webui/retrieval/web/utils.py

@@ -82,15 +82,15 @@ class SafeWebBaseLoader(WebBaseLoader):
 
 
 
 
 def get_web_loader(
 def get_web_loader(
-    url: Union[str, Sequence[str]],
+    urls: Union[str, Sequence[str]],
     verify_ssl: bool = True,
     verify_ssl: bool = True,
     requests_per_second: int = 2,
     requests_per_second: int = 2,
 ):
 ):
     # Check if the URL is valid
     # Check if the URL is valid
-    if not validate_url(url):
+    if not validate_url(urls):
         raise ValueError(ERROR_MESSAGES.INVALID_URL)
         raise ValueError(ERROR_MESSAGES.INVALID_URL)
     return SafeWebBaseLoader(
     return SafeWebBaseLoader(
-        url,
+        urls,
         verify_ssl=verify_ssl,
         verify_ssl=verify_ssl,
         requests_per_second=requests_per_second,
         requests_per_second=requests_per_second,
         continue_on_failure=True,
         continue_on_failure=True,

+ 703 - 0
backend/open_webui/routers/audio.py

@@ -0,0 +1,703 @@
+import hashlib
+import json
+import logging
+import os
+import uuid
+from functools import lru_cache
+from pathlib import Path
+from pydub import AudioSegment
+from pydub.silence import split_on_silence
+
+import aiohttp
+import aiofiles
+import requests
+
+from fastapi import (
+    Depends,
+    FastAPI,
+    File,
+    HTTPException,
+    Request,
+    UploadFile,
+    status,
+    APIRouter,
+)
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse
+from pydantic import BaseModel
+
+
+from open_webui.utils.auth import get_admin_user, get_verified_user
+from open_webui.config import (
+    WHISPER_MODEL_AUTO_UPDATE,
+    WHISPER_MODEL_DIR,
+    CACHE_DIR,
+)
+
+from open_webui.constants import ERROR_MESSAGES
+from open_webui.env import (
+    ENV,
+    SRC_LOG_LEVELS,
+    DEVICE_TYPE,
+    ENABLE_FORWARD_USER_INFO_HEADERS,
+)
+
+
+router = APIRouter()
+
+# Constants
+MAX_FILE_SIZE_MB = 25
+MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024  # Convert MB to bytes
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["AUDIO"])
+
+SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
+SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
+
+
+##########################################
+#
+# Utility functions
+#
+##########################################
+
+from pydub import AudioSegment
+from pydub.utils import mediainfo
+
+
+def is_mp4_audio(file_path):
+    """Check if the given file is an MP4 audio file."""
+    if not os.path.isfile(file_path):
+        print(f"File not found: {file_path}")
+        return False
+
+    info = mediainfo(file_path)
+    if (
+        info.get("codec_name") == "aac"
+        and info.get("codec_type") == "audio"
+        and info.get("codec_tag_string") == "mp4a"
+    ):
+        return True
+    return False
+
+
+def convert_mp4_to_wav(file_path, output_path):
+    """Convert MP4 audio file to WAV format."""
+    audio = AudioSegment.from_file(file_path, format="mp4")
+    audio.export(output_path, format="wav")
+    print(f"Converted {file_path} to {output_path}")
+
+
+def set_faster_whisper_model(model: str, auto_update: bool = False):
+    whisper_model = None
+    if model:
+        from faster_whisper import WhisperModel
+
+        faster_whisper_kwargs = {
+            "model_size_or_path": model,
+            "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu",
+            "compute_type": "int8",
+            "download_root": WHISPER_MODEL_DIR,
+            "local_files_only": not auto_update,
+        }
+
+        try:
+            whisper_model = WhisperModel(**faster_whisper_kwargs)
+        except Exception:
+            log.warning(
+                "WhisperModel initialization failed, attempting download with local_files_only=False"
+            )
+            faster_whisper_kwargs["local_files_only"] = False
+            whisper_model = WhisperModel(**faster_whisper_kwargs)
+    return whisper_model
+
+
+##########################################
+#
+# Audio API
+#
+##########################################
+
+
+class TTSConfigForm(BaseModel):
+    OPENAI_API_BASE_URL: str
+    OPENAI_API_KEY: str
+    API_KEY: str
+    ENGINE: str
+    MODEL: str
+    VOICE: str
+    SPLIT_ON: str
+    AZURE_SPEECH_REGION: str
+    AZURE_SPEECH_OUTPUT_FORMAT: str
+
+
+class STTConfigForm(BaseModel):
+    OPENAI_API_BASE_URL: str
+    OPENAI_API_KEY: str
+    ENGINE: str
+    MODEL: str
+    WHISPER_MODEL: str
+
+
+class AudioConfigUpdateForm(BaseModel):
+    tts: TTSConfigForm
+    stt: STTConfigForm
+
+
+@router.get("/config")
+async def get_audio_config(request: Request, user=Depends(get_admin_user)):
+    return {
+        "tts": {
+            "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
+            "API_KEY": request.app.state.config.TTS_API_KEY,
+            "ENGINE": request.app.state.config.TTS_ENGINE,
+            "MODEL": request.app.state.config.TTS_MODEL,
+            "VOICE": request.app.state.config.TTS_VOICE,
+            "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
+            "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
+            "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
+        },
+        "stt": {
+            "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
+            "ENGINE": request.app.state.config.STT_ENGINE,
+            "MODEL": request.app.state.config.STT_MODEL,
+            "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
+        },
+    }
+
+
+@router.post("/config/update")
+async def update_audio_config(
+    request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
+):
+    request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
+    request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
+    request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
+    request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
+    request.app.state.config.TTS_MODEL = form_data.tts.MODEL
+    request.app.state.config.TTS_VOICE = form_data.tts.VOICE
+    request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
+    request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
+    request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
+        form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
+    )
+
+    request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
+    request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
+    request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
+    request.app.state.config.STT_MODEL = form_data.stt.MODEL
+    request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
+
+    if request.app.state.config.STT_ENGINE == "":
+        request.app.state.faster_whisper_model = set_faster_whisper_model(
+            form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE
+        )
+
+    return {
+        "tts": {
+            "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
+            "API_KEY": request.app.state.config.TTS_API_KEY,
+            "ENGINE": request.app.state.config.TTS_ENGINE,
+            "MODEL": request.app.state.config.TTS_MODEL,
+            "VOICE": request.app.state.config.TTS_VOICE,
+            "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
+            "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
+            "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
+        },
+        "stt": {
+            "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
+            "ENGINE": request.app.state.config.STT_ENGINE,
+            "MODEL": request.app.state.config.STT_MODEL,
+            "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
+        },
+    }
+
+
+def load_speech_pipeline():
+    from transformers import pipeline
+    from datasets import load_dataset
+
+    if request.app.state.speech_synthesiser is None:
+        request.app.state.speech_synthesiser = pipeline(
+            "text-to-speech", "microsoft/speecht5_tts"
+        )
+
+    if request.app.state.speech_speaker_embeddings_dataset is None:
+        request.app.state.speech_speaker_embeddings_dataset = load_dataset(
+            "Matthijs/cmu-arctic-xvectors", split="validation"
+        )
+
+
+@router.post("/speech")
+async def speech(request: Request, user=Depends(get_verified_user)):
+    body = await request.body()
+    name = hashlib.sha256(body).hexdigest()
+
+    file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
+    file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
+
+    # Check if the file already exists in the cache
+    if file_path.is_file():
+        return FileResponse(file_path)
+
+    payload = None
+    try:
+        payload = json.loads(body.decode("utf-8"))
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(status_code=400, detail="Invalid JSON payload")
+
+    if request.app.state.config.TTS_ENGINE == "openai":
+        payload["model"] = request.app.state.config.TTS_MODEL
+
+        try:
+            async with aiohttp.ClientSession() as session:
+                async with session.post(
+                    url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
+                    data=payload,
+                    headers={
+                        "Content-Type": "application/json",
+                        "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
+                        **(
+                            {
+                                "X-OpenWebUI-User-Name": user.name,
+                                "X-OpenWebUI-User-Id": user.id,
+                                "X-OpenWebUI-User-Email": user.email,
+                                "X-OpenWebUI-User-Role": user.role,
+                            }
+                            if ENABLE_FORWARD_USER_INFO_HEADERS
+                            else {}
+                        ),
+                    },
+                ) as r:
+                    r.raise_for_status()
+
+                    async with aiofiles.open(file_path, "wb") as f:
+                        await f.write(await r.read())
+
+                    async with aiofiles.open(file_body_path, "w") as f:
+                        await f.write(json.dumps(json.loads(body.decode("utf-8"))))
+
+            return FileResponse(file_path)
+
+        except Exception as e:
+            log.exception(e)
+            detail = None
+
+            try:
+                if r.status != 200:
+                    res = await r.json()
+                    if "error" in res:
+                        detail = f"External: {res['error'].get('message', '')}"
+            except Exception:
+                detail = f"External: {e}"
+
+            raise HTTPException(
+                status_code=getattr(r, "status", 500),
+                detail=detail if detail else "Open WebUI: Server Connection Error",
+            )
+
+    elif request.app.state.config.TTS_ENGINE == "elevenlabs":
+        voice_id = payload.get("voice", "")
+
+        if voice_id not in get_available_voices():
+            raise HTTPException(
+                status_code=400,
+                detail="Invalid voice id",
+            )
+
+        try:
+            async with aiohttp.ClientSession() as session:
+                async with session.post(
+                    f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
+                    json={
+                        "text": payload["input"],
+                        "model_id": request.app.state.config.TTS_MODEL,
+                        "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
+                    },
+                    headers={
+                        "Accept": "audio/mpeg",
+                        "Content-Type": "application/json",
+                        "xi-api-key": request.app.state.config.TTS_API_KEY,
+                    },
+                ) as r:
+                    r.raise_for_status()
+
+                    async with aiofiles.open(file_path, "wb") as f:
+                        await f.write(await r.read())
+
+                    async with aiofiles.open(file_body_path, "w") as f:
+                        await f.write(json.dumps(json.loads(body.decode("utf-8"))))
+
+            return FileResponse(file_path)
+
+        except Exception as e:
+            log.exception(e)
+            detail = None
+
+            try:
+                if r.status != 200:
+                    res = await r.json()
+                    if "error" in res:
+                        detail = f"External: {res['error'].get('message', '')}"
+            except Exception:
+                detail = f"External: {e}"
+
+            raise HTTPException(
+                status_code=getattr(r, "status", 500),
+                detail=detail if detail else "Open WebUI: Server Connection Error",
+            )
+
+    elif request.app.state.config.TTS_ENGINE == "azure":
+        try:
+            payload = json.loads(body.decode("utf-8"))
+        except Exception as e:
+            log.exception(e)
+            raise HTTPException(status_code=400, detail="Invalid JSON payload")
+
+        region = request.app.state.config.TTS_AZURE_SPEECH_REGION
+        language = request.app.state.config.TTS_VOICE
+        locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
+        output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
+
+        try:
+            data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
+                <voice name="{language}">{payload["input"]}</voice>
+            </speak>"""
+            async with aiohttp.ClientSession() as session:
+                async with session.post(
+                    f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
+                    headers={
+                        "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
+                        "Content-Type": "application/ssml+xml",
+                        "X-Microsoft-OutputFormat": output_format,
+                    },
+                    data=data,
+                ) as r:
+                    r.raise_for_status()
+
+                    async with aiofiles.open(file_path, "wb") as f:
+                        await f.write(await r.read())
+
+                    return FileResponse(file_path)
+
+        except Exception as e:
+            log.exception(e)
+            detail = None
+
+            try:
+                if r.status != 200:
+                    res = await r.json()
+                    if "error" in res:
+                        detail = f"External: {res['error'].get('message', '')}"
+            except Exception:
+                detail = f"External: {e}"
+
+            raise HTTPException(
+                status_code=getattr(r, "status", 500),
+                detail=detail if detail else "Open WebUI: Server Connection Error",
+            )
+
+    elif request.app.state.config.TTS_ENGINE == "transformers":
+        payload = None
+        try:
+            payload = json.loads(body.decode("utf-8"))
+        except Exception as e:
+            log.exception(e)
+            raise HTTPException(status_code=400, detail="Invalid JSON payload")
+
+        import torch
+        import soundfile as sf
+
+        load_speech_pipeline()
+
+        embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
+
+        speaker_index = 6799
+        try:
+            speaker_index = embeddings_dataset["filename"].index(
+                request.app.state.config.TTS_MODEL
+            )
+        except Exception:
+            pass
+
+        speaker_embedding = torch.tensor(
+            embeddings_dataset[speaker_index]["xvector"]
+        ).unsqueeze(0)
+
+        speech = request.app.state.speech_synthesiser(
+            payload["input"],
+            forward_params={"speaker_embeddings": speaker_embedding},
+        )
+
+        sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
+        with open(file_body_path, "w") as f:
+            json.dump(json.loads(body.decode("utf-8")), f)
+
+        return FileResponse(file_path)
+
+
+def transcribe(request: Request, file_path):
+    print("transcribe", file_path)
+    filename = os.path.basename(file_path)
+    file_dir = os.path.dirname(file_path)
+    id = filename.split(".")[0]
+
+    if request.app.state.config.STT_ENGINE == "":
+        if request.app.state.faster_whisper_model is None:
+            request.app.state.faster_whisper_model = set_faster_whisper_model(
+                request.app.state.config.WHISPER_MODEL
+            )
+
+        model = request.app.state.faster_whisper_model
+        segments, info = model.transcribe(file_path, beam_size=5)
+        log.info(
+            "Detected language '%s' with probability %f"
+            % (info.language, info.language_probability)
+        )
+
+        transcript = "".join([segment.text for segment in list(segments)])
+        data = {"text": transcript.strip()}
+
+        # save the transcript to a json file
+        transcript_file = f"{file_dir}/{id}.json"
+        with open(transcript_file, "w") as f:
+            json.dump(data, f)
+
+        log.debug(data)
+        return data
+    elif request.app.state.config.STT_ENGINE == "openai":
+        if is_mp4_audio(file_path):
+            os.rename(file_path, file_path.replace(".wav", ".mp4"))
+            # Convert MP4 audio file to WAV format
+            convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
+
+        r = None
+        try:
+            r = requests.post(
+                url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
+                headers={
+                    "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
+                },
+                files={"file": (filename, open(file_path, "rb"))},
+                data={"model": request.app.state.config.STT_MODEL},
+            )
+
+            r.raise_for_status()
+            data = r.json()
+
+            # save the transcript to a json file
+            transcript_file = f"{file_dir}/{id}.json"
+            with open(transcript_file, "w") as f:
+                json.dump(data, f)
+
+            return data
+        except Exception as e:
+            log.exception(e)
+
+            detail = None
+            if r is not None:
+                try:
+                    res = r.json()
+                    if "error" in res:
+                        detail = f"External: {res['error'].get('message', '')}"
+                except Exception:
+                    detail = f"External: {e}"
+
+            raise Exception(detail if detail else "Open WebUI: Server Connection Error")
+
+
+def compress_audio(file_path):
+    if os.path.getsize(file_path) > MAX_FILE_SIZE:
+        file_dir = os.path.dirname(file_path)
+        audio = AudioSegment.from_file(file_path)
+        audio = audio.set_frame_rate(16000).set_channels(1)  # Compress audio
+        compressed_path = f"{file_dir}/{id}_compressed.opus"
+        audio.export(compressed_path, format="opus", bitrate="32k")
+        log.debug(f"Compressed audio to {compressed_path}")
+
+        if (
+            os.path.getsize(compressed_path) > MAX_FILE_SIZE
+        ):  # Still larger than MAX_FILE_SIZE after compression
+            raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
+        return compressed_path
+    else:
+        return file_path
+
+
+@router.post("/transcriptions")
+def transcription(
+    request: Request,
+    file: UploadFile = File(...),
+    user=Depends(get_verified_user),
+):
+    log.info(f"file.content_type: {file.content_type}")
+
+    if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
+        )
+
+    try:
+        ext = file.filename.split(".")[-1]
+        id = uuid.uuid4()
+
+        filename = f"{id}.{ext}"
+        contents = file.file.read()
+
+        file_dir = f"{CACHE_DIR}/audio/transcriptions"
+        os.makedirs(file_dir, exist_ok=True)
+        file_path = f"{file_dir}/{filename}"
+
+        with open(file_path, "wb") as f:
+            f.write(contents)
+
+        try:
+            try:
+                file_path = compress_audio(file_path)
+            except Exception as e:
+                log.exception(e)
+
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT(e),
+                )
+
+            data = transcribe(request, file_path)
+            file_path = file_path.split("/")[-1]
+            return {**data, "filename": file_path}
+        except Exception as e:
+            log.exception(e)
+
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT(e),
+            )
+
+    except Exception as e:
+        log.exception(e)
+
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
+
+def get_available_models(request: Request) -> list[dict]:
+    available_models = []
+    if request.app.state.config.TTS_ENGINE == "openai":
+        available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
+    elif request.app.state.config.TTS_ENGINE == "elevenlabs":
+        try:
+            response = requests.get(
+                "https://api.elevenlabs.io/v1/models",
+                headers={
+                    "xi-api-key": request.app.state.config.TTS_API_KEY,
+                    "Content-Type": "application/json",
+                },
+                timeout=5,
+            )
+            response.raise_for_status()
+            models = response.json()
+
+            available_models = [
+                {"name": model["name"], "id": model["model_id"]} for model in models
+            ]
+        except requests.RequestException as e:
+            log.error(f"Error fetching voices: {str(e)}")
+    return available_models
+
+
+@router.get("/models")
+async def get_models(request: Request, user=Depends(get_verified_user)):
+    return {"models": get_available_models(request)}
+
+
+def get_available_voices(request) -> dict:
+    """Returns {voice_id: voice_name} dict"""
+    available_voices = {}
+    if request.app.state.config.TTS_ENGINE == "openai":
+        available_voices = {
+            "alloy": "alloy",
+            "echo": "echo",
+            "fable": "fable",
+            "onyx": "onyx",
+            "nova": "nova",
+            "shimmer": "shimmer",
+        }
+    elif request.app.state.config.TTS_ENGINE == "elevenlabs":
+        try:
+            available_voices = get_elevenlabs_voices(
+                api_key=request.app.state.config.TTS_API_KEY
+            )
+        except Exception:
+            # Avoided @lru_cache with exception
+            pass
+    elif request.app.state.config.TTS_ENGINE == "azure":
+        try:
+            region = request.app.state.config.TTS_AZURE_SPEECH_REGION
+            url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
+            headers = {
+                "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
+            }
+
+            response = requests.get(url, headers=headers)
+            response.raise_for_status()
+            voices = response.json()
+
+            for voice in voices:
+                available_voices[voice["ShortName"]] = (
+                    f"{voice['DisplayName']} ({voice['ShortName']})"
+                )
+        except requests.RequestException as e:
+            log.error(f"Error fetching voices: {str(e)}")
+
+    return available_voices
+
+
+@lru_cache
+def get_elevenlabs_voices(api_key: str) -> dict:
+    """
+    Note, set the following in your .env file to use Elevenlabs:
+    AUDIO_TTS_ENGINE=elevenlabs
+    AUDIO_TTS_API_KEY=sk_...  # Your Elevenlabs API key
+    AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL  # From https://api.elevenlabs.io/v1/voices
+    AUDIO_TTS_MODEL=eleven_multilingual_v2
+    """
+
+    try:
+        # TODO: Add retries
+        response = requests.get(
+            "https://api.elevenlabs.io/v1/voices",
+            headers={
+                "xi-api-key": api_key,
+                "Content-Type": "application/json",
+            },
+        )
+        response.raise_for_status()
+        voices_data = response.json()
+
+        voices = {}
+        for voice in voices_data.get("voices", []):
+            voices[voice["voice_id"]] = voice["name"]
+    except requests.RequestException as e:
+        # Avoid @lru_cache with exception
+        log.error(f"Error fetching voices: {str(e)}")
+        raise RuntimeError(f"Error fetching voices: {str(e)}")
+
+    return voices
+
+
+@router.get("/voices")
+async def get_voices(request: Request, user=Depends(get_verified_user)):
+    return {
+        "voices": [
+            {"id": k, "name": v} for k, v in get_available_voices(request).items()
+        ]
+    }

+ 320 - 8
backend/open_webui/apps/webui/routers/auths.py → backend/open_webui/routers/auths.py

@@ -2,12 +2,15 @@ import re
 import uuid
 import uuid
 import time
 import time
 import datetime
 import datetime
+import logging
+from aiohttp import ClientSession
 
 
-from open_webui.apps.webui.models.auths import (
+from open_webui.models.auths import (
     AddUserForm,
     AddUserForm,
     ApiKey,
     ApiKey,
     Auths,
     Auths,
     Token,
     Token,
+    LdapForm,
     SigninForm,
     SigninForm,
     SigninResponse,
     SigninResponse,
     SignupForm,
     SignupForm,
@@ -15,20 +18,26 @@ from open_webui.apps.webui.models.auths import (
     UpdateProfileForm,
     UpdateProfileForm,
     UserResponse,
     UserResponse,
 )
 )
-from open_webui.apps.webui.models.users import Users
-from open_webui.config import WEBUI_AUTH
+from open_webui.models.users import Users
+
 from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
 from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
 from open_webui.env import (
 from open_webui.env import (
+    WEBUI_AUTH,
     WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
     WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
     WEBUI_AUTH_TRUSTED_NAME_HEADER,
     WEBUI_AUTH_TRUSTED_NAME_HEADER,
     WEBUI_SESSION_COOKIE_SAME_SITE,
     WEBUI_SESSION_COOKIE_SAME_SITE,
     WEBUI_SESSION_COOKIE_SECURE,
     WEBUI_SESSION_COOKIE_SECURE,
+    SRC_LOG_LEVELS,
 )
 )
 from fastapi import APIRouter, Depends, HTTPException, Request, status
 from fastapi import APIRouter, Depends, HTTPException, Request, status
-from fastapi.responses import Response
+from fastapi.responses import RedirectResponse, Response
+from open_webui.config import (
+    OPENID_PROVIDER_URL,
+    ENABLE_OAUTH_SIGNUP,
+)
 from pydantic import BaseModel
 from pydantic import BaseModel
 from open_webui.utils.misc import parse_duration, validate_email_format
 from open_webui.utils.misc import parse_duration, validate_email_format
-from open_webui.utils.utils import (
+from open_webui.utils.auth import (
     create_api_key,
     create_api_key,
     create_token,
     create_token,
     get_admin_user,
     get_admin_user,
@@ -37,10 +46,19 @@ from open_webui.utils.utils import (
     get_password_hash,
     get_password_hash,
 )
 )
 from open_webui.utils.webhook import post_webhook
 from open_webui.utils.webhook import post_webhook
-from typing import Optional
+from open_webui.utils.access_control import get_permissions
+
+from typing import Optional, List
+
+from ssl import CERT_REQUIRED, PROTOCOL_TLS
+from ldap3 import Server, Connection, ALL, Tls
+from ldap3.utils.conv import escape_filter_chars
 
 
 router = APIRouter()
 router = APIRouter()
 
 
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MAIN"])
+
 ############################
 ############################
 # GetSessionUser
 # GetSessionUser
 ############################
 ############################
@@ -48,6 +66,7 @@ router = APIRouter()
 
 
 class SessionUserResponse(Token, UserResponse):
 class SessionUserResponse(Token, UserResponse):
     expires_at: Optional[int] = None
     expires_at: Optional[int] = None
+    permissions: Optional[dict] = None
 
 
 
 
 @router.get("/", response_model=SessionUserResponse)
 @router.get("/", response_model=SessionUserResponse)
@@ -80,6 +99,10 @@ async def get_session_user(
         secure=WEBUI_SESSION_COOKIE_SECURE,
         secure=WEBUI_SESSION_COOKIE_SECURE,
     )
     )
 
 
+    user_permissions = get_permissions(
+        user.id, request.app.state.config.USER_PERMISSIONS
+    )
+
     return {
     return {
         "token": token,
         "token": token,
         "token_type": "Bearer",
         "token_type": "Bearer",
@@ -89,6 +112,7 @@ async def get_session_user(
         "name": user.name,
         "name": user.name,
         "role": user.role,
         "role": user.role,
         "profile_image_url": user.profile_image_url,
         "profile_image_url": user.profile_image_url,
+        "permissions": user_permissions,
     }
     }
 
 
 
 
@@ -137,6 +161,146 @@ async def update_password(
         raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
         raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
 
 
 
 
+############################
+# LDAP Authentication
+############################
+@router.post("/ldap", response_model=SigninResponse)
+async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
+    ENABLE_LDAP = request.app.state.config.ENABLE_LDAP
+    LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL
+    LDAP_SERVER_HOST = request.app.state.config.LDAP_SERVER_HOST
+    LDAP_SERVER_PORT = request.app.state.config.LDAP_SERVER_PORT
+    LDAP_ATTRIBUTE_FOR_USERNAME = request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME
+    LDAP_SEARCH_BASE = request.app.state.config.LDAP_SEARCH_BASE
+    LDAP_SEARCH_FILTERS = request.app.state.config.LDAP_SEARCH_FILTERS
+    LDAP_APP_DN = request.app.state.config.LDAP_APP_DN
+    LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD
+    LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS
+    LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE
+    LDAP_CIPHERS = (
+        request.app.state.config.LDAP_CIPHERS
+        if request.app.state.config.LDAP_CIPHERS
+        else "ALL"
+    )
+
+    if not ENABLE_LDAP:
+        raise HTTPException(400, detail="LDAP authentication is not enabled")
+
+    try:
+        tls = Tls(
+            validate=CERT_REQUIRED,
+            version=PROTOCOL_TLS,
+            ca_certs_file=LDAP_CA_CERT_FILE,
+            ciphers=LDAP_CIPHERS,
+        )
+    except Exception as e:
+        log.error(f"An error occurred on TLS: {str(e)}")
+        raise HTTPException(400, detail=str(e))
+
+    try:
+        server = Server(
+            host=LDAP_SERVER_HOST,
+            port=LDAP_SERVER_PORT,
+            get_info=ALL,
+            use_ssl=LDAP_USE_TLS,
+            tls=tls,
+        )
+        connection_app = Connection(
+            server,
+            LDAP_APP_DN,
+            LDAP_APP_PASSWORD,
+            auto_bind="NONE",
+            authentication="SIMPLE",
+        )
+        if not connection_app.bind():
+            raise HTTPException(400, detail="Application account bind failed")
+
+        search_success = connection_app.search(
+            search_base=LDAP_SEARCH_BASE,
+            search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})",
+            attributes=[f"{LDAP_ATTRIBUTE_FOR_USERNAME}", "mail", "cn"],
+        )
+
+        if not search_success:
+            raise HTTPException(400, detail="User not found in the LDAP server")
+
+        entry = connection_app.entries[0]
+        username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
+        mail = str(entry["mail"])
+        cn = str(entry["cn"])
+        user_dn = entry.entry_dn
+
+        if username == form_data.user.lower():
+            connection_user = Connection(
+                server,
+                user_dn,
+                form_data.password,
+                auto_bind="NONE",
+                authentication="SIMPLE",
+            )
+            if not connection_user.bind():
+                raise HTTPException(400, f"Authentication failed for {form_data.user}")
+
+            user = Users.get_user_by_email(mail)
+            if not user:
+                try:
+                    role = (
+                        "admin"
+                        if Users.get_num_users() == 0
+                        else request.app.state.config.DEFAULT_USER_ROLE
+                    )
+
+                    user = Auths.insert_new_auth(
+                        email=mail, password=str(uuid.uuid4()), name=cn, role=role
+                    )
+
+                    if not user:
+                        raise HTTPException(
+                            500, detail=ERROR_MESSAGES.CREATE_USER_ERROR
+                        )
+
+                except HTTPException:
+                    raise
+                except Exception as err:
+                    raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
+
+            user = Auths.authenticate_user_by_trusted_header(mail)
+
+            if user:
+                token = create_token(
+                    data={"id": user.id},
+                    expires_delta=parse_duration(
+                        request.app.state.config.JWT_EXPIRES_IN
+                    ),
+                )
+
+                # Set the cookie token
+                response.set_cookie(
+                    key="token",
+                    value=token,
+                    httponly=True,  # Ensures the cookie is not accessible via JavaScript
+                )
+
+                return {
+                    "token": token,
+                    "token_type": "Bearer",
+                    "id": user.id,
+                    "email": user.email,
+                    "name": user.name,
+                    "role": user.role,
+                    "profile_image_url": user.profile_image_url,
+                }
+            else:
+                raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+        else:
+            raise HTTPException(
+                400,
+                f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}",
+            )
+    except Exception as e:
+        raise HTTPException(400, detail=str(e))
+
+
 ############################
 ############################
 # SignIn
 # SignIn
 ############################
 ############################
@@ -211,6 +375,10 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
             secure=WEBUI_SESSION_COOKIE_SECURE,
             secure=WEBUI_SESSION_COOKIE_SECURE,
         )
         )
 
 
+        user_permissions = get_permissions(
+            user.id, request.app.state.config.USER_PERMISSIONS
+        )
+
         return {
         return {
             "token": token,
             "token": token,
             "token_type": "Bearer",
             "token_type": "Bearer",
@@ -220,6 +388,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
             "name": user.name,
             "name": user.name,
             "role": user.role,
             "role": user.role,
             "profile_image_url": user.profile_image_url,
             "profile_image_url": user.profile_image_url,
+            "permissions": user_permissions,
         }
         }
     else:
     else:
         raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
         raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
@@ -260,6 +429,11 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
             if Users.get_num_users() == 0
             if Users.get_num_users() == 0
             else request.app.state.config.DEFAULT_USER_ROLE
             else request.app.state.config.DEFAULT_USER_ROLE
         )
         )
+
+        if Users.get_num_users() == 0:
+            # Disable signup after the first user is created
+            request.app.state.config.ENABLE_SIGNUP = False
+
         hashed = get_password_hash(form_data.password)
         hashed = get_password_hash(form_data.password)
         user = Auths.insert_new_auth(
         user = Auths.insert_new_auth(
             form_data.email.lower(),
             form_data.email.lower(),
@@ -307,6 +481,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
                     },
                     },
                 )
                 )
 
 
+            user_permissions = get_permissions(
+                user.id, request.app.state.config.USER_PERMISSIONS
+            )
+
             return {
             return {
                 "token": token,
                 "token": token,
                 "token_type": "Bearer",
                 "token_type": "Bearer",
@@ -316,6 +494,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
                 "name": user.name,
                 "name": user.name,
                 "role": user.role,
                 "role": user.role,
                 "profile_image_url": user.profile_image_url,
                 "profile_image_url": user.profile_image_url,
+                "permissions": user_permissions,
             }
             }
         else:
         else:
             raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
             raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
@@ -324,8 +503,31 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
 
 
 
 
 @router.get("/signout")
 @router.get("/signout")
-async def signout(response: Response):
+async def signout(request: Request, response: Response):
     response.delete_cookie("token")
     response.delete_cookie("token")
+
+    if ENABLE_OAUTH_SIGNUP.value:
+        oauth_id_token = request.cookies.get("oauth_id_token")
+        if oauth_id_token:
+            try:
+                async with ClientSession() as session:
+                    async with session.get(OPENID_PROVIDER_URL.value) as resp:
+                        if resp.status == 200:
+                            openid_data = await resp.json()
+                            logout_url = openid_data.get("end_session_endpoint")
+                            if logout_url:
+                                response.delete_cookie("oauth_id_token")
+                                return RedirectResponse(
+                                    url=f"{logout_url}?id_token_hint={oauth_id_token}"
+                                )
+                        else:
+                            raise HTTPException(
+                                status_code=resp.status,
+                                detail="Failed to fetch OpenID configuration",
+                            )
+            except Exception as e:
+                raise HTTPException(status_code=500, detail=str(e))
+
     return {"status": True}
     return {"status": True}
 
 
 
 
@@ -413,6 +615,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
     return {
     return {
         "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
         "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
         "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
         "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
+        "ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
         "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
         "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
         "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
         "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
         "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
         "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
@@ -423,6 +626,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
 class AdminConfig(BaseModel):
 class AdminConfig(BaseModel):
     SHOW_ADMIN_DETAILS: bool
     SHOW_ADMIN_DETAILS: bool
     ENABLE_SIGNUP: bool
     ENABLE_SIGNUP: bool
+    ENABLE_API_KEY: bool
     DEFAULT_USER_ROLE: str
     DEFAULT_USER_ROLE: str
     JWT_EXPIRES_IN: str
     JWT_EXPIRES_IN: str
     ENABLE_COMMUNITY_SHARING: bool
     ENABLE_COMMUNITY_SHARING: bool
@@ -435,6 +639,7 @@ async def update_admin_config(
 ):
 ):
     request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
     request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
     request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
     request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
+    request.app.state.config.ENABLE_API_KEY = form_data.ENABLE_API_KEY
 
 
     if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
     if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
         request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
         request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
@@ -453,6 +658,7 @@ async def update_admin_config(
     return {
     return {
         "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
         "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
         "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
         "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
+        "ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
         "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
         "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
         "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
         "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
         "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
         "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
@@ -460,6 +666,105 @@ async def update_admin_config(
     }
     }
 
 
 
 
+class LdapServerConfig(BaseModel):
+    label: str
+    host: str
+    port: Optional[int] = None
+    attribute_for_username: str = "uid"
+    app_dn: str
+    app_dn_password: str
+    search_base: str
+    search_filters: str = ""
+    use_tls: bool = True
+    certificate_path: Optional[str] = None
+    ciphers: Optional[str] = "ALL"
+
+
+@router.get("/admin/config/ldap/server", response_model=LdapServerConfig)
+async def get_ldap_server(request: Request, user=Depends(get_admin_user)):
+    return {
+        "label": request.app.state.config.LDAP_SERVER_LABEL,
+        "host": request.app.state.config.LDAP_SERVER_HOST,
+        "port": request.app.state.config.LDAP_SERVER_PORT,
+        "attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME,
+        "app_dn": request.app.state.config.LDAP_APP_DN,
+        "app_dn_password": request.app.state.config.LDAP_APP_PASSWORD,
+        "search_base": request.app.state.config.LDAP_SEARCH_BASE,
+        "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
+        "use_tls": request.app.state.config.LDAP_USE_TLS,
+        "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
+        "ciphers": request.app.state.config.LDAP_CIPHERS,
+    }
+
+
+@router.post("/admin/config/ldap/server")
+async def update_ldap_server(
+    request: Request, form_data: LdapServerConfig, user=Depends(get_admin_user)
+):
+    required_fields = [
+        "label",
+        "host",
+        "attribute_for_username",
+        "app_dn",
+        "app_dn_password",
+        "search_base",
+    ]
+    for key in required_fields:
+        value = getattr(form_data, key)
+        if not value:
+            raise HTTPException(400, detail=f"Required field {key} is empty")
+
+    if form_data.use_tls and not form_data.certificate_path:
+        raise HTTPException(
+            400, detail="TLS is enabled but certificate file path is missing"
+        )
+
+    request.app.state.config.LDAP_SERVER_LABEL = form_data.label
+    request.app.state.config.LDAP_SERVER_HOST = form_data.host
+    request.app.state.config.LDAP_SERVER_PORT = form_data.port
+    request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = (
+        form_data.attribute_for_username
+    )
+    request.app.state.config.LDAP_APP_DN = form_data.app_dn
+    request.app.state.config.LDAP_APP_PASSWORD = form_data.app_dn_password
+    request.app.state.config.LDAP_SEARCH_BASE = form_data.search_base
+    request.app.state.config.LDAP_SEARCH_FILTERS = form_data.search_filters
+    request.app.state.config.LDAP_USE_TLS = form_data.use_tls
+    request.app.state.config.LDAP_CA_CERT_FILE = form_data.certificate_path
+    request.app.state.config.LDAP_CIPHERS = form_data.ciphers
+
+    return {
+        "label": request.app.state.config.LDAP_SERVER_LABEL,
+        "host": request.app.state.config.LDAP_SERVER_HOST,
+        "port": request.app.state.config.LDAP_SERVER_PORT,
+        "attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME,
+        "app_dn": request.app.state.config.LDAP_APP_DN,
+        "app_dn_password": request.app.state.config.LDAP_APP_PASSWORD,
+        "search_base": request.app.state.config.LDAP_SEARCH_BASE,
+        "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
+        "use_tls": request.app.state.config.LDAP_USE_TLS,
+        "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
+        "ciphers": request.app.state.config.LDAP_CIPHERS,
+    }
+
+
+@router.get("/admin/config/ldap")
+async def get_ldap_config(request: Request, user=Depends(get_admin_user)):
+    return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP}
+
+
+class LdapConfigForm(BaseModel):
+    enable_ldap: Optional[bool] = None
+
+
+@router.post("/admin/config/ldap")
+async def update_ldap_config(
+    request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user)
+):
+    request.app.state.config.ENABLE_LDAP = form_data.enable_ldap
+    return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP}
+
+
 ############################
 ############################
 # API Key
 # API Key
 ############################
 ############################
@@ -467,9 +772,16 @@ async def update_admin_config(
 
 
 # create api key
 # create api key
 @router.post("/api_key", response_model=ApiKey)
 @router.post("/api_key", response_model=ApiKey)
-async def create_api_key_(user=Depends(get_current_user)):
+async def generate_api_key(request: Request, user=Depends(get_current_user)):
+    if not request.app.state.config.ENABLE_API_KEY:
+        raise HTTPException(
+            status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.API_KEY_CREATION_NOT_ALLOWED,
+        )
+
     api_key = create_api_key()
     api_key = create_api_key()
     success = Users.update_user_api_key_by_id(user.id, api_key)
     success = Users.update_user_api_key_by_id(user.id, api_key)
+
     if success:
     if success:
         return {
         return {
             "api_key": api_key,
             "api_key": api_key,

+ 13 - 10
backend/open_webui/apps/webui/routers/chats.py → backend/open_webui/routers/chats.py

@@ -2,22 +2,25 @@ import json
 import logging
 import logging
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.apps.webui.models.chats import (
+from open_webui.models.chats import (
     ChatForm,
     ChatForm,
     ChatImportForm,
     ChatImportForm,
     ChatResponse,
     ChatResponse,
     Chats,
     Chats,
     ChatTitleIdResponse,
     ChatTitleIdResponse,
 )
 )
-from open_webui.apps.webui.models.tags import TagModel, Tags
-from open_webui.apps.webui.models.folders import Folders
+from open_webui.models.tags import TagModel, Tags
+from open_webui.models.folders import Folders
 
 
 from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
 from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from fastapi import APIRouter, Depends, HTTPException, Request, status
 from fastapi import APIRouter, Depends, HTTPException, Request, status
 from pydantic import BaseModel
 from pydantic import BaseModel
-from open_webui.utils.utils import get_admin_user, get_verified_user
+
+
+from open_webui.utils.auth import get_admin_user, get_verified_user
+from open_webui.utils.access_control import has_permission
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -50,9 +53,10 @@ async def get_session_user_chat_list(
 
 
 @router.delete("/", response_model=bool)
 @router.delete("/", response_model=bool)
 async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
 async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
-    if user.role == "user" and not request.app.state.config.USER_PERMISSIONS.get(
-        "chat", {}
-    ).get("deletion", {}):
+
+    if user.role == "user" and not has_permission(
+        user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
+    ):
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
             status_code=status.HTTP_401_UNAUTHORIZED,
             detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
             detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
@@ -385,8 +389,8 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
 
 
         return result
         return result
     else:
     else:
-        if not request.app.state.config.USER_PERMISSIONS.get("chat", {}).get(
-            "deletion", {}
+        if not has_permission(
+            user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
         ):
         ):
             raise HTTPException(
             raise HTTPException(
                 status_code=status.HTTP_401_UNAUTHORIZED,
                 status_code=status.HTTP_401_UNAUTHORIZED,
@@ -603,7 +607,6 @@ async def add_tag_by_id_and_tag_name(
                 detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"),
                 detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"),
             )
             )
 
 
-        print(tags, tag_id)
         if tag_id not in tags:
         if tag_id not in tags:
             Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
             Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
                 id, user.id, form_data.name
                 id, user.id, form_data.name

+ 32 - 19
backend/open_webui/apps/webui/routers/configs.py → backend/open_webui/routers/configs.py

@@ -1,10 +1,12 @@
-from open_webui.config import BannerModel
 from fastapi import APIRouter, Depends, Request
 from fastapi import APIRouter, Depends, Request
 from pydantic import BaseModel
 from pydantic import BaseModel
-from open_webui.utils.utils import get_admin_user, get_verified_user
 
 
+from typing import Optional
 
 
+from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.config import get_config, save_config
 from open_webui.config import get_config, save_config
+from open_webui.config import BannerModel
+
 
 
 router = APIRouter()
 router = APIRouter()
 
 
@@ -34,8 +36,32 @@ async def export_config(user=Depends(get_admin_user)):
     return get_config()
     return get_config()
 
 
 
 
-class SetDefaultModelsForm(BaseModel):
-    models: str
+############################
+# SetDefaultModels
+############################
+class ModelsConfigForm(BaseModel):
+    DEFAULT_MODELS: Optional[str]
+    MODEL_ORDER_LIST: Optional[list[str]]
+
+
+@router.get("/models", response_model=ModelsConfigForm)
+async def get_models_config(request: Request, user=Depends(get_admin_user)):
+    return {
+        "DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
+        "MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
+    }
+
+
+@router.post("/models", response_model=ModelsConfigForm)
+async def set_models_config(
+    request: Request, form_data: ModelsConfigForm, user=Depends(get_admin_user)
+):
+    request.app.state.config.DEFAULT_MODELS = form_data.DEFAULT_MODELS
+    request.app.state.config.MODEL_ORDER_LIST = form_data.MODEL_ORDER_LIST
+    return {
+        "DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
+        "MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
+    }
 
 
 
 
 class PromptSuggestion(BaseModel):
 class PromptSuggestion(BaseModel):
@@ -47,21 +73,8 @@ class SetDefaultSuggestionsForm(BaseModel):
     suggestions: list[PromptSuggestion]
     suggestions: list[PromptSuggestion]
 
 
 
 
-############################
-# SetDefaultModels
-############################
-
-
-@router.post("/default/models", response_model=str)
-async def set_global_default_models(
-    request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
-):
-    request.app.state.config.DEFAULT_MODELS = form_data.models
-    return request.app.state.config.DEFAULT_MODELS
-
-
-@router.post("/default/suggestions", response_model=list[PromptSuggestion])
-async def set_global_default_suggestions(
+@router.post("/suggestions", response_model=list[PromptSuggestion])
+async def set_default_suggestions(
     request: Request,
     request: Request,
     form_data: SetDefaultSuggestionsForm,
     form_data: SetDefaultSuggestionsForm,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),

+ 3 - 3
backend/open_webui/apps/webui/routers/evaluations.py → backend/open_webui/routers/evaluations.py

@@ -2,8 +2,8 @@ from typing import Optional
 from fastapi import APIRouter, Depends, HTTPException, status, Request
 from fastapi import APIRouter, Depends, HTTPException, status, Request
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
-from open_webui.apps.webui.models.users import Users, UserModel
-from open_webui.apps.webui.models.feedbacks import (
+from open_webui.models.users import Users, UserModel
+from open_webui.models.feedbacks import (
     FeedbackModel,
     FeedbackModel,
     FeedbackResponse,
     FeedbackResponse,
     FeedbackForm,
     FeedbackForm,
@@ -11,7 +11,7 @@ from open_webui.apps.webui.models.feedbacks import (
 )
 )
 
 
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
-from open_webui.utils.utils import get_admin_user, get_verified_user
+from open_webui.utils.auth import get_admin_user, get_verified_user
 
 
 router = APIRouter()
 router = APIRouter()
 
 

+ 29 - 16
backend/open_webui/apps/webui/routers/files.py → backend/open_webui/routers/files.py

@@ -5,27 +5,28 @@ from pathlib import Path
 from typing import Optional
 from typing import Optional
 from pydantic import BaseModel
 from pydantic import BaseModel
 import mimetypes
 import mimetypes
+from urllib.parse import quote
 
 
 from open_webui.storage.provider import Storage
 from open_webui.storage.provider import Storage
 
 
-from open_webui.apps.webui.models.files import (
+from open_webui.models.files import (
     FileForm,
     FileForm,
     FileModel,
     FileModel,
     FileModelResponse,
     FileModelResponse,
     Files,
     Files,
 )
 )
-from open_webui.apps.retrieval.main import process_file, ProcessFileForm
+from open_webui.routers.retrieval import process_file, ProcessFileForm
 
 
 from open_webui.config import UPLOAD_DIR
 from open_webui.config import UPLOAD_DIR
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
 
 
 
 
-from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
+from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
 from fastapi.responses import FileResponse, StreamingResponse
 from fastapi.responses import FileResponse, StreamingResponse
 
 
 
 
-from open_webui.utils.utils import get_admin_user, get_verified_user
+from open_webui.utils.auth import get_admin_user, get_verified_user
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -39,7 +40,9 @@ router = APIRouter()
 
 
 
 
 @router.post("/", response_model=FileModelResponse)
 @router.post("/", response_model=FileModelResponse)
-def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
+def upload_file(
+    request: Request, file: UploadFile = File(...), user=Depends(get_verified_user)
+):
     log.info(f"file.content_type: {file.content_type}")
     log.info(f"file.content_type: {file.content_type}")
     try:
     try:
         unsanitized_filename = file.filename
         unsanitized_filename = file.filename
@@ -56,7 +59,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
             FileForm(
             FileForm(
                 **{
                 **{
                     "id": id,
                     "id": id,
-                    "filename": filename,
+                    "filename": name,
                     "path": file_path,
                     "path": file_path,
                     "meta": {
                     "meta": {
                         "name": name,
                         "name": name,
@@ -68,7 +71,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
         )
         )
 
 
         try:
         try:
-            process_file(ProcessFileForm(file_id=id))
+            process_file(request, ProcessFileForm(file_id=id))
             file_item = Files.get_file_by_id(id=id)
             file_item = Files.get_file_by_id(id=id)
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)
@@ -183,13 +186,15 @@ class ContentForm(BaseModel):
 
 
 @router.post("/{id}/data/content/update")
 @router.post("/{id}/data/content/update")
 async def update_file_data_content_by_id(
 async def update_file_data_content_by_id(
-    id: str, form_data: ContentForm, user=Depends(get_verified_user)
+    request: Request, id: str, form_data: ContentForm, user=Depends(get_verified_user)
 ):
 ):
     file = Files.get_file_by_id(id)
     file = Files.get_file_by_id(id)
 
 
     if file and (file.user_id == user.id or user.role == "admin"):
     if file and (file.user_id == user.id or user.role == "admin"):
         try:
         try:
-            process_file(ProcessFileForm(file_id=id, content=form_data.content))
+            process_file(
+                request, ProcessFileForm(file_id=id, content=form_data.content)
+            )
             file = Files.get_file_by_id(id=id)
             file = Files.get_file_by_id(id=id)
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)
@@ -218,11 +223,15 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
 
 
             # Check if the file already exists in the cache
             # Check if the file already exists in the cache
             if file_path.is_file():
             if file_path.is_file():
-                print(f"file_path: {file_path}")
+                # Handle Unicode filenames
+                filename = file.meta.get("name", file.filename)
+                encoded_filename = quote(filename)  # RFC5987 encoding
                 headers = {
                 headers = {
-                    "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
+                    "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
                 }
                 }
+
                 return FileResponse(file_path, headers=headers)
                 return FileResponse(file_path, headers=headers)
+
             else:
             else:
                 raise HTTPException(
                 raise HTTPException(
                     status_code=status.HTTP_404_NOT_FOUND,
                     status_code=status.HTTP_404_NOT_FOUND,
@@ -279,16 +288,20 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
 
 
     if file and (file.user_id == user.id or user.role == "admin"):
     if file and (file.user_id == user.id or user.role == "admin"):
         file_path = file.path
         file_path = file.path
+
+        # Handle Unicode filenames
+        filename = file.meta.get("name", file.filename)
+        encoded_filename = quote(filename)  # RFC5987 encoding
+        headers = {
+            "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
+        }
+
         if file_path:
         if file_path:
             file_path = Storage.get_file(file_path)
             file_path = Storage.get_file(file_path)
             file_path = Path(file_path)
             file_path = Path(file_path)
 
 
             # Check if the file already exists in the cache
             # Check if the file already exists in the cache
             if file_path.is_file():
             if file_path.is_file():
-                print(f"file_path: {file_path}")
-                headers = {
-                    "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
-                }
                 return FileResponse(file_path, headers=headers)
                 return FileResponse(file_path, headers=headers)
             else:
             else:
                 raise HTTPException(
                 raise HTTPException(
@@ -307,7 +320,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
             return StreamingResponse(
             return StreamingResponse(
                 generator(),
                 generator(),
                 media_type="text/plain",
                 media_type="text/plain",
-                headers={"Content-Disposition": f"attachment; filename={file_name}"},
+                headers=headers,
             )
             )
     else:
     else:
         raise HTTPException(
         raise HTTPException(

+ 3 - 3
backend/open_webui/apps/webui/routers/folders.py → backend/open_webui/routers/folders.py

@@ -8,12 +8,12 @@ from pydantic import BaseModel
 import mimetypes
 import mimetypes
 
 
 
 
-from open_webui.apps.webui.models.folders import (
+from open_webui.models.folders import (
     FolderForm,
     FolderForm,
     FolderModel,
     FolderModel,
     Folders,
     Folders,
 )
 )
-from open_webui.apps.webui.models.chats import Chats
+from open_webui.models.chats import Chats
 
 
 from open_webui.config import UPLOAD_DIR
 from open_webui.config import UPLOAD_DIR
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
@@ -24,7 +24,7 @@ from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
 from fastapi.responses import FileResponse, StreamingResponse
 from fastapi.responses import FileResponse, StreamingResponse
 
 
 
 
-from open_webui.utils.utils import get_admin_user, get_verified_user
+from open_webui.utils.auth import get_admin_user, get_verified_user
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
 log.setLevel(SRC_LOG_LEVELS["MODELS"])

部分文件因为文件数量过多而无法显示