Parcourir la source

Merge pull request #1107 from open-webui/dev

0.1.111
Timothy Jaeryang Baek il y a 1 an
Parent
commit
89634046e7
37 fichiers modifiés avec 1554 ajouts et 765 suppressions
  1. 5 5
      .github/workflows/build-release.yml
  2. 18 0
      CHANGELOG.md
  3. 2 2
      Dockerfile
  4. 10 2
      README.md
  5. 190 53
      backend/apps/images/main.py
  6. 41 0
      backend/apps/litellm/main.py
  7. 17 3
      backend/apps/ollama/main.py
  8. 22 4
      backend/apps/openai/main.py
  9. 38 88
      backend/apps/rag/main.py
  10. 97 0
      backend/apps/rag/utils.py
  11. 6 1
      backend/config.py
  12. 1 0
      backend/data/config.json
  13. 175 38
      backend/main.py
  14. 2 1
      backend/requirements.txt
  15. 1 1
      package.json
  16. 83 8
      src/lib/apis/images/index.ts
  17. 62 0
      src/lib/apis/index.ts
  18. 3 1
      src/lib/apis/litellm/index.ts
  19. 16 7
      src/lib/apis/rag/index.ts
  20. 113 0
      src/lib/components/admin/Settings/Users.svelte
  21. 3 3
      src/lib/components/chat/MessageInput.svelte
  22. 14 11
      src/lib/components/chat/Messages/ResponseMessage.svelte
  23. 3 3
      src/lib/components/chat/Messages/UserMessage.svelte
  24. 1 1
      src/lib/components/chat/Settings/Account.svelte
  25. 1 1
      src/lib/components/chat/Settings/Audio.svelte
  26. 1 1
      src/lib/components/chat/Settings/Connections.svelte
  27. 2 2
      src/lib/components/chat/Settings/General.svelte
  28. 135 76
      src/lib/components/chat/Settings/Images.svelte
  29. 5 27
      src/lib/components/chat/Settings/Interface.svelte
  30. 277 221
      src/lib/components/chat/Settings/Models.svelte
  31. 6 6
      src/lib/components/chat/SettingsModal.svelte
  32. 7 2
      src/lib/components/common/Image.svelte
  33. 99 72
      src/lib/components/documents/Settings/General.svelte
  34. 12 7
      src/lib/components/layout/Sidebar.svelte
  35. 45 59
      src/routes/(app)/+page.svelte
  36. 40 58
      src/routes/(app)/c/[id]/+page.svelte
  37. 1 1
      src/routes/(app)/playground/+page.svelte

+ 5 - 5
.github/workflows/build-release.yml

@@ -29,11 +29,11 @@ jobs:
     - name: Extract latest CHANGELOG entry
     - name: Extract latest CHANGELOG entry
       id: changelog
       id: changelog
       run: |
       run: |
-        CHANGELOG_CONTENT=$(awk '/^## \[/{n++} n==1' CHANGELOG.md)
-        echo "CHANGELOG_CONTENT<<EOF" 
-        echo "$CHANGELOG_CONTENT"
-        echo "EOF" 
-        echo "::set-output name=content::${CHANGELOG_CONTENT}"
+        CHANGELOG_CONTENT=$(awk 'BEGIN {print_section=0;} /^## \[/ {if (print_section == 0) {print_section=1;} else {exit;}} print_section {print;}' CHANGELOG.md)
+        CHANGELOG_ESCAPED=$(echo "$CHANGELOG_CONTENT" | sed ':a;N;$!ba;s/\n/%0A/g')
+        echo "Extracted latest release notes from CHANGELOG.md:" 
+        echo -e "$CHANGELOG_CONTENT" 
+        echo "::set-output name=content::$CHANGELOG_ESCAPED"
 
 
     - name: Create GitHub release
     - name: Create GitHub release
       uses: actions/github-script@v5
       uses: actions/github-script@v5

+ 18 - 0
CHANGELOG.md

@@ -5,6 +5,24 @@ All notable changes to this project will be documented in this file.
 The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
 The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
 and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
 and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
 
 
+## [0.1.111] - 2024-03-10
+
+### Added
+
+- 🛡️ **Model Whitelisting**: Admins now have the ability to whitelist models for users with the 'user' role.
+- 🔄 **Update All Models**: Added a convenient button to update all models at once.
+- 📄 **Toggle PDF OCR**: Users can now toggle PDF OCR option for improved parsing performance.
+- 🎨 **DALL-E Integration**: Introduced DALL-E integration for image generation alongside automatic1111.
+- 🛠️ **RAG API Refactoring**: Refactored RAG logic and exposed its API, with additional documentation to follow.
+
+### Fixed
+
+- 🔒 **Max Token Settings**: Added max token settings for anthropic/claude-3-sonnet-20240229 (Issue #1094).
+- 🔧 **Misalignment Issue**: Corrected misalignment of Edit and Delete Icons when Chat Title is Empty (Issue #1104).
+- 🔄 **Context Loss Fix**: Resolved RAG losing context on model response regeneration with Groq models via API key (Issue #1105).
+- 📁 **File Handling Bug**: Addressed File Not Found Notification when Dropping a Conversation Element (Issue #1098).
+- 🖱️ **Dragged File Styling**: Fixed dragged file layover styling issue.
+
 ## [0.1.110] - 2024-03-06
 ## [0.1.110] - 2024-03-06
 
 
 ### Added
 ### Added

+ 2 - 2
Dockerfile

@@ -41,7 +41,7 @@ ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
 # for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
 # for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
 # IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
 # IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
 ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2"
 ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2"
-# device type for whisper tts and ebbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
+# device type for whisper tts and embbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
 ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu"
 ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu"
 ENV RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models"
 ENV RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models"
 ENV SENTENCE_TRANSFORMERS_HOME $RAG_EMBEDDING_MODEL_DIR
 ENV SENTENCE_TRANSFORMERS_HOME $RAG_EMBEDDING_MODEL_DIR
@@ -81,4 +81,4 @@ COPY --from=build /app/package.json /app/package.json
 # copy backend files
 # copy backend files
 COPY ./backend .
 COPY ./backend .
 
 
-CMD [ "bash", "start.sh"]
+CMD [ "bash", "start.sh"]

+ 10 - 2
README.md

@@ -53,8 +53,6 @@ User-friendly WebUI for LLMs, supported LLM runners include Ollama and OpenAI-co
 
 
 - 💬 **Collaborative Chat**: Harness the collective intelligence of multiple models by seamlessly orchestrating group conversations. Use the `@` command to specify the model, enabling dynamic and diverse dialogues within your chat interface. Immerse yourself in the collective intelligence woven into your chat environment.
 - 💬 **Collaborative Chat**: Harness the collective intelligence of multiple models by seamlessly orchestrating group conversations. Use the `@` command to specify the model, enabling dynamic and diverse dialogues within your chat interface. Immerse yourself in the collective intelligence woven into your chat environment.
 
 
-- 🤝 **OpenAI API Integration**: Effortlessly integrate OpenAI-compatible API for versatile conversations alongside Ollama models. Customize the API Base URL to link with **LMStudio, Mistral, OpenRouter, and more**.
-
 - 🔄 **Regeneration History Access**: Easily revisit and explore your entire regeneration history.
 - 🔄 **Regeneration History Access**: Easily revisit and explore your entire regeneration history.
 
 
 - 📜 **Chat History**: Effortlessly access and manage your conversation history.
 - 📜 **Chat History**: Effortlessly access and manage your conversation history.
@@ -65,8 +63,18 @@ User-friendly WebUI for LLMs, supported LLM runners include Ollama and OpenAI-co
 
 
 - ⚙️ **Fine-Tuned Control with Advanced Parameters**: Gain a deeper level of control by adjusting parameters such as temperature and defining your system prompts to tailor the conversation to your specific preferences and needs.
 - ⚙️ **Fine-Tuned Control with Advanced Parameters**: Gain a deeper level of control by adjusting parameters such as temperature and defining your system prompts to tailor the conversation to your specific preferences and needs.
 
 
+- 🎨🤖 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using AUTOMATIC1111 API (local) and DALL-E, enriching your chat experience with dynamic visual content.
+
+- 🤝 **OpenAI API Integration**: Effortlessly integrate OpenAI-compatible API for versatile conversations alongside Ollama models. Customize the API Base URL to link with **LMStudio, Mistral, OpenRouter, and more**.
+
+- ✨ **Multiple OpenAI-Compatible API Support**: Seamlessly integrate and customize various OpenAI-compatible APIs, enhancing the versatility of your chat interactions.
+
 - 🔗 **External Ollama Server Connection**: Seamlessly link to an external Ollama server hosted on a different address by configuring the environment variable.
 - 🔗 **External Ollama Server Connection**: Seamlessly link to an external Ollama server hosted on a different address by configuring the environment variable.
 
 
+- 🔀 **Multiple Ollama Instance Load Balancing**: Effortlessly distribute chat requests across multiple Ollama instances for enhanced performance and reliability.
+
+- 👥 **Multi-User Management**: Easily oversee and administer users via our intuitive admin panel, streamlining user management processes.
+
 - 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
 - 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
 
 
 - 🔒 **Backend Reverse Proxy Support**: Bolster security through direct communication between Open WebUI backend and Ollama. This key feature eliminates the need to expose Ollama over LAN. Requests made to the '/ollama/api' route from the web UI are seamlessly redirected to Ollama from the backend, enhancing overall system security.
 - 🔒 **Backend Reverse Proxy Support**: Bolster security through direct communication between Open WebUI backend and Ollama. This key feature eliminates the need to expose Ollama over LAN. Requests made to the '/ollama/api' route from the web UI are seamlessly redirected to Ollama from the backend, enhancing overall system security.

+ 190 - 53
backend/apps/images/main.py

@@ -21,7 +21,16 @@ from utils.utils import (
 from utils.misc import calculate_sha256
 from utils.misc import calculate_sha256
 from typing import Optional
 from typing import Optional
 from pydantic import BaseModel
 from pydantic import BaseModel
-from config import AUTOMATIC1111_BASE_URL
+from pathlib import Path
+import uuid
+import base64
+import json
+
+from config import CACHE_DIR, AUTOMATIC1111_BASE_URL
+
+
+IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
+IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
 
 
 app = FastAPI()
 app = FastAPI()
 app.add_middleware(
 app.add_middleware(
@@ -32,25 +41,34 @@ app.add_middleware(
     allow_headers=["*"],
     allow_headers=["*"],
 )
 )
 
 
+app.state.ENGINE = ""
+app.state.ENABLED = False
+
+app.state.OPENAI_API_KEY = ""
+app.state.MODEL = ""
+
+
 app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
 app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
-app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != ""
+
 app.state.IMAGE_SIZE = "512x512"
 app.state.IMAGE_SIZE = "512x512"
 app.state.IMAGE_STEPS = 50
 app.state.IMAGE_STEPS = 50
 
 
 
 
-@app.get("/enabled", response_model=bool)
-async def get_enable_status(request: Request, user=Depends(get_admin_user)):
-    return app.state.ENABLED
+@app.get("/config")
+async def get_config(request: Request, user=Depends(get_admin_user)):
+    return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
 
 
 
 
-@app.get("/enabled/toggle", response_model=bool)
-async def toggle_enabled(request: Request, user=Depends(get_admin_user)):
-    try:
-        r = requests.head(app.state.AUTOMATIC1111_BASE_URL)
-        app.state.ENABLED = not app.state.ENABLED
-        return app.state.ENABLED
-    except Exception as e:
-        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
+class ConfigUpdateForm(BaseModel):
+    engine: str
+    enabled: bool
+
+
+@app.post("/config/update")
+async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
+    app.state.ENGINE = form_data.engine
+    app.state.ENABLED = form_data.enabled
+    return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
 
 
 
 
 class UrlUpdateForm(BaseModel):
 class UrlUpdateForm(BaseModel):
@@ -58,17 +76,24 @@ class UrlUpdateForm(BaseModel):
 
 
 
 
 @app.get("/url")
 @app.get("/url")
-async def get_openai_url(user=Depends(get_admin_user)):
+async def get_automatic1111_url(user=Depends(get_admin_user)):
     return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
     return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
 
 
 
 
 @app.post("/url/update")
 @app.post("/url/update")
-async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
+async def update_automatic1111_url(
+    form_data: UrlUpdateForm, user=Depends(get_admin_user)
+):
 
 
     if form_data.url == "":
     if form_data.url == "":
         app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
         app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
     else:
     else:
-        app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/")
+        url = form_data.url.strip("/")
+        try:
+            r = requests.head(url)
+            app.state.AUTOMATIC1111_BASE_URL = url
+        except Exception as e:
+            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
 
 
     return {
     return {
         "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
         "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
@@ -76,6 +101,30 @@ async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_use
     }
     }
 
 
 
 
+class OpenAIKeyUpdateForm(BaseModel):
+    key: str
+
+
+@app.get("/key")
+async def get_openai_key(user=Depends(get_admin_user)):
+    return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
+
+
+@app.post("/key/update")
+async def update_openai_key(
+    form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user)
+):
+
+    if form_data.key == "":
+        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
+
+    app.state.OPENAI_API_KEY = form_data.key
+    return {
+        "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
+        "status": True,
+    }
+
+
 class ImageSizeUpdateForm(BaseModel):
 class ImageSizeUpdateForm(BaseModel):
     size: str
     size: str
 
 
@@ -132,9 +181,22 @@ async def update_image_size(
 @app.get("/models")
 @app.get("/models")
 def get_models(user=Depends(get_current_user)):
 def get_models(user=Depends(get_current_user)):
     try:
     try:
-        r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models")
-        models = r.json()
-        return models
+        if app.state.ENGINE == "openai":
+            return [
+                {"id": "dall-e-2", "name": "DALL·E 2"},
+                {"id": "dall-e-3", "name": "DALL·E 3"},
+            ]
+        else:
+            r = requests.get(
+                url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
+            )
+            models = r.json()
+            return list(
+                map(
+                    lambda model: {"id": model["title"], "name": model["model_name"]},
+                    models,
+                )
+            )
     except Exception as e:
     except Exception as e:
         app.state.ENABLED = False
         app.state.ENABLED = False
         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@@ -143,10 +205,12 @@ def get_models(user=Depends(get_current_user)):
 @app.get("/models/default")
 @app.get("/models/default")
 async def get_default_model(user=Depends(get_admin_user)):
 async def get_default_model(user=Depends(get_admin_user)):
     try:
     try:
-        r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
-        options = r.json()
-
-        return {"model": options["sd_model_checkpoint"]}
+        if app.state.ENGINE == "openai":
+            return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
+        else:
+            r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
+            options = r.json()
+            return {"model": options["sd_model_checkpoint"]}
     except Exception as e:
     except Exception as e:
         app.state.ENABLED = False
         app.state.ENABLED = False
         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@@ -157,16 +221,21 @@ class UpdateModelForm(BaseModel):
 
 
 
 
 def set_model_handler(model: str):
 def set_model_handler(model: str):
-    r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
-    options = r.json()
 
 
-    if model != options["sd_model_checkpoint"]:
-        options["sd_model_checkpoint"] = model
-        r = requests.post(
-            url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
-        )
+    if app.state.ENGINE == "openai":
+        app.state.MODEL = model
+        return app.state.MODEL
+    else:
+        r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
+        options = r.json()
+
+        if model != options["sd_model_checkpoint"]:
+            options["sd_model_checkpoint"] = model
+            r = requests.post(
+                url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
+            )
 
 
-    return options
+        return options
 
 
 
 
 @app.post("/models/default/update")
 @app.post("/models/default/update")
@@ -181,45 +250,113 @@ class GenerateImageForm(BaseModel):
     model: Optional[str] = None
     model: Optional[str] = None
     prompt: str
     prompt: str
     n: int = 1
     n: int = 1
-    size: str = "512x512"
+    size: Optional[str] = None
     negative_prompt: Optional[str] = None
     negative_prompt: Optional[str] = None
 
 
 
 
+def save_b64_image(b64_str):
+    image_id = str(uuid.uuid4())
+    file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
+
+    try:
+        # Split the base64 string to get the actual image data
+        img_data = base64.b64decode(b64_str)
+
+        # Write the image data to a file
+        with open(file_path, "wb") as f:
+            f.write(img_data)
+
+        return image_id
+    except Exception as e:
+        print(f"Error saving image: {e}")
+        return None
+
+
 @app.post("/generations")
 @app.post("/generations")
 def generate_image(
 def generate_image(
     form_data: GenerateImageForm,
     form_data: GenerateImageForm,
     user=Depends(get_current_user),
     user=Depends(get_current_user),
 ):
 ):
 
 
-    print(form_data)
-
+    r = None
     try:
     try:
-        if form_data.model:
-            set_model_handler(form_data.model)
+        if app.state.ENGINE == "openai":
 
 
-        width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
+            headers = {}
+            headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
+            headers["Content-Type"] = "application/json"
 
 
-        data = {
-            "prompt": form_data.prompt,
-            "batch_size": form_data.n,
-            "width": width,
-            "height": height,
-        }
+            data = {
+                "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
+                "prompt": form_data.prompt,
+                "n": form_data.n,
+                "size": form_data.size if form_data.size else app.state.IMAGE_SIZE,
+                "response_format": "b64_json",
+            }
+            r = requests.post(
+                url=f"https://api.openai.com/v1/images/generations",
+                json=data,
+                headers=headers,
+            )
 
 
-        if app.state.IMAGE_STEPS != None:
-            data["steps"] = app.state.IMAGE_STEPS
+            r.raise_for_status()
 
 
-        if form_data.negative_prompt != None:
-            data["negative_prompt"] = form_data.negative_prompt
+            res = r.json()
 
 
-        print(data)
+            images = []
 
 
-        r = requests.post(
-            url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
-            json=data,
-        )
+            for image in res["data"]:
+                image_id = save_b64_image(image["b64_json"])
+                images.append({"url": f"/cache/image/generations/{image_id}.png"})
+                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
+
+                with open(file_body_path, "w") as f:
+                    json.dump(data, f)
+
+            return images
+
+        else:
+            if form_data.model:
+                set_model_handler(form_data.model)
+
+            width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
+
+            data = {
+                "prompt": form_data.prompt,
+                "batch_size": form_data.n,
+                "width": width,
+                "height": height,
+            }
+
+            if app.state.IMAGE_STEPS != None:
+                data["steps"] = app.state.IMAGE_STEPS
+
+            if form_data.negative_prompt != None:
+                data["negative_prompt"] = form_data.negative_prompt
+
+            r = requests.post(
+                url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
+                json=data,
+            )
+
+            res = r.json()
+
+            print(res)
+
+            images = []
+
+            for image in res["images"]:
+                image_id = save_b64_image(image)
+                images.append({"url": f"/cache/image/generations/{image_id}.png"})
+                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
+
+                with open(file_body_path, "w") as f:
+                    json.dump({**data, "info": res["info"]}, f)
+
+            return images
 
 
-        return r.json()
     except Exception as e:
     except Exception as e:
         print(e)
         print(e)
+        if r:
+            print(r.json())
         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))

+ 41 - 0
backend/apps/litellm/main.py

@@ -0,0 +1,41 @@
+from litellm.proxy.proxy_server import ProxyConfig, initialize
+from litellm.proxy.proxy_server import app
+
+from fastapi import FastAPI, Request, Depends, status
+from fastapi.responses import JSONResponse
+from utils.utils import get_http_authorization_cred, get_current_user
+from config import ENV
+
+proxy_config = ProxyConfig()
+
+
+async def config():
+    router, model_list, general_settings = await proxy_config.load_config(
+        router=None, config_file_path="./data/litellm/config.yaml"
+    )
+
+    await initialize(config="./data/litellm/config.yaml", telemetry=False)
+
+
+async def startup():
+    await config()
+
+
+@app.on_event("startup")
+async def on_startup():
+    await startup()
+
+
+@app.middleware("http")
+async def auth_middleware(request: Request, call_next):
+    auth_header = request.headers.get("Authorization", "")
+
+    if ENV != "dev":
+        try:
+            user = get_current_user(get_http_authorization_cred(auth_header))
+            print(user)
+        except Exception as e:
+            return JSONResponse(status_code=400, content={"detail": str(e)})
+
+    response = await call_next(request)
+    return response

+ 17 - 3
backend/apps/ollama/main.py

@@ -15,7 +15,7 @@ import asyncio
 from apps.web.models.users import Users
 from apps.web.models.users import Users
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 from utils.utils import decode_token, get_current_user, get_admin_user
 from utils.utils import decode_token, get_current_user, get_admin_user
-from config import OLLAMA_BASE_URLS
+from config import OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST
 
 
 from typing import Optional, List, Union
 from typing import Optional, List, Union
 
 
@@ -29,6 +29,10 @@ app.add_middleware(
     allow_headers=["*"],
     allow_headers=["*"],
 )
 )
 
 
+
+app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
+app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
+
 app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
 app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
 app.state.MODELS = {}
 app.state.MODELS = {}
 
 
@@ -129,9 +133,19 @@ async def get_all_models():
 async def get_ollama_tags(
 async def get_ollama_tags(
     url_idx: Optional[int] = None, user=Depends(get_current_user)
     url_idx: Optional[int] = None, user=Depends(get_current_user)
 ):
 ):
-
     if url_idx == None:
     if url_idx == None:
-        return await get_all_models()
+        models = await get_all_models()
+
+        if app.state.MODEL_FILTER_ENABLED:
+            if user.role == "user":
+                models["models"] = list(
+                    filter(
+                        lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
+                        models["models"],
+                    )
+                )
+                return models
+        return models
     else:
     else:
         url = app.state.OLLAMA_BASE_URLS[url_idx]
         url = app.state.OLLAMA_BASE_URLS[url_idx]
         try:
         try:

+ 22 - 4
backend/apps/openai/main.py

@@ -18,7 +18,13 @@ from utils.utils import (
     get_verified_user,
     get_verified_user,
     get_admin_user,
     get_admin_user,
 )
 )
-from config import OPENAI_API_BASE_URLS, OPENAI_API_KEYS, CACHE_DIR
+from config import (
+    OPENAI_API_BASE_URLS,
+    OPENAI_API_KEYS,
+    CACHE_DIR,
+    MODEL_FILTER_ENABLED,
+    MODEL_FILTER_LIST,
+)
 from typing import List, Optional
 from typing import List, Optional
 
 
 
 
@@ -34,6 +40,9 @@ app.add_middleware(
     allow_headers=["*"],
     allow_headers=["*"],
 )
 )
 
 
+app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
+app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
+
 app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
 app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
 app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
 app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
 
 
@@ -186,12 +195,21 @@ async def get_all_models():
     return models
     return models
 
 
 
 
-# , user=Depends(get_current_user)
 @app.get("/models")
 @app.get("/models")
 @app.get("/models/{url_idx}")
 @app.get("/models/{url_idx}")
-async def get_models(url_idx: Optional[int] = None):
+async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
     if url_idx == None:
     if url_idx == None:
-        return await get_all_models()
+        models = await get_all_models()
+        if app.state.MODEL_FILTER_ENABLED:
+            if user.role == "user":
+                models["data"] = list(
+                    filter(
+                        lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
+                        models["data"],
+                    )
+                )
+                return models
+        return models
     else:
     else:
         url = app.state.OPENAI_API_BASE_URLS[url_idx]
         url = app.state.OPENAI_API_BASE_URLS[url_idx]
         try:
         try:

+ 38 - 88
backend/apps/rag/main.py

@@ -44,6 +44,8 @@ from apps.web.models.documents import (
     DocumentResponse,
     DocumentResponse,
 )
 )
 
 
+from apps.rag.utils import query_doc, query_collection
+
 from utils.misc import (
 from utils.misc import (
     calculate_sha256,
     calculate_sha256,
     calculate_sha256_string,
     calculate_sha256_string,
@@ -75,6 +77,7 @@ from constants import ERROR_MESSAGES
 
 
 app = FastAPI()
 app = FastAPI()
 
 
+app.state.PDF_EXTRACT_IMAGES = False
 app.state.CHUNK_SIZE = CHUNK_SIZE
 app.state.CHUNK_SIZE = CHUNK_SIZE
 app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
 app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
 app.state.RAG_TEMPLATE = RAG_TEMPLATE
 app.state.RAG_TEMPLATE = RAG_TEMPLATE
@@ -182,12 +185,15 @@ async def update_embedding_model(
     }
     }
 
 
 
 
-@app.get("/chunk")
-async def get_chunk_params(user=Depends(get_admin_user)):
+@app.get("/config")
+async def get_rag_config(user=Depends(get_admin_user)):
     return {
     return {
         "status": True,
         "status": True,
-        "chunk_size": app.state.CHUNK_SIZE,
-        "chunk_overlap": app.state.CHUNK_OVERLAP,
+        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
+        "chunk": {
+            "chunk_size": app.state.CHUNK_SIZE,
+            "chunk_overlap": app.state.CHUNK_OVERLAP,
+        },
     }
     }
 
 
 
 
@@ -196,17 +202,24 @@ class ChunkParamUpdateForm(BaseModel):
     chunk_overlap: int
     chunk_overlap: int
 
 
 
 
-@app.post("/chunk/update")
-async def update_chunk_params(
-    form_data: ChunkParamUpdateForm, user=Depends(get_admin_user)
-):
-    app.state.CHUNK_SIZE = form_data.chunk_size
-    app.state.CHUNK_OVERLAP = form_data.chunk_overlap
+class ConfigUpdateForm(BaseModel):
+    pdf_extract_images: bool
+    chunk: ChunkParamUpdateForm
+
+
+@app.post("/config/update")
+async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
+    app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
+    app.state.CHUNK_SIZE = form_data.chunk.chunk_size
+    app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
 
 
     return {
     return {
         "status": True,
         "status": True,
-        "chunk_size": app.state.CHUNK_SIZE,
-        "chunk_overlap": app.state.CHUNK_OVERLAP,
+        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
+        "chunk": {
+            "chunk_size": app.state.CHUNK_SIZE,
+            "chunk_overlap": app.state.CHUNK_OVERLAP,
+        },
     }
     }
 
 
 
 
@@ -248,21 +261,18 @@ class QueryDocForm(BaseModel):
 
 
 
 
 @app.post("/query/doc")
 @app.post("/query/doc")
-def query_doc(
+def query_doc_handler(
     form_data: QueryDocForm,
     form_data: QueryDocForm,
     user=Depends(get_current_user),
     user=Depends(get_current_user),
 ):
 ):
+
     try:
     try:
-        # if you use docker use the model from the environment variable
-        collection = CHROMA_CLIENT.get_collection(
-            name=form_data.collection_name,
+        return query_doc(
+            collection_name=form_data.collection_name,
+            query=form_data.query,
+            k=form_data.k if form_data.k else app.state.TOP_K,
             embedding_function=app.state.sentence_transformer_ef,
             embedding_function=app.state.sentence_transformer_ef,
         )
         )
-        result = collection.query(
-            query_texts=[form_data.query],
-            n_results=form_data.k if form_data.k else app.state.TOP_K,
-        )
-        return result
     except Exception as e:
     except Exception as e:
         print(e)
         print(e)
         raise HTTPException(
         raise HTTPException(
@@ -277,76 +287,16 @@ class QueryCollectionsForm(BaseModel):
     k: Optional[int] = None
     k: Optional[int] = None
 
 
 
 
-def merge_and_sort_query_results(query_results, k):
-    # Initialize lists to store combined data
-    combined_ids = []
-    combined_distances = []
-    combined_metadatas = []
-    combined_documents = []
-
-    # Combine data from each dictionary
-    for data in query_results:
-        combined_ids.extend(data["ids"][0])
-        combined_distances.extend(data["distances"][0])
-        combined_metadatas.extend(data["metadatas"][0])
-        combined_documents.extend(data["documents"][0])
-
-    # Create a list of tuples (distance, id, metadata, document)
-    combined = list(
-        zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
-    )
-
-    # Sort the list based on distances
-    combined.sort(key=lambda x: x[0])
-
-    # Unzip the sorted list
-    sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
-
-    # Slicing the lists to include only k elements
-    sorted_distances = list(sorted_distances)[:k]
-    sorted_ids = list(sorted_ids)[:k]
-    sorted_metadatas = list(sorted_metadatas)[:k]
-    sorted_documents = list(sorted_documents)[:k]
-
-    # Create the output dictionary
-    merged_query_results = {
-        "ids": [sorted_ids],
-        "distances": [sorted_distances],
-        "metadatas": [sorted_metadatas],
-        "documents": [sorted_documents],
-        "embeddings": None,
-        "uris": None,
-        "data": None,
-    }
-
-    return merged_query_results
-
-
 @app.post("/query/collection")
 @app.post("/query/collection")
-def query_collection(
+def query_collection_handler(
     form_data: QueryCollectionsForm,
     form_data: QueryCollectionsForm,
     user=Depends(get_current_user),
     user=Depends(get_current_user),
 ):
 ):
-    results = []
-
-    for collection_name in form_data.collection_names:
-        try:
-            # if you use docker use the model from the environment variable
-            collection = CHROMA_CLIENT.get_collection(
-                name=collection_name,
-                embedding_function=app.state.sentence_transformer_ef,
-            )
-
-            result = collection.query(
-                query_texts=[form_data.query],
-                n_results=form_data.k if form_data.k else app.state.TOP_K,
-            )
-            results.append(result)
-        except:
-            pass
-
-    return merge_and_sort_query_results(
-        results, form_data.k if form_data.k else app.state.TOP_K
+    return query_collection(
+        collection_names=form_data.collection_names,
+        query=form_data.query,
+        k=form_data.k if form_data.k else app.state.TOP_K,
+        embedding_function=app.state.sentence_transformer_ef,
     )
     )
 
 
 
 
@@ -425,7 +375,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
     ]
     ]
 
 
     if file_ext == "pdf":
     if file_ext == "pdf":
-        loader = PyPDFLoader(file_path, extract_images=True)
+        loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
     elif file_ext == "csv":
     elif file_ext == "csv":
         loader = CSVLoader(file_path)
         loader = CSVLoader(file_path)
     elif file_ext == "rst":
     elif file_ext == "rst":

+ 97 - 0
backend/apps/rag/utils.py

@@ -0,0 +1,97 @@
+import re
+from typing import List
+
+from config import CHROMA_CLIENT
+
+
+def query_doc(collection_name: str, query: str, k: int, embedding_function):
+    try:
+        # if you use docker use the model from the environment variable
+        collection = CHROMA_CLIENT.get_collection(
+            name=collection_name,
+            embedding_function=embedding_function,
+        )
+        result = collection.query(
+            query_texts=[query],
+            n_results=k,
+        )
+        return result
+    except Exception as e:
+        raise e
+
+
+def merge_and_sort_query_results(query_results, k):
+    # Initialize lists to store combined data
+    combined_ids = []
+    combined_distances = []
+    combined_metadatas = []
+    combined_documents = []
+
+    # Combine data from each dictionary
+    for data in query_results:
+        combined_ids.extend(data["ids"][0])
+        combined_distances.extend(data["distances"][0])
+        combined_metadatas.extend(data["metadatas"][0])
+        combined_documents.extend(data["documents"][0])
+
+    # Create a list of tuples (distance, id, metadata, document)
+    combined = list(
+        zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
+    )
+
+    # Sort the list based on distances
+    combined.sort(key=lambda x: x[0])
+
+    # Unzip the sorted list
+    sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
+
+    # Slicing the lists to include only k elements
+    sorted_distances = list(sorted_distances)[:k]
+    sorted_ids = list(sorted_ids)[:k]
+    sorted_metadatas = list(sorted_metadatas)[:k]
+    sorted_documents = list(sorted_documents)[:k]
+
+    # Create the output dictionary
+    merged_query_results = {
+        "ids": [sorted_ids],
+        "distances": [sorted_distances],
+        "metadatas": [sorted_metadatas],
+        "documents": [sorted_documents],
+        "embeddings": None,
+        "uris": None,
+        "data": None,
+    }
+
+    return merged_query_results
+
+
+def query_collection(
+    collection_names: List[str], query: str, k: int, embedding_function
+):
+
+    results = []
+
+    for collection_name in collection_names:
+        try:
+            # if you use docker use the model from the environment variable
+            collection = CHROMA_CLIENT.get_collection(
+                name=collection_name,
+                embedding_function=embedding_function,
+            )
+
+            result = collection.query(
+                query_texts=[query],
+                n_results=k,
+            )
+            results.append(result)
+        except:
+            pass
+
+    return merge_and_sort_query_results(results, k)
+
+
+def rag_template(template: str, context: str, query: str):
+    template = re.sub(r"\[context\]", context, template)
+    template = re.sub(r"\[query\]", query, template)
+
+    return template

+ 6 - 1
backend/config.py

@@ -251,7 +251,7 @@ OPENAI_API_BASE_URLS = (
     OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL
     OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL
 )
 )
 
 
-OPENAI_API_BASE_URLS = [url.strip() for url in OPENAI_API_BASE_URL.split(";")]
+OPENAI_API_BASE_URLS = [url.strip() for url in OPENAI_API_BASE_URLS.split(";")]
 
 
 
 
 ####################################
 ####################################
@@ -292,6 +292,11 @@ DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending")
 USER_PERMISSIONS = {"chat": {"deletion": True}}
 USER_PERMISSIONS = {"chat": {"deletion": True}}
 
 
 
 
+MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", False)
+MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
+MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]
+
+
 ####################################
 ####################################
 # WEBUI_VERSION
 # WEBUI_VERSION
 ####################################
 ####################################

+ 1 - 0
backend/data/config.json

@@ -1,4 +1,5 @@
 {
 {
+    "version": "0.0.1",
     "ui": {
     "ui": {
         "prompt_suggestions": [
         "prompt_suggestions": [
             {
             {

+ 175 - 38
backend/main.py

@@ -9,27 +9,37 @@ import requests
 from fastapi import FastAPI, Request, Depends, status
 from fastapi import FastAPI, Request, Depends, status
 from fastapi.staticfiles import StaticFiles
 from fastapi.staticfiles import StaticFiles
 from fastapi import HTTPException
 from fastapi import HTTPException
-from fastapi.responses import JSONResponse
 from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from starlette.exceptions import HTTPException as StarletteHTTPException
 from starlette.exceptions import HTTPException as StarletteHTTPException
+from starlette.middleware.base import BaseHTTPMiddleware
 
 
 
 
-from litellm.proxy.proxy_server import ProxyConfig, initialize
-from litellm.proxy.proxy_server import app as litellm_app
-
 from apps.ollama.main import app as ollama_app
 from apps.ollama.main import app as ollama_app
 from apps.openai.main import app as openai_app
 from apps.openai.main import app as openai_app
+from apps.litellm.main import app as litellm_app, startup as litellm_app_startup
 from apps.audio.main import app as audio_app
 from apps.audio.main import app as audio_app
 from apps.images.main import app as images_app
 from apps.images.main import app as images_app
 from apps.rag.main import app as rag_app
 from apps.rag.main import app as rag_app
 from apps.web.main import app as webui_app
 from apps.web.main import app as webui_app
 
 
+from pydantic import BaseModel
+from typing import List
 
 
-from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
-from constants import ERROR_MESSAGES
 
 
-from utils.utils import get_http_authorization_cred, get_current_user
+from utils.utils import get_admin_user
+from apps.rag.utils import query_doc, query_collection, rag_template
+
+from config import (
+    WEBUI_NAME,
+    ENV,
+    VERSION,
+    CHANGELOG,
+    FRONTEND_BUILD_DIR,
+    MODEL_FILTER_ENABLED,
+    MODEL_FILTER_LIST,
+)
+from constants import ERROR_MESSAGES
 
 
 
 
 class SPAStaticFiles(StaticFiles):
 class SPAStaticFiles(StaticFiles):
@@ -43,23 +53,11 @@ class SPAStaticFiles(StaticFiles):
                 raise ex
                 raise ex
 
 
 
 
-proxy_config = ProxyConfig()
-
-
-async def config():
-    router, model_list, general_settings = await proxy_config.load_config(
-        router=None, config_file_path="./data/litellm/config.yaml"
-    )
-
-    await initialize(config="./data/litellm/config.yaml", telemetry=False)
-
-
-async def startup():
-    await config()
-
-
 app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
 app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
 
 
+app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
+app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
+
 origins = ["*"]
 origins = ["*"]
 
 
 app.add_middleware(
 app.add_middleware(
@@ -73,7 +71,127 @@ app.add_middleware(
 
 
 @app.on_event("startup")
 @app.on_event("startup")
 async def on_startup():
 async def on_startup():
-    await startup()
+    await litellm_app_startup()
+
+
+class RAGMiddleware(BaseHTTPMiddleware):
+    async def dispatch(self, request: Request, call_next):
+        if request.method == "POST" and (
+            "/api/chat" in request.url.path or "/chat/completions" in request.url.path
+        ):
+            print(request.url.path)
+
+            # Read the original request body
+            body = await request.body()
+            # Decode body to string
+            body_str = body.decode("utf-8")
+            # Parse string to JSON
+            data = json.loads(body_str) if body_str else {}
+
+            # Example: Add a new key-value pair or modify existing ones
+            # data["modified"] = True  # Example modification
+            if "docs" in data:
+                docs = data["docs"]
+                print(docs)
+
+                last_user_message_idx = None
+                for i in range(len(data["messages"]) - 1, -1, -1):
+                    if data["messages"][i]["role"] == "user":
+                        last_user_message_idx = i
+                        break
+
+                user_message = data["messages"][last_user_message_idx]
+
+                if isinstance(user_message["content"], list):
+                    # Handle list content input
+                    content_type = "list"
+                    query = ""
+                    for content_item in user_message["content"]:
+                        if content_item["type"] == "text":
+                            query = content_item["text"]
+                            break
+                elif isinstance(user_message["content"], str):
+                    # Handle text content input
+                    content_type = "text"
+                    query = user_message["content"]
+                else:
+                    # Fallback in case the input does not match expected types
+                    content_type = None
+                    query = ""
+
+                relevant_contexts = []
+
+                for doc in docs:
+                    context = None
+
+                    try:
+                        if doc["type"] == "collection":
+                            context = query_collection(
+                                collection_names=doc["collection_names"],
+                                query=query,
+                                k=rag_app.state.TOP_K,
+                                embedding_function=rag_app.state.sentence_transformer_ef,
+                            )
+                        else:
+                            context = query_doc(
+                                collection_name=doc["collection_name"],
+                                query=query,
+                                k=rag_app.state.TOP_K,
+                                embedding_function=rag_app.state.sentence_transformer_ef,
+                            )
+                    except Exception as e:
+                        print(e)
+                        context = None
+
+                    relevant_contexts.append(context)
+
+                context_string = ""
+                for context in relevant_contexts:
+                    if context:
+                        context_string += " ".join(context["documents"][0]) + "\n"
+
+                ra_content = rag_template(
+                    template=rag_app.state.RAG_TEMPLATE,
+                    context=context_string,
+                    query=query,
+                )
+
+                if content_type == "list":
+                    new_content = []
+                    for content_item in user_message["content"]:
+                        if content_item["type"] == "text":
+                            # Update the text item's content with ra_content
+                            new_content.append({"type": "text", "text": ra_content})
+                        else:
+                            # Keep other types of content as they are
+                            new_content.append(content_item)
+                    new_user_message = {**user_message, "content": new_content}
+                else:
+                    new_user_message = {
+                        **user_message,
+                        "content": ra_content,
+                    }
+
+                data["messages"][last_user_message_idx] = new_user_message
+                del data["docs"]
+
+                print(data["messages"])
+
+            modified_body_bytes = json.dumps(data).encode("utf-8")
+
+            # Create a new request with the modified body
+            scope = request.scope
+            scope["body"] = modified_body_bytes
+            request = Request(scope, receive=lambda: self._receive(modified_body_bytes))
+
+        response = await call_next(request)
+        return response
+
+    async def _receive(self, body: bytes):
+        return {"type": "http.request", "body": body, "more_body": False}
+
+
+app.add_middleware(RAGMiddleware)
 
 
 
 
 @app.middleware("http")
 @app.middleware("http")
@@ -86,21 +204,6 @@ async def check_url(request: Request, call_next):
     return response
     return response
 
 
 
 
-@litellm_app.middleware("http")
-async def auth_middleware(request: Request, call_next):
-    auth_header = request.headers.get("Authorization", "")
-
-    if ENV != "dev":
-        try:
-            user = get_current_user(get_http_authorization_cred(auth_header))
-            print(user)
-        except Exception as e:
-            return JSONResponse(status_code=400, content={"detail": str(e)})
-
-    response = await call_next(request)
-    return response
-
-
 app.mount("/api/v1", webui_app)
 app.mount("/api/v1", webui_app)
 app.mount("/litellm/api", litellm_app)
 app.mount("/litellm/api", litellm_app)
 
 
@@ -125,6 +228,39 @@ async def get_app_config():
     }
     }
 
 
 
 
+@app.get("/api/config/model/filter")
+async def get_model_filter_config(user=Depends(get_admin_user)):
+    return {
+        "enabled": app.state.MODEL_FILTER_ENABLED,
+        "models": app.state.MODEL_FILTER_LIST,
+    }
+
+
+class ModelFilterConfigForm(BaseModel):
+    enabled: bool
+    models: List[str]
+
+
+@app.post("/api/config/model/filter")
+async def get_model_filter_config(
+    form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
+):
+
+    app.state.MODEL_FILTER_ENABLED = form_data.enabled
+    app.state.MODEL_FILTER_LIST = form_data.models
+
+    ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
+    ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
+
+    openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
+    openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
+
+    return {
+        "enabled": app.state.MODEL_FILTER_ENABLED,
+        "models": app.state.MODEL_FILTER_LIST,
+    }
+
+
 @app.get("/api/version")
 @app.get("/api/version")
 async def get_app_config():
 async def get_app_config():
 
 
@@ -156,6 +292,7 @@ async def get_app_latest_release_version():
 
 
 
 
 app.mount("/static", StaticFiles(directory="static"), name="static")
 app.mount("/static", StaticFiles(directory="static"), name="static")
+app.mount("/cache", StaticFiles(directory="data/cache"), name="cache")
 
 
 
 
 app.mount(
 app.mount(

+ 2 - 1
backend/requirements.txt

@@ -16,7 +16,8 @@ aiohttp
 peewee
 peewee
 bcrypt
 bcrypt
 
 
-litellm
+litellm==1.30.7
+argon2-cffi
 apscheduler
 apscheduler
 google-generativeai
 google-generativeai
 
 

+ 1 - 1
package.json

@@ -1,6 +1,6 @@
 {
 {
 	"name": "open-webui",
 	"name": "open-webui",
-	"version": "0.1.110",
+	"version": "0.1.111",
 	"private": true,
 	"private": true,
 	"scripts": {
 	"scripts": {
 		"dev": "vite dev --host",
 		"dev": "vite dev --host",

+ 83 - 8
src/lib/apis/images/index.ts

@@ -1,9 +1,9 @@
 import { IMAGES_API_BASE_URL } from '$lib/constants';
 import { IMAGES_API_BASE_URL } from '$lib/constants';
 
 
-export const getImageGenerationEnabledStatus = async (token: string = '') => {
+export const getImageGenerationConfig = async (token: string = '') => {
 	let error = null;
 	let error = null;
 
 
-	const res = await fetch(`${IMAGES_API_BASE_URL}/enabled`, {
+	const res = await fetch(`${IMAGES_API_BASE_URL}/config`, {
 		method: 'GET',
 		method: 'GET',
 		headers: {
 		headers: {
 			Accept: 'application/json',
 			Accept: 'application/json',
@@ -32,10 +32,50 @@ export const getImageGenerationEnabledStatus = async (token: string = '') => {
 	return res;
 	return res;
 };
 };
 
 
-export const toggleImageGenerationEnabledStatus = async (token: string = '') => {
+export const updateImageGenerationConfig = async (
+	token: string = '',
+	engine: string,
+	enabled: boolean
+) => {
 	let error = null;
 	let error = null;
 
 
-	const res = await fetch(`${IMAGES_API_BASE_URL}/enabled/toggle`, {
+	const res = await fetch(`${IMAGES_API_BASE_URL}/config/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		},
+		body: JSON.stringify({
+			engine,
+			enabled
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getOpenAIKey = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${IMAGES_API_BASE_URL}/key`, {
 		method: 'GET',
 		method: 'GET',
 		headers: {
 		headers: {
 			Accept: 'application/json',
 			Accept: 'application/json',
@@ -61,7 +101,42 @@ export const toggleImageGenerationEnabledStatus = async (token: string = '') =>
 		throw error;
 		throw error;
 	}
 	}
 
 
-	return res;
+	return res.OPENAI_API_KEY;
+};
+
+export const updateOpenAIKey = async (token: string = '', key: string) => {
+	let error = null;
+
+	const res = await fetch(`${IMAGES_API_BASE_URL}/key/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		},
+		body: JSON.stringify({
+			key: key
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.OPENAI_API_KEY;
 };
 };
 
 
 export const getAUTOMATIC1111Url = async (token: string = '') => {
 export const getAUTOMATIC1111Url = async (token: string = '') => {
@@ -263,7 +338,7 @@ export const updateImageSteps = async (token: string = '', steps: number) => {
 	return res.IMAGE_STEPS;
 	return res.IMAGE_STEPS;
 };
 };
 
 
-export const getDiffusionModels = async (token: string = '') => {
+export const getImageGenerationModels = async (token: string = '') => {
 	let error = null;
 	let error = null;
 
 
 	const res = await fetch(`${IMAGES_API_BASE_URL}/models`, {
 	const res = await fetch(`${IMAGES_API_BASE_URL}/models`, {
@@ -295,7 +370,7 @@ export const getDiffusionModels = async (token: string = '') => {
 	return res;
 	return res;
 };
 };
 
 
-export const getDefaultDiffusionModel = async (token: string = '') => {
+export const getDefaultImageGenerationModel = async (token: string = '') => {
 	let error = null;
 	let error = null;
 
 
 	const res = await fetch(`${IMAGES_API_BASE_URL}/models/default`, {
 	const res = await fetch(`${IMAGES_API_BASE_URL}/models/default`, {
@@ -327,7 +402,7 @@ export const getDefaultDiffusionModel = async (token: string = '') => {
 	return res.model;
 	return res.model;
 };
 };
 
 
-export const updateDefaultDiffusionModel = async (token: string = '', model: string) => {
+export const updateDefaultImageGenerationModel = async (token: string = '', model: string) => {
 	let error = null;
 	let error = null;
 
 
 	const res = await fetch(`${IMAGES_API_BASE_URL}/models/default/update`, {
 	const res = await fetch(`${IMAGES_API_BASE_URL}/models/default/update`, {

+ 62 - 0
src/lib/apis/index.ts

@@ -77,3 +77,65 @@ export const getVersionUpdates = async () => {
 
 
 	return res;
 	return res;
 };
 };
+
+export const getModelFilterConfig = async (token: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_BASE_URL}/api/config/model/filter`, {
+		method: 'GET',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const updateModelFilterConfig = async (
+	token: string,
+	enabled: boolean,
+	models: string[]
+) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_BASE_URL}/api/config/model/filter`, {
+		method: 'POST',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			enabled: enabled,
+			models: models
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 3 - 1
src/lib/apis/litellm/index.ts

@@ -77,6 +77,7 @@ type AddLiteLLMModelForm = {
 	api_base: string;
 	api_base: string;
 	api_key: string;
 	api_key: string;
 	rpm: string;
 	rpm: string;
+	max_tokens: string;
 };
 };
 
 
 export const addLiteLLMModel = async (token: string = '', payload: AddLiteLLMModelForm) => {
 export const addLiteLLMModel = async (token: string = '', payload: AddLiteLLMModelForm) => {
@@ -95,7 +96,8 @@ export const addLiteLLMModel = async (token: string = '', payload: AddLiteLLMMod
 				model: payload.model,
 				model: payload.model,
 				...(payload.api_base === '' ? {} : { api_base: payload.api_base }),
 				...(payload.api_base === '' ? {} : { api_base: payload.api_base }),
 				...(payload.api_key === '' ? {} : { api_key: payload.api_key }),
 				...(payload.api_key === '' ? {} : { api_key: payload.api_key }),
-				...(isNaN(parseInt(payload.rpm)) ? {} : { rpm: parseInt(payload.rpm) })
+				...(isNaN(parseInt(payload.rpm)) ? {} : { rpm: parseInt(payload.rpm) }),
+				...(payload.max_tokens === '' ? {} : { max_tokens: payload.max_tokens })
 			}
 			}
 		})
 		})
 	})
 	})

+ 16 - 7
src/lib/apis/rag/index.ts

@@ -1,9 +1,9 @@
 import { RAG_API_BASE_URL } from '$lib/constants';
 import { RAG_API_BASE_URL } from '$lib/constants';
 
 
-export const getChunkParams = async (token: string) => {
+export const getRAGConfig = async (token: string) => {
 	let error = null;
 	let error = null;
 
 
-	const res = await fetch(`${RAG_API_BASE_URL}/chunk`, {
+	const res = await fetch(`${RAG_API_BASE_URL}/config`, {
 		method: 'GET',
 		method: 'GET',
 		headers: {
 		headers: {
 			'Content-Type': 'application/json',
 			'Content-Type': 'application/json',
@@ -27,18 +27,27 @@ export const getChunkParams = async (token: string) => {
 	return res;
 	return res;
 };
 };
 
 
-export const updateChunkParams = async (token: string, size: number, overlap: number) => {
+type ChunkConfigForm = {
+	chunk_size: number;
+	chunk_overlap: number;
+};
+
+type RAGConfigForm = {
+	pdf_extract_images: boolean;
+	chunk: ChunkConfigForm;
+};
+
+export const updateRAGConfig = async (token: string, payload: RAGConfigForm) => {
 	let error = null;
 	let error = null;
 
 
-	const res = await fetch(`${RAG_API_BASE_URL}/chunk/update`, {
+	const res = await fetch(`${RAG_API_BASE_URL}/config/update`, {
 		method: 'POST',
 		method: 'POST',
 		headers: {
 		headers: {
 			'Content-Type': 'application/json',
 			'Content-Type': 'application/json',
 			Authorization: `Bearer ${token}`
 			Authorization: `Bearer ${token}`
 		},
 		},
 		body: JSON.stringify({
 		body: JSON.stringify({
-			chunk_size: size,
-			chunk_overlap: overlap
+			...payload
 		})
 		})
 	})
 	})
 		.then(async (res) => {
 		.then(async (res) => {
@@ -252,7 +261,7 @@ export const queryCollection = async (
 	token: string,
 	token: string,
 	collection_names: string,
 	collection_names: string,
 	query: string,
 	query: string,
-	k: number
+	k: number | null = null
 ) => {
 ) => {
 	let error = null;
 	let error = null;
 
 

+ 113 - 0
src/lib/components/admin/Settings/Users.svelte

@@ -1,10 +1,14 @@
 <script lang="ts">
 <script lang="ts">
+	import { getModelFilterConfig, updateModelFilterConfig } from '$lib/apis';
 	import { getSignUpEnabledStatus, toggleSignUpEnabledStatus } from '$lib/apis/auths';
 	import { getSignUpEnabledStatus, toggleSignUpEnabledStatus } from '$lib/apis/auths';
 	import { getUserPermissions, updateUserPermissions } from '$lib/apis/users';
 	import { getUserPermissions, updateUserPermissions } from '$lib/apis/users';
+	import { models } from '$lib/stores';
 	import { onMount } from 'svelte';
 	import { onMount } from 'svelte';
 
 
 	export let saveHandler: Function;
 	export let saveHandler: Function;
 
 
+	let whitelistEnabled = false;
+	let whitelistModels = [''];
 	let permissions = {
 	let permissions = {
 		chat: {
 		chat: {
 			deletion: true
 			deletion: true
@@ -13,6 +17,13 @@
 
 
 	onMount(async () => {
 	onMount(async () => {
 		permissions = await getUserPermissions(localStorage.token);
 		permissions = await getUserPermissions(localStorage.token);
+
+		const res = await getModelFilterConfig(localStorage.token);
+		if (res) {
+			whitelistEnabled = res.enabled;
+
+			whitelistModels = res.models.length > 0 ? res.models : [''];
+		}
 	});
 	});
 </script>
 </script>
 
 
@@ -21,6 +32,8 @@
 	on:submit|preventDefault={async () => {
 	on:submit|preventDefault={async () => {
 		// console.log('submit');
 		// console.log('submit');
 		await updateUserPermissions(localStorage.token, permissions);
 		await updateUserPermissions(localStorage.token, permissions);
+
+		await updateModelFilterConfig(localStorage.token, whitelistEnabled, whitelistModels);
 		saveHandler();
 		saveHandler();
 	}}
 	}}
 >
 >
@@ -69,6 +82,106 @@
 				</button>
 				</button>
 			</div>
 			</div>
 		</div>
 		</div>
+
+		<hr class=" dark:border-gray-700 my-2" />
+
+		<div class="mt-2 space-y-3 pr-1.5">
+			<div>
+				<div class="mb-2">
+					<div class="flex justify-between items-center text-xs">
+						<div class=" text-sm font-medium">Manage Models</div>
+					</div>
+				</div>
+
+				<div class=" space-y-3">
+					<div>
+						<div class="flex justify-between items-center text-xs">
+							<div class=" text-xs font-medium">Model Whitelisting</div>
+
+							<button
+								class=" text-xs font-medium text-gray-500"
+								type="button"
+								on:click={() => {
+									whitelistEnabled = !whitelistEnabled;
+								}}>{whitelistEnabled ? 'On' : 'Off'}</button
+							>
+						</div>
+					</div>
+
+					{#if whitelistEnabled}
+						<div>
+							<div class=" space-y-1.5">
+								{#each whitelistModels as modelId, modelIdx}
+									<div class="flex w-full">
+										<div class="flex-1 mr-2">
+											<select
+												class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+												bind:value={modelId}
+												placeholder="Select a model"
+											>
+												<option value="" disabled selected>Select a model</option>
+												{#each $models.filter((model) => model.id) as model}
+													<option value={model.id} class="bg-gray-100 dark:bg-gray-700"
+														>{model.name}</option
+													>
+												{/each}
+											</select>
+										</div>
+
+										{#if modelIdx === 0}
+											<button
+												class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-900 dark:text-white rounded-lg transition"
+												type="button"
+												on:click={() => {
+													if (whitelistModels.at(-1) !== '') {
+														whitelistModels = [...whitelistModels, ''];
+													}
+												}}
+											>
+												<svg
+													xmlns="http://www.w3.org/2000/svg"
+													viewBox="0 0 16 16"
+													fill="currentColor"
+													class="w-4 h-4"
+												>
+													<path
+														d="M8.75 3.75a.75.75 0 0 0-1.5 0v3.5h-3.5a.75.75 0 0 0 0 1.5h3.5v3.5a.75.75 0 0 0 1.5 0v-3.5h3.5a.75.75 0 0 0 0-1.5h-3.5v-3.5Z"
+													/>
+												</svg>
+											</button>
+										{:else}
+											<button
+												class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-900 dark:text-white rounded-lg transition"
+												type="button"
+												on:click={() => {
+													whitelistModels.splice(modelIdx, 1);
+													whitelistModels = whitelistModels;
+												}}
+											>
+												<svg
+													xmlns="http://www.w3.org/2000/svg"
+													viewBox="0 0 16 16"
+													fill="currentColor"
+													class="w-4 h-4"
+												>
+													<path d="M3.75 7.25a.75.75 0 0 0 0 1.5h8.5a.75.75 0 0 0 0-1.5h-8.5Z" />
+												</svg>
+											</button>
+										{/if}
+									</div>
+								{/each}
+							</div>
+
+							<div class="flex justify-end items-center text-xs mt-1.5 text-right">
+								<div class=" text-xs font-medium">
+									{whitelistModels.length} Model(s) Whitelisted
+								</div>
+							</div>
+						</div>
+					{/if}
+				</div>
+			</div>
+		</div>
 	</div>
 	</div>
 
 
 	<div class="flex justify-end pt-3 text-sm font-medium">
 	<div class="flex justify-end pt-3 text-sm font-medium">

+ 3 - 3
src/lib/components/chat/MessageInput.svelte

@@ -19,7 +19,7 @@
 
 
 	export let suggestionPrompts = [];
 	export let suggestionPrompts = [];
 	export let autoScroll = true;
 	export let autoScroll = true;
-	let chatTextAreaElement:HTMLTextAreaElement
+	let chatTextAreaElement: HTMLTextAreaElement;
 	let filesInputElement;
 	let filesInputElement;
 
 
 	let promptsElement;
 	let promptsElement;
@@ -359,12 +359,12 @@
 
 
 {#if dragged}
 {#if dragged}
 	<div
 	<div
-		class="fixed w-full h-full flex z-50 touch-none pointer-events-none"
+		class="fixed lg:w-[calc(100%-260px)] w-full h-full flex z-50 touch-none pointer-events-none"
 		id="dropzone"
 		id="dropzone"
 		role="region"
 		role="region"
 		aria-label="Drag and Drop Container"
 		aria-label="Drag and Drop Container"
 	>
 	>
-		<div class="absolute rounded-xl w-full h-full backdrop-blur bg-gray-800/40 flex justify-center">
+		<div class="absolute w-full h-full backdrop-blur bg-gray-800/40 flex justify-center">
 			<div class="m-auto pt-64 flex flex-col justify-center">
 			<div class="m-auto pt-64 flex flex-col justify-center">
 				<div class="max-w-md">
 				<div class="max-w-md">
 					<AddFilesPlaceholder />
 					<AddFilesPlaceholder />

+ 14 - 11
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -6,6 +6,7 @@
 	import auto_render from 'katex/dist/contrib/auto-render.mjs';
 	import auto_render from 'katex/dist/contrib/auto-render.mjs';
 	import 'katex/dist/katex.min.css';
 	import 'katex/dist/katex.min.css';
 
 
+	import { fade } from 'svelte/transition';
 	import { createEventDispatcher } from 'svelte';
 	import { createEventDispatcher } from 'svelte';
 	import { onMount, tick } from 'svelte';
 	import { onMount, tick } from 'svelte';
 
 
@@ -276,13 +277,15 @@
 
 
 	const generateImage = async (message) => {
 	const generateImage = async (message) => {
 		generatingImage = true;
 		generatingImage = true;
-		const res = await imageGenerations(localStorage.token, message.content);
+		const res = await imageGenerations(localStorage.token, message.content).catch((error) => {
+			toast.error(error);
+		});
 		console.log(res);
 		console.log(res);
 
 
 		if (res) {
 		if (res) {
-			message.files = res.images.map((image) => ({
+			message.files = res.map((image) => ({
 				type: 'image',
 				type: 'image',
-				url: `data:image/png;base64,${image}`
+				url: `${image.url}`
 			}));
 			}));
 
 
 			dispatch('save', message);
 			dispatch('save', message);
@@ -477,7 +480,7 @@
 													xmlns="http://www.w3.org/2000/svg"
 													xmlns="http://www.w3.org/2000/svg"
 													fill="none"
 													fill="none"
 													viewBox="0 0 24 24"
 													viewBox="0 0 24 24"
-													stroke-width="1.5"
+													stroke-width="2"
 													stroke="currentColor"
 													stroke="currentColor"
 													class="w-4 h-4"
 													class="w-4 h-4"
 												>
 												>
@@ -503,7 +506,7 @@
 													xmlns="http://www.w3.org/2000/svg"
 													xmlns="http://www.w3.org/2000/svg"
 													fill="none"
 													fill="none"
 													viewBox="0 0 24 24"
 													viewBox="0 0 24 24"
-													stroke-width="1.5"
+													stroke-width="2"
 													stroke="currentColor"
 													stroke="currentColor"
 													class="w-4 h-4"
 													class="w-4 h-4"
 												>
 												>
@@ -622,7 +625,7 @@
 														xmlns="http://www.w3.org/2000/svg"
 														xmlns="http://www.w3.org/2000/svg"
 														fill="none"
 														fill="none"
 														viewBox="0 0 24 24"
 														viewBox="0 0 24 24"
-														stroke-width="1.5"
+														stroke-width="2"
 														stroke="currentColor"
 														stroke="currentColor"
 														class="w-4 h-4"
 														class="w-4 h-4"
 													>
 													>
@@ -637,7 +640,7 @@
 														xmlns="http://www.w3.org/2000/svg"
 														xmlns="http://www.w3.org/2000/svg"
 														fill="none"
 														fill="none"
 														viewBox="0 0 24 24"
 														viewBox="0 0 24 24"
-														stroke-width="1.5"
+														stroke-width="2"
 														stroke="currentColor"
 														stroke="currentColor"
 														class="w-4 h-4"
 														class="w-4 h-4"
 													>
 													>
@@ -703,7 +706,7 @@
 															xmlns="http://www.w3.org/2000/svg"
 															xmlns="http://www.w3.org/2000/svg"
 															fill="none"
 															fill="none"
 															viewBox="0 0 24 24"
 															viewBox="0 0 24 24"
-															stroke-width="1.5"
+															stroke-width="2"
 															stroke="currentColor"
 															stroke="currentColor"
 															class="w-4 h-4"
 															class="w-4 h-4"
 														>
 														>
@@ -733,7 +736,7 @@
 														xmlns="http://www.w3.org/2000/svg"
 														xmlns="http://www.w3.org/2000/svg"
 														fill="none"
 														fill="none"
 														viewBox="0 0 24 24"
 														viewBox="0 0 24 24"
-														stroke-width="1.5"
+														stroke-width="2"
 														stroke="currentColor"
 														stroke="currentColor"
 														class="w-4 h-4"
 														class="w-4 h-4"
 													>
 													>
@@ -762,7 +765,7 @@
 														xmlns="http://www.w3.org/2000/svg"
 														xmlns="http://www.w3.org/2000/svg"
 														fill="none"
 														fill="none"
 														viewBox="0 0 24 24"
 														viewBox="0 0 24 24"
-														stroke-width="1.5"
+														stroke-width="2"
 														stroke="currentColor"
 														stroke="currentColor"
 														class="w-4 h-4"
 														class="w-4 h-4"
 													>
 													>
@@ -792,7 +795,7 @@
 														xmlns="http://www.w3.org/2000/svg"
 														xmlns="http://www.w3.org/2000/svg"
 														fill="none"
 														fill="none"
 														viewBox="0 0 24 24"
 														viewBox="0 0 24 24"
-														stroke-width="1.5"
+														stroke-width="2"
 														stroke="currentColor"
 														stroke="currentColor"
 														class="w-4 h-4"
 														class="w-4 h-4"
 													>
 													>

+ 3 - 3
src/lib/components/chat/Messages/UserMessage.svelte

@@ -258,7 +258,7 @@
 									xmlns="http://www.w3.org/2000/svg"
 									xmlns="http://www.w3.org/2000/svg"
 									fill="none"
 									fill="none"
 									viewBox="0 0 24 24"
 									viewBox="0 0 24 24"
-									stroke-width="1.5"
+									stroke-width="2"
 									stroke="currentColor"
 									stroke="currentColor"
 									class="w-4 h-4"
 									class="w-4 h-4"
 								>
 								>
@@ -282,7 +282,7 @@
 									xmlns="http://www.w3.org/2000/svg"
 									xmlns="http://www.w3.org/2000/svg"
 									fill="none"
 									fill="none"
 									viewBox="0 0 24 24"
 									viewBox="0 0 24 24"
-									stroke-width="1.5"
+									stroke-width="2"
 									stroke="currentColor"
 									stroke="currentColor"
 									class="w-4 h-4"
 									class="w-4 h-4"
 								>
 								>
@@ -307,7 +307,7 @@
 										xmlns="http://www.w3.org/2000/svg"
 										xmlns="http://www.w3.org/2000/svg"
 										fill="none"
 										fill="none"
 										viewBox="0 0 24 24"
 										viewBox="0 0 24 24"
-										stroke-width="1.5"
+										stroke-width="2"
 										stroke="currentColor"
 										stroke="currentColor"
 										class="w-4 h-4"
 										class="w-4 h-4"
 									>
 									>

+ 1 - 1
src/lib/components/chat/Settings/Account.svelte

@@ -271,7 +271,7 @@
 
 
 	<div class="flex justify-end pt-3 text-sm font-medium">
 	<div class="flex justify-end pt-3 text-sm font-medium">
 		<button
 		<button
-			class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded"
+			class="  px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
 			on:click={async () => {
 			on:click={async () => {
 				const res = await submitHandler();
 				const res = await submitHandler();
 
 

+ 1 - 1
src/lib/components/chat/Settings/Audio.svelte

@@ -251,7 +251,7 @@
 
 
 	<div class="flex justify-end pt-3 text-sm font-medium">
 	<div class="flex justify-end pt-3 text-sm font-medium">
 		<button
 		<button
-			class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded"
+			class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
 			type="submit"
 			type="submit"
 		>
 		>
 			Save
 			Save

+ 1 - 1
src/lib/components/chat/Settings/Connections.svelte

@@ -247,7 +247,7 @@
 
 
 	<div class="flex justify-end pt-3 text-sm font-medium">
 	<div class="flex justify-end pt-3 text-sm font-medium">
 		<button
 		<button
-			class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded"
+			class="  px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
 			type="submit"
 			type="submit"
 		>
 		>
 			Save
 			Save

+ 2 - 2
src/lib/components/chat/Settings/General.svelte

@@ -176,7 +176,7 @@
 			<div class=" my-2.5 text-sm font-medium">System Prompt</div>
 			<div class=" my-2.5 text-sm font-medium">System Prompt</div>
 			<textarea
 			<textarea
 				bind:value={system}
 				bind:value={system}
-				class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none"
+				class="w-full rounded-lg p-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
 				rows="4"
 				rows="4"
 			/>
 			/>
 		</div>
 		</div>
@@ -262,7 +262,7 @@
 
 
 	<div class="flex justify-end pt-3 text-sm font-medium">
 	<div class="flex justify-end pt-3 text-sm font-medium">
 		<button
 		<button
-			class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded"
+			class="  px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
 			on:click={() => {
 			on:click={() => {
 				saveSettings({
 				saveSettings({
 					system: system !== '' ? system : undefined,
 					system: system !== '' ? system : undefined,

+ 135 - 76
src/lib/components/chat/Settings/Images.svelte

@@ -5,16 +5,18 @@
 	import { config, user } from '$lib/stores';
 	import { config, user } from '$lib/stores';
 	import {
 	import {
 		getAUTOMATIC1111Url,
 		getAUTOMATIC1111Url,
-		getDefaultDiffusionModel,
-		getDiffusionModels,
-		getImageGenerationEnabledStatus,
+		getImageGenerationModels,
+		getDefaultImageGenerationModel,
+		updateDefaultImageGenerationModel,
 		getImageSize,
 		getImageSize,
-		toggleImageGenerationEnabledStatus,
+		getImageGenerationConfig,
+		updateImageGenerationConfig,
 		updateAUTOMATIC1111Url,
 		updateAUTOMATIC1111Url,
-		updateDefaultDiffusionModel,
 		updateImageSize,
 		updateImageSize,
 		getImageSteps,
 		getImageSteps,
-		updateImageSteps
+		updateImageSteps,
+		getOpenAIKey,
+		updateOpenAIKey
 	} from '$lib/apis/images';
 	} from '$lib/apis/images';
 	import { getBackendConfig } from '$lib/apis';
 	import { getBackendConfig } from '$lib/apis';
 	const dispatch = createEventDispatcher();
 	const dispatch = createEventDispatcher();
@@ -23,8 +25,11 @@
 
 
 	let loading = false;
 	let loading = false;
 
 
+	let imageGenerationEngine = '';
 	let enableImageGeneration = false;
 	let enableImageGeneration = false;
+
 	let AUTOMATIC1111_BASE_URL = '';
 	let AUTOMATIC1111_BASE_URL = '';
+	let OPENAI_API_KEY = '';
 
 
 	let selectedModel = '';
 	let selectedModel = '';
 	let models = null;
 	let models = null;
@@ -33,11 +38,11 @@
 	let steps = 50;
 	let steps = 50;
 
 
 	const getModels = async () => {
 	const getModels = async () => {
-		models = await getDiffusionModels(localStorage.token).catch((error) => {
+		models = await getImageGenerationModels(localStorage.token).catch((error) => {
 			toast.error(error);
 			toast.error(error);
 			return null;
 			return null;
 		});
 		});
-		selectedModel = await getDefaultDiffusionModel(localStorage.token).catch((error) => {
+		selectedModel = await getDefaultImageGenerationModel(localStorage.token).catch((error) => {
 			return '';
 			return '';
 		});
 		});
 	};
 	};
@@ -62,33 +67,45 @@
 			AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
 			AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
 		}
 		}
 	};
 	};
-	const toggleImageGeneration = async () => {
-		if (AUTOMATIC1111_BASE_URL) {
-			enableImageGeneration = await toggleImageGenerationEnabledStatus(localStorage.token).catch(
-				(error) => {
-					toast.error(error);
-					return false;
-				}
-			);
+	const updateImageGeneration = async () => {
+		const res = await updateImageGenerationConfig(
+			localStorage.token,
+			imageGenerationEngine,
+			enableImageGeneration
+		).catch((error) => {
+			toast.error(error);
+			return null;
+		});
 
 
-			if (enableImageGeneration) {
-				config.set(await getBackendConfig(localStorage.token));
-				getModels();
-			}
-		} else {
-			enableImageGeneration = false;
-			toast.error('AUTOMATIC1111_BASE_URL not provided');
+		if (res) {
+			imageGenerationEngine = res.engine;
+			enableImageGeneration = res.enabled;
+		}
+
+		if (enableImageGeneration) {
+			config.set(await getBackendConfig(localStorage.token));
+			getModels();
 		}
 		}
 	};
 	};
 
 
 	onMount(async () => {
 	onMount(async () => {
 		if ($user.role === 'admin') {
 		if ($user.role === 'admin') {
-			enableImageGeneration = await getImageGenerationEnabledStatus(localStorage.token);
+			const res = await getImageGenerationConfig(localStorage.token).catch((error) => {
+				toast.error(error);
+				return null;
+			});
+
+			if (res) {
+				imageGenerationEngine = res.engine;
+				enableImageGeneration = res.enabled;
+			}
 			AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
 			AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
+			OPENAI_API_KEY = await getOpenAIKey(localStorage.token);
+
+			imageSize = await getImageSize(localStorage.token);
+			steps = await getImageSteps(localStorage.token);
 
 
-			if (enableImageGeneration && AUTOMATIC1111_BASE_URL) {
-				imageSize = await getImageSize(localStorage.token);
-				steps = await getImageSteps(localStorage.token);
+			if (enableImageGeneration) {
 				getModels();
 				getModels();
 			}
 			}
 		}
 		}
@@ -99,7 +116,11 @@
 	class="flex flex-col h-full justify-between space-y-3 text-sm"
 	class="flex flex-col h-full justify-between space-y-3 text-sm"
 	on:submit|preventDefault={async () => {
 	on:submit|preventDefault={async () => {
 		loading = true;
 		loading = true;
-		await updateDefaultDiffusionModel(localStorage.token, selectedModel);
+		await updateOpenAIKey(localStorage.token, OPENAI_API_KEY);
+
+		await updateDefaultImageGenerationModel(localStorage.token, selectedModel);
+
+		await updateDefaultImageGenerationModel(localStorage.token, selectedModel);
 		await updateImageSize(localStorage.token, imageSize).catch((error) => {
 		await updateImageSize(localStorage.token, imageSize).catch((error) => {
 			toast.error(error);
 			toast.error(error);
 			return null;
 			return null;
@@ -117,6 +138,23 @@
 		<div>
 		<div>
 			<div class=" mb-1 text-sm font-medium">Image Settings</div>
 			<div class=" mb-1 text-sm font-medium">Image Settings</div>
 
 
+			<div class=" py-0.5 flex w-full justify-between">
+				<div class=" self-center text-xs font-medium">Image Generation Engine</div>
+				<div class="flex items-center relative">
+					<select
+						class="w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
+						bind:value={imageGenerationEngine}
+						placeholder="Select a mode"
+						on:change={async () => {
+							await updateImageGeneration();
+						}}
+					>
+						<option value="">Default (Automatic1111)</option>
+						<option value="openai">Open AI (Dall-E)</option>
+					</select>
+				</div>
+			</div>
+
 			<div>
 			<div>
 				<div class=" py-0.5 flex w-full justify-between">
 				<div class=" py-0.5 flex w-full justify-between">
 					<div class=" self-center text-xs font-medium">Image Generation (Experimental)</div>
 					<div class=" self-center text-xs font-medium">Image Generation (Experimental)</div>
@@ -124,7 +162,17 @@
 					<button
 					<button
 						class="p-1 px-3 text-xs flex rounded transition"
 						class="p-1 px-3 text-xs flex rounded transition"
 						on:click={() => {
 						on:click={() => {
-							toggleImageGeneration();
+							if (imageGenerationEngine === '' && AUTOMATIC1111_BASE_URL === '') {
+								toast.error('AUTOMATIC1111 Base URL is required.');
+								enableImageGeneration = false;
+							} else if (imageGenerationEngine === 'openai' && OPENAI_API_KEY === '') {
+								toast.error('OpenAI API Key is required.');
+								enableImageGeneration = false;
+							} else {
+								enableImageGeneration = !enableImageGeneration;
+							}
+
+							updateImageGeneration();
 						}}
 						}}
 						type="button"
 						type="button"
 					>
 					>
@@ -139,49 +187,62 @@
 		</div>
 		</div>
 		<hr class=" dark:border-gray-700" />
 		<hr class=" dark:border-gray-700" />
 
 
-		<div class=" mb-2.5 text-sm font-medium">AUTOMATIC1111 Base URL</div>
-		<div class="flex w-full">
-			<div class="flex-1 mr-2">
-				<input
-					class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-					placeholder="Enter URL (e.g. http://127.0.0.1:7860/)"
-					bind:value={AUTOMATIC1111_BASE_URL}
-				/>
+		{#if imageGenerationEngine === ''}
+			<div class=" mb-2.5 text-sm font-medium">AUTOMATIC1111 Base URL</div>
+			<div class="flex w-full">
+				<div class="flex-1 mr-2">
+					<input
+						class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+						placeholder="Enter URL (e.g. http://127.0.0.1:7860/)"
+						bind:value={AUTOMATIC1111_BASE_URL}
+					/>
+				</div>
+				<button
+					class="px-3 bg-gray-200 hover:bg-gray-300 dark:bg-gray-600 dark:hover:bg-gray-700 rounded-lg transition"
+					type="button"
+					on:click={() => {
+						// updateOllamaAPIUrlHandler();
+
+						updateAUTOMATIC1111UrlHandler();
+					}}
+				>
+					<svg
+						xmlns="http://www.w3.org/2000/svg"
+						viewBox="0 0 20 20"
+						fill="currentColor"
+						class="w-4 h-4"
+					>
+						<path
+							fill-rule="evenodd"
+							d="M15.312 11.424a5.5 5.5 0 01-9.201 2.466l-.312-.311h2.433a.75.75 0 000-1.5H3.989a.75.75 0 00-.75.75v4.242a.75.75 0 001.5 0v-2.43l.31.31a7 7 0 0011.712-3.138.75.75 0 00-1.449-.39zm1.23-3.723a.75.75 0 00.219-.53V2.929a.75.75 0 00-1.5 0V5.36l-.31-.31A7 7 0 003.239 8.188a.75.75 0 101.448.389A5.5 5.5 0 0113.89 6.11l.311.31h-2.432a.75.75 0 000 1.5h4.243a.75.75 0 00.53-.219z"
+							clip-rule="evenodd"
+						/>
+					</svg>
+				</button>
 			</div>
 			</div>
-			<button
-				class="px-3 bg-gray-200 hover:bg-gray-300 dark:bg-gray-600 dark:hover:bg-gray-700 rounded transition"
-				type="button"
-				on:click={() => {
-					// updateOllamaAPIUrlHandler();
-
-					updateAUTOMATIC1111UrlHandler();
-				}}
-			>
-				<svg
-					xmlns="http://www.w3.org/2000/svg"
-					viewBox="0 0 20 20"
-					fill="currentColor"
-					class="w-4 h-4"
+
+			<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
+				Include `--api` flag when running stable-diffusion-webui
+				<a
+					class=" text-gray-300 font-medium"
+					href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/3734"
+					target="_blank"
 				>
 				>
-					<path
-						fill-rule="evenodd"
-						d="M15.312 11.424a5.5 5.5 0 01-9.201 2.466l-.312-.311h2.433a.75.75 0 000-1.5H3.989a.75.75 0 00-.75.75v4.242a.75.75 0 001.5 0v-2.43l.31.31a7 7 0 0011.712-3.138.75.75 0 00-1.449-.39zm1.23-3.723a.75.75 0 00.219-.53V2.929a.75.75 0 00-1.5 0V5.36l-.31-.31A7 7 0 003.239 8.188a.75.75 0 101.448.389A5.5 5.5 0 0113.89 6.11l.311.31h-2.432a.75.75 0 000 1.5h4.243a.75.75 0 00.53-.219z"
-						clip-rule="evenodd"
+					(e.g. `sh webui.sh --api`)
+				</a>
+			</div>
+		{:else if imageGenerationEngine === 'openai'}
+			<div class=" mb-2.5 text-sm font-medium">OpenAI API Key</div>
+			<div class="flex w-full">
+				<div class="flex-1 mr-2">
+					<input
+						class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+						placeholder="Enter API Key"
+						bind:value={OPENAI_API_KEY}
 					/>
 					/>
-				</svg>
-			</button>
-		</div>
-
-		<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
-			Include `--api` flag when running stable-diffusion-webui
-			<a
-				class=" text-gray-300 font-medium"
-				href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/3734"
-				target="_blank"
-			>
-				(e.g. `sh webui.sh --api`)
-			</a>
-		</div>
+				</div>
+			</div>
+		{/if}
 
 
 		{#if enableImageGeneration}
 		{#if enableImageGeneration}
 			<hr class=" dark:border-gray-700" />
 			<hr class=" dark:border-gray-700" />
@@ -191,7 +252,7 @@
 				<div class="flex w-full">
 				<div class="flex w-full">
 					<div class="flex-1 mr-2">
 					<div class="flex-1 mr-2">
 						<select
 						<select
-							class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
+							class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
 							bind:value={selectedModel}
 							bind:value={selectedModel}
 							placeholder="Select a model"
 							placeholder="Select a model"
 						>
 						>
@@ -199,9 +260,7 @@
 								<option value="" disabled selected>Select a model</option>
 								<option value="" disabled selected>Select a model</option>
 							{/if}
 							{/if}
 							{#each models ?? [] as model}
 							{#each models ?? [] as model}
-								<option value={model.title} class="bg-gray-100 dark:bg-gray-700"
-									>{model.model_name}</option
-								>
+								<option value={model.id} class="bg-gray-100 dark:bg-gray-700">{model.name}</option>
 							{/each}
 							{/each}
 						</select>
 						</select>
 					</div>
 					</div>
@@ -213,7 +272,7 @@
 				<div class="flex w-full">
 				<div class="flex w-full">
 					<div class="flex-1 mr-2">
 					<div class="flex-1 mr-2">
 						<input
 						<input
-							class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
+							class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
 							placeholder="Enter Image Size (e.g. 512x512)"
 							placeholder="Enter Image Size (e.g. 512x512)"
 							bind:value={imageSize}
 							bind:value={imageSize}
 						/>
 						/>
@@ -226,7 +285,7 @@
 				<div class="flex w-full">
 				<div class="flex w-full">
 					<div class="flex-1 mr-2">
 					<div class="flex-1 mr-2">
 						<input
 						<input
-							class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
+							class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
 							placeholder="Enter Number of Steps (e.g. 50)"
 							placeholder="Enter Number of Steps (e.g. 50)"
 							bind:value={steps}
 							bind:value={steps}
 						/>
 						/>
@@ -238,7 +297,7 @@
 
 
 	<div class="flex justify-end pt-3 text-sm font-medium">
 	<div class="flex justify-end pt-3 text-sm font-medium">
 		<button
 		<button
-			class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded flex flex-row space-x-1 items-center {loading
+			class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg flex flex-row space-x-1 items-center {loading
 				? ' cursor-not-allowed'
 				? ' cursor-not-allowed'
 				: ''}"
 				: ''}"
 			type="submit"
 			type="submit"

+ 5 - 27
src/lib/components/chat/Settings/Interface.svelte

@@ -63,6 +63,7 @@
 		}
 		}
 
 
 		saveSettings({
 		saveSettings({
+			titleAutoGenerateModel: titleAutoGenerateModel !== '' ? titleAutoGenerateModel : undefined,
 			titleGenerationPrompt: titleGenerationPrompt ? titleGenerationPrompt : undefined
 			titleGenerationPrompt: titleGenerationPrompt ? titleGenerationPrompt : undefined
 		});
 		});
 	};
 	};
@@ -186,7 +187,7 @@
 			<div class="flex w-full">
 			<div class="flex w-full">
 				<div class="flex-1 mr-2">
 				<div class="flex-1 mr-2">
 					<select
 					<select
-						class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
+						class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
 						bind:value={titleAutoGenerateModel}
 						bind:value={titleAutoGenerateModel}
 						placeholder="Select a model"
 						placeholder="Select a model"
 					>
 					>
@@ -200,35 +201,12 @@
 						{/each}
 						{/each}
 					</select>
 					</select>
 				</div>
 				</div>
-				<button
-					class="px-3 bg-gray-200 hover:bg-gray-300 dark:bg-gray-700 dark:hover:bg-gray-800 dark:text-gray-100 rounded transition"
-					on:click={() => {
-						saveSettings({
-							titleAutoGenerateModel:
-								titleAutoGenerateModel !== '' ? titleAutoGenerateModel : undefined
-						});
-					}}
-					type="button"
-				>
-					<svg
-						xmlns="http://www.w3.org/2000/svg"
-						viewBox="0 0 16 16"
-						fill="currentColor"
-						class="w-3.5 h-3.5"
-					>
-						<path
-							fill-rule="evenodd"
-							d="M13.836 2.477a.75.75 0 0 1 .75.75v3.182a.75.75 0 0 1-.75.75h-3.182a.75.75 0 0 1 0-1.5h1.37l-.84-.841a4.5 4.5 0 0 0-7.08.932.75.75 0 0 1-1.3-.75 6 6 0 0 1 9.44-1.242l.842.84V3.227a.75.75 0 0 1 .75-.75Zm-.911 7.5A.75.75 0 0 1 13.199 11a6 6 0 0 1-9.44 1.241l-.84-.84v1.371a.75.75 0 0 1-1.5 0V9.591a.75.75 0 0 1 .75-.75H5.35a.75.75 0 0 1 0 1.5H3.98l.841.841a4.5 4.5 0 0 0 7.08-.932.75.75 0 0 1 1.025-.273Z"
-							clip-rule="evenodd"
-						/>
-					</svg>
-				</button>
 			</div>
 			</div>
-			<div class="mt-3">
+			<div class="mt-3 mr-2">
 				<div class=" mb-2.5 text-sm font-medium">Title Generation Prompt</div>
 				<div class=" mb-2.5 text-sm font-medium">Title Generation Prompt</div>
 				<textarea
 				<textarea
 					bind:value={titleGenerationPrompt}
 					bind:value={titleGenerationPrompt}
-					class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none"
+					class="w-full rounded-lg p-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
 					rows="3"
 					rows="3"
 				/>
 				/>
 			</div>
 			</div>
@@ -321,7 +299,7 @@
 
 
 	<div class="flex justify-end pt-3 text-sm font-medium">
 	<div class="flex justify-end pt-3 text-sm font-medium">
 		<button
 		<button
-			class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded"
+			class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
 			type="submit"
 			type="submit"
 		>
 		>
 			Save
 			Save

+ 277 - 221
src/lib/components/chat/Settings/Models.svelte

@@ -14,6 +14,7 @@
 	import { splitStream } from '$lib/utils';
 	import { splitStream } from '$lib/utils';
 	import { onMount } from 'svelte';
 	import { onMount } from 'svelte';
 	import { addLiteLLMModel, deleteLiteLLMModel, getLiteLLMModelInfo } from '$lib/apis/litellm';
 	import { addLiteLLMModel, deleteLiteLLMModel, getLiteLLMModelInfo } from '$lib/apis/litellm';
+	import Tooltip from '$lib/components/common/Tooltip.svelte';
 
 
 	export let getModels: Function;
 	export let getModels: Function;
 
 
@@ -27,6 +28,7 @@
 	let liteLLMAPIBase = '';
 	let liteLLMAPIBase = '';
 	let liteLLMAPIKey = '';
 	let liteLLMAPIKey = '';
 	let liteLLMRPM = '';
 	let liteLLMRPM = '';
+	let liteLLMMaxTokens = '';
 
 
 	let deleteLiteLLMModelId = '';
 	let deleteLiteLLMModelId = '';
 
 
@@ -36,6 +38,10 @@
 
 
 	let OLLAMA_URLS = [];
 	let OLLAMA_URLS = [];
 	let selectedOllamaUrlIdx: string | null = null;
 	let selectedOllamaUrlIdx: string | null = null;
+
+	let updateModelId = null;
+	let updateProgress = null;
+
 	let showExperimentalOllama = false;
 	let showExperimentalOllama = false;
 	let ollamaVersion = '';
 	let ollamaVersion = '';
 	const MAX_PARALLEL_DOWNLOADS = 3;
 	const MAX_PARALLEL_DOWNLOADS = 3;
@@ -60,6 +66,71 @@
 
 
 	let deleteModelTag = '';
 	let deleteModelTag = '';
 
 
+	const updateModelsHandler = async () => {
+		for (const model of $models.filter(
+			(m) =>
+				m.size != null &&
+				(selectedOllamaUrlIdx === null ? true : (m?.urls ?? []).includes(selectedOllamaUrlIdx))
+		)) {
+			console.log(model);
+
+			updateModelId = model.id;
+			const res = await pullModel(localStorage.token, model.id, selectedOllamaUrlIdx).catch(
+				(error) => {
+					toast.error(error);
+					return null;
+				}
+			);
+
+			if (res) {
+				const reader = res.body
+					.pipeThrough(new TextDecoderStream())
+					.pipeThrough(splitStream('\n'))
+					.getReader();
+
+				while (true) {
+					try {
+						const { value, done } = await reader.read();
+						if (done) break;
+
+						let lines = value.split('\n');
+
+						for (const line of lines) {
+							if (line !== '') {
+								let data = JSON.parse(line);
+
+								console.log(data);
+								if (data.error) {
+									throw data.error;
+								}
+								if (data.detail) {
+									throw data.detail;
+								}
+								if (data.status) {
+									if (data.digest) {
+										updateProgress = 0;
+										if (data.completed) {
+											updateProgress = Math.round((data.completed / data.total) * 1000) / 10;
+										} else {
+											updateProgress = 100;
+										}
+									} else {
+										toast.success(data.status);
+									}
+								}
+							}
+						}
+					} catch (error) {
+						console.log(error);
+					}
+				}
+			}
+		}
+
+		updateModelId = null;
+		updateProgress = null;
+	};
+
 	const pullModelHandler = async () => {
 	const pullModelHandler = async () => {
 		const sanitizedModelTag = modelTag.trim();
 		const sanitizedModelTag = modelTag.trim();
 		if (modelDownloadStatus[sanitizedModelTag]) {
 		if (modelDownloadStatus[sanitizedModelTag]) {
@@ -326,7 +397,8 @@
 				model: liteLLMModel,
 				model: liteLLMModel,
 				api_base: liteLLMAPIBase,
 				api_base: liteLLMAPIBase,
 				api_key: liteLLMAPIKey,
 				api_key: liteLLMAPIKey,
-				rpm: liteLLMRPM
+				rpm: liteLLMRPM,
+				max_tokens: liteLLMMaxTokens
 			}).catch((error) => {
 			}).catch((error) => {
 				toast.error(error);
 				toast.error(error);
 				return null;
 				return null;
@@ -346,6 +418,7 @@
 		liteLLMAPIBase = '';
 		liteLLMAPIBase = '';
 		liteLLMAPIKey = '';
 		liteLLMAPIKey = '';
 		liteLLMRPM = '';
 		liteLLMRPM = '';
+		liteLLMMaxTokens = '';
 
 
 		liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token);
 		liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token);
 		models.set(await getModels());
 		models.set(await getModels());
@@ -376,7 +449,7 @@
 			return [];
 			return [];
 		});
 		});
 
 
-		if (OLLAMA_URLS.length > 1) {
+		if (OLLAMA_URLS.length > 0) {
 			selectedOllamaUrlIdx = 0;
 			selectedOllamaUrlIdx = 0;
 		}
 		}
 
 
@@ -391,18 +464,51 @@
 			<div class="space-y-2 pr-1.5">
 			<div class="space-y-2 pr-1.5">
 				<div class="text-sm font-medium">Manage Ollama Models</div>
 				<div class="text-sm font-medium">Manage Ollama Models</div>
 
 
-				{#if OLLAMA_URLS.length > 1}
-					<div class="flex-1 pb-1">
-						<select
-							class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
-							bind:value={selectedOllamaUrlIdx}
-							placeholder="Select an Ollama instance"
-						>
-							{#each OLLAMA_URLS as url, idx}
-								<option value={idx} class="bg-gray-100 dark:bg-gray-700">{url}</option>
-							{/each}
-						</select>
+				{#if OLLAMA_URLS.length > 0}
+					<div class="flex gap-2">
+						<div class="flex-1 pb-1">
+							<select
+								class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+								bind:value={selectedOllamaUrlIdx}
+								placeholder="Select an Ollama instance"
+							>
+								{#each OLLAMA_URLS as url, idx}
+									<option value={idx} class="bg-gray-100 dark:bg-gray-700">{url}</option>
+								{/each}
+							</select>
+						</div>
+
+						<div>
+							<div class="flex w-full justify-end">
+								<Tooltip content="Update All Models" placement="top">
+									<button
+										class="p-2.5 flex gap-2 items-center bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
+										on:click={() => {
+											updateModelsHandler();
+										}}
+									>
+										<svg
+											xmlns="http://www.w3.org/2000/svg"
+											viewBox="0 0 16 16"
+											fill="currentColor"
+											class="w-4 h-4"
+										>
+											<path
+												d="M7 1a.75.75 0 0 1 .75.75V6h-1.5V1.75A.75.75 0 0 1 7 1ZM6.25 6v2.94L5.03 7.72a.75.75 0 0 0-1.06 1.06l2.5 2.5a.75.75 0 0 0 1.06 0l2.5-2.5a.75.75 0 1 0-1.06-1.06L7.75 8.94V6H10a2 2 0 0 1 2 2v3a2 2 0 0 1-2 2H4a2 2 0 0 1-2-2V8a2 2 0 0 1 2-2h2.25Z"
+											/>
+											<path
+												d="M4.268 14A2 2 0 0 0 6 15h6a2 2 0 0 0 2-2v-3a2 2 0 0 0-1-1.732V11a3 3 0 0 1-3 3H4.268Z"
+											/>
+										</svg>
+									</button>
+								</Tooltip>
+							</div>
+						</div>
 					</div>
 					</div>
+
+					{#if updateModelId}
+						Updating "{updateModelId}" {updateProgress ? `(${updateProgress}%)` : ''}
+					{/if}
 				{/if}
 				{/if}
 
 
 				<div class="space-y-2">
 				<div class="space-y-2">
@@ -467,12 +573,14 @@
 							</button>
 							</button>
 						</div>
 						</div>
 
 
-						<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
-							To access the available model names for downloading, <a
-								class=" text-gray-500 dark:text-gray-300 font-medium underline"
-								href="https://ollama.com/library"
-								target="_blank">click here.</a
-							>
+						<div>
+							<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
+								To access the available model names for downloading, <a
+									class=" text-gray-500 dark:text-gray-300 font-medium underline"
+									href="https://ollama.com/library"
+									target="_blank">click here.</a
+								>
+							</div>
 						</div>
 						</div>
 
 
 						{#if Object.keys(modelDownloadStatus).length > 0}
 						{#if Object.keys(modelDownloadStatus).length > 0}
@@ -589,7 +697,7 @@
 												on:change={() => {
 												on:change={() => {
 													console.log(modelInputFile);
 													console.log(modelInputFile);
 												}}
 												}}
-												accept=".gguf"
+												accept=".gguf,.safetensors"
 												required
 												required
 												hidden
 												hidden
 											/>
 											/>
@@ -722,245 +830,193 @@
 		<div class=" space-y-3">
 		<div class=" space-y-3">
 			<div class="mt-2 space-y-3 pr-1.5">
 			<div class="mt-2 space-y-3 pr-1.5">
 				<div>
 				<div>
-					<div class=" mb-2 text-sm font-medium">Manage LiteLLM Models</div>
-
-					<div>
+					<div class="mb-2">
 						<div class="flex justify-between items-center text-xs">
 						<div class="flex justify-between items-center text-xs">
-							<div class=" text-sm font-medium">Add a model</div>
+							<div class=" text-sm font-medium">Manage LiteLLM Models</div>
 							<button
 							<button
 								class=" text-xs font-medium text-gray-500"
 								class=" text-xs font-medium text-gray-500"
 								type="button"
 								type="button"
 								on:click={() => {
 								on:click={() => {
-									showLiteLLMParams = !showLiteLLMParams;
-								}}>{showLiteLLMParams ? 'Hide Additional Params' : 'Show Additional Params'}</button
+									showLiteLLM = !showLiteLLM;
+								}}>{showLiteLLM ? 'Hide' : 'Show'}</button
 							>
 							>
 						</div>
 						</div>
 					</div>
 					</div>
 
 
-					<div class="my-2 space-y-2">
-						<div class="flex w-full mb-1.5">
-							<div class="flex-1 mr-2">
-								<input
-									class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
-									placeholder="Enter LiteLLM Model (litellm_params.model)"
-									bind:value={liteLLMModel}
-									autocomplete="off"
-								/>
+					{#if showLiteLLM}
+						<div>
+							<div class="flex justify-between items-center text-xs">
+								<div class=" text-sm font-medium">Add a model</div>
+								<button
+									class=" text-xs font-medium text-gray-500"
+									type="button"
+									on:click={() => {
+										showLiteLLMParams = !showLiteLLMParams;
+									}}
+									>{showLiteLLMParams ? 'Hide Additional Params' : 'Show Additional Params'}</button
+								>
 							</div>
 							</div>
+						</div>
 
 
-							<button
-								class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
-								on:click={() => {
-									addLiteLLMModelHandler();
-								}}
-							>
-								<svg
-									xmlns="http://www.w3.org/2000/svg"
-									viewBox="0 0 16 16"
-									fill="currentColor"
-									class="w-4 h-4"
-								>
-									<path
-										d="M8.75 3.75a.75.75 0 0 0-1.5 0v3.5h-3.5a.75.75 0 0 0 0 1.5h3.5v3.5a.75.75 0 0 0 1.5 0v-3.5h3.5a.75.75 0 0 0 0-1.5h-3.5v-3.5Z"
+						<div class="my-2 space-y-2">
+							<div class="flex w-full mb-1.5">
+								<div class="flex-1 mr-2">
+									<input
+										class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+										placeholder="Enter LiteLLM Model (litellm_params.model)"
+										bind:value={liteLLMModel}
+										autocomplete="off"
 									/>
 									/>
-								</svg>
-							</button>
-						</div>
+								</div>
 
 
-						{#if showLiteLLMParams}
-							<div>
-								<div class=" mb-1.5 text-sm font-medium">Model Name</div>
-								<div class="flex w-full">
-									<div class="flex-1">
-										<input
-											class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
-											placeholder="Enter Model Name (model_name)"
-											bind:value={liteLLMModelName}
-											autocomplete="off"
+								<button
+									class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
+									on:click={() => {
+										addLiteLLMModelHandler();
+									}}
+								>
+									<svg
+										xmlns="http://www.w3.org/2000/svg"
+										viewBox="0 0 16 16"
+										fill="currentColor"
+										class="w-4 h-4"
+									>
+										<path
+											d="M8.75 3.75a.75.75 0 0 0-1.5 0v3.5h-3.5a.75.75 0 0 0 0 1.5h3.5v3.5a.75.75 0 0 0 1.5 0v-3.5h3.5a.75.75 0 0 0 0-1.5h-3.5v-3.5Z"
 										/>
 										/>
-									</div>
-								</div>
+									</svg>
+								</button>
 							</div>
 							</div>
 
 
-							<div>
-								<div class=" mb-1.5 text-sm font-medium">API Base URL</div>
-								<div class="flex w-full">
-									<div class="flex-1">
-										<input
-											class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
-											placeholder="Enter LiteLLM API Base URL (litellm_params.api_base)"
-											bind:value={liteLLMAPIBase}
-											autocomplete="off"
-										/>
+							{#if showLiteLLMParams}
+								<div>
+									<div class=" mb-1.5 text-sm font-medium">Model Name</div>
+									<div class="flex w-full">
+										<div class="flex-1">
+											<input
+												class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+												placeholder="Enter Model Name (model_name)"
+												bind:value={liteLLMModelName}
+												autocomplete="off"
+											/>
+										</div>
 									</div>
 									</div>
 								</div>
 								</div>
-							</div>
 
 
-							<div>
-								<div class=" mb-1.5 text-sm font-medium">API Key</div>
-								<div class="flex w-full">
-									<div class="flex-1">
-										<input
-											class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
-											placeholder="Enter LiteLLM API Key (litellm_params.api_key)"
-											bind:value={liteLLMAPIKey}
-											autocomplete="off"
-										/>
+								<div>
+									<div class=" mb-1.5 text-sm font-medium">API Base URL</div>
+									<div class="flex w-full">
+										<div class="flex-1">
+											<input
+												class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+												placeholder="Enter LiteLLM API Base URL (litellm_params.api_base)"
+												bind:value={liteLLMAPIBase}
+												autocomplete="off"
+											/>
+										</div>
 									</div>
 									</div>
 								</div>
 								</div>
-							</div>
 
 
-							<div>
-								<div class="mb-1.5 text-sm font-medium">API RPM</div>
-								<div class="flex w-full">
-									<div class="flex-1">
-										<input
-											class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
-											placeholder="Enter LiteLLM API RPM (litellm_params.rpm)"
-											bind:value={liteLLMRPM}
-											autocomplete="off"
-										/>
+								<div>
+									<div class=" mb-1.5 text-sm font-medium">API Key</div>
+									<div class="flex w-full">
+										<div class="flex-1">
+											<input
+												class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+												placeholder="Enter LiteLLM API Key (litellm_params.api_key)"
+												bind:value={liteLLMAPIKey}
+												autocomplete="off"
+											/>
+										</div>
 									</div>
 									</div>
 								</div>
 								</div>
-							</div>
-						{/if}
-					</div>
-
-					<div class="mb-2 text-xs text-gray-400 dark:text-gray-500">
-						Not sure what to add?
-						<a
-							class=" text-gray-300 font-medium underline"
-							href="https://litellm.vercel.app/docs/proxy/configs#quick-start"
-							target="_blank"
-						>
-							Click here for help.
-						</a>
-					</div>
-
-					<div>
-						<div class=" mb-2.5 text-sm font-medium">Delete a model</div>
-						<div class="flex w-full">
-							<div class="flex-1 mr-2">
-								<select
-									class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
-									bind:value={deleteLiteLLMModelId}
-									placeholder="Select a model"
-								>
-									{#if !deleteLiteLLMModelId}
-										<option value="" disabled selected>Select a model</option>
-									{/if}
-									{#each liteLLMModelInfo as model}
-										<option value={model.model_info.id} class="bg-gray-100 dark:bg-gray-700"
-											>{model.model_name}</option
-										>
-									{/each}
-								</select>
-							</div>
-							<button
-								class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
-								on:click={() => {
-									deleteLiteLLMModelHandler();
-								}}
-							>
-								<svg
-									xmlns="http://www.w3.org/2000/svg"
-									viewBox="0 0 16 16"
-									fill="currentColor"
-									class="w-4 h-4"
-								>
-									<path
-										fill-rule="evenodd"
-										d="M5 3.25V4H2.75a.75.75 0 0 0 0 1.5h.3l.815 8.15A1.5 1.5 0 0 0 5.357 15h5.285a1.5 1.5 0 0 0 1.493-1.35l.815-8.15h.3a.75.75 0 0 0 0-1.5H11v-.75A2.25 2.25 0 0 0 8.75 1h-1.5A2.25 2.25 0 0 0 5 3.25Zm2.25-.75a.75.75 0 0 0-.75.75V4h3v-.75a.75.75 0 0 0-.75-.75h-1.5ZM6.05 6a.75.75 0 0 1 .787.713l.275 5.5a.75.75 0 0 1-1.498.075l-.275-5.5A.75.75 0 0 1 6.05 6Zm3.9 0a.75.75 0 0 1 .712.787l-.275 5.5a.75.75 0 0 1-1.498-.075l.275-5.5a.75.75 0 0 1 .786-.711Z"
-										clip-rule="evenodd"
-									/>
-								</svg>
-							</button>
-						</div>
-					</div>
-				</div>
-			</div>
-
-			<!-- <div class="mt-2 space-y-3 pr-1.5">
-				<div>
-					<div class=" mb-2.5 text-sm font-medium">Add LiteLLM Model</div>
-					<div class="flex w-full mb-2">
-						<div class="flex-1">
-							<input
-								class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-								placeholder="Enter LiteLLM Model (e.g. ollama/mistral)"
-								bind:value={liteLLMModel}
-								autocomplete="off"
-							/>
-						</div>
-					</div>
 
 
-					<div class="flex justify-between items-center text-sm">
-						<div class="  font-medium">Advanced Model Params</div>
-						<button
-							class=" text-xs font-medium text-gray-500"
-							type="button"
-							on:click={() => {
-								showLiteLLMParams = !showLiteLLMParams;
-							}}>{showLiteLLMParams ? 'Hide' : 'Show'}</button
-						>
-					</div>
+								<div>
+									<div class="mb-1.5 text-sm font-medium">API RPM</div>
+									<div class="flex w-full">
+										<div class="flex-1">
+											<input
+												class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+												placeholder="Enter LiteLLM API RPM (litellm_params.rpm)"
+												bind:value={liteLLMRPM}
+												autocomplete="off"
+											/>
+										</div>
+									</div>
+								</div>
 
 
-					{#if showLiteLLMParams}
-						<div>
-							<div class=" mb-2.5 text-sm font-medium">LiteLLM API Key</div>
-							<div class="flex w-full">
-								<div class="flex-1">
-									<input
-										class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-										placeholder="Enter LiteLLM API Key (e.g. os.environ/AZURE_API_KEY_CA)"
-										bind:value={liteLLMAPIKey}
-										autocomplete="off"
-									/>
+								<div>
+									<div class="mb-1.5 text-sm font-medium">Max Tokens</div>
+									<div class="flex w-full">
+										<div class="flex-1">
+											<input
+												class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+												placeholder="Enter Max Tokens (litellm_params.max_tokens)"
+												bind:value={liteLLMMaxTokens}
+												type="number"
+												min="1"
+												autocomplete="off"
+											/>
+										</div>
+									</div>
 								</div>
 								</div>
-							</div>
+							{/if}
 						</div>
 						</div>
 
 
-						<div>
-							<div class=" mb-2.5 text-sm font-medium">LiteLLM API Base URL</div>
-							<div class="flex w-full">
-								<div class="flex-1">
-									<input
-										class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-										placeholder="Enter LiteLLM API Base URL"
-										bind:value={liteLLMAPIBase}
-										autocomplete="off"
-									/>
-								</div>
-							</div>
+						<div class="mb-2 text-xs text-gray-400 dark:text-gray-500">
+							Not sure what to add?
+							<a
+								class=" text-gray-300 font-medium underline"
+								href="https://litellm.vercel.app/docs/proxy/configs#quick-start"
+								target="_blank"
+							>
+								Click here for help.
+							</a>
 						</div>
 						</div>
 
 
 						<div>
 						<div>
-							<div class=" mb-2.5 text-sm font-medium">LiteLLM API RPM</div>
+							<div class=" mb-2.5 text-sm font-medium">Delete a model</div>
 							<div class="flex w-full">
 							<div class="flex w-full">
-								<div class="flex-1">
-									<input
-										class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-										placeholder="Enter LiteLLM API RPM"
-										bind:value={liteLLMRPM}
-										autocomplete="off"
-									/>
+								<div class="flex-1 mr-2">
+									<select
+										class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+										bind:value={deleteLiteLLMModelId}
+										placeholder="Select a model"
+									>
+										{#if !deleteLiteLLMModelId}
+											<option value="" disabled selected>Select a model</option>
+										{/if}
+										{#each liteLLMModelInfo as model}
+											<option value={model.model_info.id} class="bg-gray-100 dark:bg-gray-700"
+												>{model.model_name}</option
+											>
+										{/each}
+									</select>
 								</div>
 								</div>
+								<button
+									class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
+									on:click={() => {
+										deleteLiteLLMModelHandler();
+									}}
+								>
+									<svg
+										xmlns="http://www.w3.org/2000/svg"
+										viewBox="0 0 16 16"
+										fill="currentColor"
+										class="w-4 h-4"
+									>
+										<path
+											fill-rule="evenodd"
+											d="M5 3.25V4H2.75a.75.75 0 0 0 0 1.5h.3l.815 8.15A1.5 1.5 0 0 0 5.357 15h5.285a1.5 1.5 0 0 0 1.493-1.35l.815-8.15h.3a.75.75 0 0 0 0-1.5H11v-.75A2.25 2.25 0 0 0 8.75 1h-1.5A2.25 2.25 0 0 0 5 3.25Zm2.25-.75a.75.75 0 0 0-.75.75V4h3v-.75a.75.75 0 0 0-.75-.75h-1.5ZM6.05 6a.75.75 0 0 1 .787.713l.275 5.5a.75.75 0 0 1-1.498.075l-.275-5.5A.75.75 0 0 1 6.05 6Zm3.9 0a.75.75 0 0 1 .712.787l-.275 5.5a.75.75 0 0 1-1.498-.075l.275-5.5a.75.75 0 0 1 .786-.711Z"
+											clip-rule="evenodd"
+										/>
+									</svg>
+								</button>
 							</div>
 							</div>
 						</div>
 						</div>
 					{/if}
 					{/if}
-
-					<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
-						Not sure what to add?
-						<a
-							class=" text-gray-300 font-medium underline"
-							href="https://litellm.vercel.app/docs/proxy/configs#quick-start"
-							target="_blank"
-						>
-							Click here for help.
-						</a>
-					</div>
 				</div>
 				</div>
-			</div> -->
+			</div>
 		</div>
 		</div>
 	</div>
 	</div>
 </div>
 </div>

+ 6 - 6
src/lib/components/chat/SettingsModal.svelte

@@ -326,7 +326,7 @@
 						{getModels}
 						{getModels}
 						{saveSettings}
 						{saveSettings}
 						on:save={() => {
 						on:save={() => {
-							show = false;
+							toast.success('Settings saved successfully!');
 						}}
 						}}
 					/>
 					/>
 				{:else if selectedTab === 'models'}
 				{:else if selectedTab === 'models'}
@@ -335,28 +335,28 @@
 					<Connections
 					<Connections
 						{getModels}
 						{getModels}
 						on:save={() => {
 						on:save={() => {
-							show = false;
+							toast.success('Settings saved successfully!');
 						}}
 						}}
 					/>
 					/>
 				{:else if selectedTab === 'interface'}
 				{:else if selectedTab === 'interface'}
 					<Interface
 					<Interface
 						{saveSettings}
 						{saveSettings}
 						on:save={() => {
 						on:save={() => {
-							show = false;
+							toast.success('Settings saved successfully!');
 						}}
 						}}
 					/>
 					/>
 				{:else if selectedTab === 'audio'}
 				{:else if selectedTab === 'audio'}
 					<Audio
 					<Audio
 						{saveSettings}
 						{saveSettings}
 						on:save={() => {
 						on:save={() => {
-							show = false;
+							toast.success('Settings saved successfully!');
 						}}
 						}}
 					/>
 					/>
 				{:else if selectedTab === 'images'}
 				{:else if selectedTab === 'images'}
 					<Images
 					<Images
 						{saveSettings}
 						{saveSettings}
 						on:save={() => {
 						on:save={() => {
-							show = false;
+							toast.success('Settings saved successfully!');
 						}}
 						}}
 					/>
 					/>
 				{:else if selectedTab === 'chats'}
 				{:else if selectedTab === 'chats'}
@@ -364,7 +364,7 @@
 				{:else if selectedTab === 'account'}
 				{:else if selectedTab === 'account'}
 					<Account
 					<Account
 						saveHandler={() => {
 						saveHandler={() => {
-							show = false;
+							toast.success('Settings saved successfully!');
 						}}
 						}}
 					/>
 					/>
 				{:else if selectedTab === 'about'}
 				{:else if selectedTab === 'about'}

+ 7 - 2
src/lib/components/common/Image.svelte

@@ -1,18 +1,23 @@
 <script lang="ts">
 <script lang="ts">
+	import { WEBUI_BASE_URL } from '$lib/constants';
 	import ImagePreview from './ImagePreview.svelte';
 	import ImagePreview from './ImagePreview.svelte';
 
 
 	export let src = '';
 	export let src = '';
 	export let alt = '';
 	export let alt = '';
 
 
+	let _src = '';
+
+	$: _src = src.startsWith('/') ? `${WEBUI_BASE_URL}${src}` : src;
+
 	let showImagePreview = false;
 	let showImagePreview = false;
 </script>
 </script>
 
 
-<ImagePreview bind:show={showImagePreview} {src} {alt} />
+<ImagePreview bind:show={showImagePreview} src={_src} {alt} />
 <button
 <button
 	on:click={() => {
 	on:click={() => {
 		console.log('image preview');
 		console.log('image preview');
 		showImagePreview = true;
 		showImagePreview = true;
 	}}
 	}}
 >
 >
-	<img {src} {alt} class=" max-h-96 rounded-lg" draggable="false" />
+	<img src={_src} {alt} class=" max-h-96 rounded-lg" draggable="false" />
 </button>
 </button>

+ 99 - 72
src/lib/components/documents/Settings/General.svelte

@@ -1,10 +1,10 @@
 <script lang="ts">
 <script lang="ts">
 	import { getDocs } from '$lib/apis/documents';
 	import { getDocs } from '$lib/apis/documents';
 	import {
 	import {
-		getChunkParams,
+		getRAGConfig,
+		updateRAGConfig,
 		getQuerySettings,
 		getQuerySettings,
 		scanDocs,
 		scanDocs,
-		updateChunkParams,
 		updateQuerySettings
 		updateQuerySettings
 	} from '$lib/apis/rag';
 	} from '$lib/apis/rag';
 	import { documents } from '$lib/stores';
 	import { documents } from '$lib/stores';
@@ -17,6 +17,7 @@
 
 
 	let chunkSize = 0;
 	let chunkSize = 0;
 	let chunkOverlap = 0;
 	let chunkOverlap = 0;
+	let pdfExtractImages = true;
 
 
 	let querySettings = {
 	let querySettings = {
 		template: '',
 		template: '',
@@ -35,16 +36,24 @@
 	};
 	};
 
 
 	const submitHandler = async () => {
 	const submitHandler = async () => {
-		const res = await updateChunkParams(localStorage.token, chunkSize, chunkOverlap);
+		const res = await updateRAGConfig(localStorage.token, {
+			pdf_extract_images: pdfExtractImages,
+			chunk: {
+				chunk_overlap: chunkOverlap,
+				chunk_size: chunkSize
+			}
+		});
 		querySettings = await updateQuerySettings(localStorage.token, querySettings);
 		querySettings = await updateQuerySettings(localStorage.token, querySettings);
 	};
 	};
 
 
 	onMount(async () => {
 	onMount(async () => {
-		const res = await getChunkParams(localStorage.token);
+		const res = await getRAGConfig(localStorage.token);
 
 
 		if (res) {
 		if (res) {
-			chunkSize = res.chunk_size;
-			chunkOverlap = res.chunk_overlap;
+			pdfExtractImages = res.pdf_extract_images;
+
+			chunkSize = res.chunk.chunk_size;
+			chunkOverlap = res.chunk.chunk_overlap;
 		}
 		}
 
 
 		querySettings = await getQuerySettings(localStorage.token);
 		querySettings = await getQuerySettings(localStorage.token);
@@ -124,82 +133,100 @@
 
 
 		<hr class=" dark:border-gray-700" />
 		<hr class=" dark:border-gray-700" />
 
 
-		<div class=" ">
-			<div class=" text-sm font-medium">Chunk Params</div>
-
-			<div class=" flex">
-				<div class="  flex w-full justify-between">
-					<div class="self-center text-xs font-medium min-w-fit">Chunk Size</div>
-
-					<div class="self-center p-3">
-						<input
-							class=" w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
-							type="number"
-							placeholder="Enter Chunk Size"
-							bind:value={chunkSize}
-							autocomplete="off"
-							min="0"
-						/>
+		<div class=" space-y-3">
+			<div class=" space-y-3">
+				<div class=" text-sm font-medium">Chunk Params</div>
+
+				<div class=" flex gap-2">
+					<div class="  flex w-full justify-between gap-2">
+						<div class="self-center text-xs font-medium min-w-fit">Chunk Size</div>
+
+						<div class="self-center">
+							<input
+								class=" w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
+								type="number"
+								placeholder="Enter Chunk Size"
+								bind:value={chunkSize}
+								autocomplete="off"
+								min="0"
+							/>
+						</div>
 					</div>
 					</div>
-				</div>
 
 
-				<div class="flex w-full">
-					<div class=" self-center text-xs font-medium min-w-fit">Chunk Overlap</div>
-
-					<div class="self-center p-3">
-						<input
-							class="w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
-							type="number"
-							placeholder="Enter Chunk Overlap"
-							bind:value={chunkOverlap}
-							autocomplete="off"
-							min="0"
-						/>
+					<div class="flex w-full gap-2">
+						<div class=" self-center text-xs font-medium min-w-fit">Chunk Overlap</div>
+
+						<div class="self-center">
+							<input
+								class="w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
+								type="number"
+								placeholder="Enter Chunk Overlap"
+								bind:value={chunkOverlap}
+								autocomplete="off"
+								min="0"
+							/>
+						</div>
 					</div>
 					</div>
 				</div>
 				</div>
-			</div>
-
-			<div class=" text-sm font-medium">Query Params</div>
 
 
-			<div class=" flex">
-				<div class="  flex w-full justify-between">
-					<div class="self-center text-xs font-medium flex-1">Top K</div>
-
-					<div class="self-center p-3">
-						<input
-							class=" w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
-							type="number"
-							placeholder="Enter Top K"
-							bind:value={querySettings.k}
-							autocomplete="off"
-							min="0"
-						/>
+				<div>
+					<div class="flex justify-between items-center text-xs">
+						<div class=" text-xs font-medium">PDF Extract Images (OCR)</div>
+
+						<button
+							class=" text-xs font-medium text-gray-500"
+							type="button"
+							on:click={() => {
+								pdfExtractImages = !pdfExtractImages;
+							}}>{pdfExtractImages ? 'On' : 'Off'}</button
+						>
 					</div>
 					</div>
 				</div>
 				</div>
-
-				<!-- <div class="flex w-full">
-					<div class=" self-center text-xs font-medium min-w-fit">Chunk Overlap</div>
-
-					<div class="self-center p-3">
-						<input
-							class="w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
-							type="number"
-							placeholder="Enter Chunk Overlap"
-							bind:value={chunkOverlap}
-							autocomplete="off"
-							min="0"
-						/>
-					</div>
-				</div> -->
 			</div>
 			</div>
 
 
 			<div>
 			<div>
-				<div class=" mb-2.5 text-sm font-medium">RAG Template</div>
-				<textarea
-					bind:value={querySettings.template}
-					class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none"
-					rows="4"
-				/>
+				<div class=" text-sm font-medium">Query Params</div>
+
+				<div class=" flex py-2">
+					<div class="  flex w-full justify-between gap-2">
+						<div class="self-center text-xs font-medium flex-1">Top K</div>
+
+						<div class="self-center">
+							<input
+								class=" w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
+								type="number"
+								placeholder="Enter Top K"
+								bind:value={querySettings.k}
+								autocomplete="off"
+								min="0"
+							/>
+						</div>
+					</div>
+
+					<!-- <div class="flex w-full">
+						<div class=" self-center text-xs font-medium min-w-fit">Chunk Overlap</div>
+	
+						<div class="self-center p-3">
+							<input
+								class="w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
+								type="number"
+								placeholder="Enter Chunk Overlap"
+								bind:value={chunkOverlap}
+								autocomplete="off"
+								min="0"
+							/>
+						</div>
+					</div> -->
+				</div>
+
+				<div>
+					<div class=" mb-2.5 text-sm font-medium">RAG Template</div>
+					<textarea
+						bind:value={querySettings.template}
+						class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none"
+						rows="4"
+					/>
+				</div>
 			</div>
 			</div>
 		</div>
 		</div>
 	</div>
 	</div>

+ 12 - 7
src/lib/components/layout/Sidebar.svelte

@@ -61,12 +61,16 @@
 	};
 	};
 
 
 	const editChatTitle = async (id, _title) => {
 	const editChatTitle = async (id, _title) => {
-		title = _title;
-
-		await updateChatById(localStorage.token, id, {
-			title: _title
-		});
-		await chats.set(await getChatList(localStorage.token));
+		if (_title === '') {
+			toast.error('Title cannot be an empty string.');
+		} else {
+			title = _title;
+
+			await updateChatById(localStorage.token, id, {
+				title: _title
+			});
+			await chats.set(await getChatList(localStorage.token));
+		}
 	};
 	};
 
 
 	const deleteChat = async (id) => {
 	const deleteChat = async (id) => {
@@ -388,12 +392,13 @@
 										show = false;
 										show = false;
 									}
 									}
 								}}
 								}}
+								draggable="false"
 							>
 							>
 								<div class=" flex self-center flex-1 w-full">
 								<div class=" flex self-center flex-1 w-full">
 									<div
 									<div
 										class=" text-left self-center overflow-hidden {chat.id === $chatId
 										class=" text-left self-center overflow-hidden {chat.id === $chatId
 											? 'w-[160px]'
 											? 'w-[160px]'
-											: 'w-full'} "
+											: 'w-full'}  h-[20px]"
 									>
 									>
 										{chat.title}
 										{chat.title}
 									</div>
 									</div>

+ 45 - 59
src/routes/(app)/+page.svelte

@@ -232,53 +232,6 @@
 	const sendPrompt = async (prompt, parentId) => {
 	const sendPrompt = async (prompt, parentId) => {
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 
 
-		const docs = messages
-			.filter((message) => message?.files ?? null)
-			.map((message) =>
-				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
-			)
-			.flat(1);
-
-		console.log(docs);
-		if (docs.length > 0) {
-			processing = 'Reading';
-			const query = history.messages[parentId].content;
-
-			let relevantContexts = await Promise.all(
-				docs.map(async (doc) => {
-					if (doc.type === 'collection') {
-						return await queryCollection(localStorage.token, doc.collection_names, query).catch(
-							(error) => {
-								console.log(error);
-								return null;
-							}
-						);
-					} else {
-						return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => {
-							console.log(error);
-							return null;
-						});
-					}
-				})
-			);
-			relevantContexts = relevantContexts.filter((context) => context);
-
-			const contextString = relevantContexts.reduce((a, context, i, arr) => {
-				return `${a}${context.documents.join(' ')}\n`;
-			}, '');
-
-			console.log(contextString);
-
-			history.messages[parentId].raContent = await RAGTemplate(
-				localStorage.token,
-				contextString,
-				query
-			);
-			history.messages[parentId].contexts = relevantContexts;
-			await tick();
-			processing = '';
-		}
-
 		await Promise.all(
 		await Promise.all(
 			selectedModels.map(async (modelId) => {
 			selectedModels.map(async (modelId) => {
 				const model = $models.filter((m) => m.id === modelId).at(0);
 				const model = $models.filter((m) => m.id === modelId).at(0);
@@ -342,15 +295,25 @@
 			...messages
 			...messages
 		]
 		]
 			.filter((message) => message)
 			.filter((message) => message)
-			.map((message, idx, arr) => ({
-				role: message.role,
-				content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content,
-				...(message.files && {
-					images: message.files
-						.filter((file) => file.type === 'image')
-						.map((file) => file.url.slice(file.url.indexOf(',') + 1))
-				})
-			}));
+			.map((message, idx, arr) => {
+				// Prepare the base message object
+				const baseMessage = {
+					role: message.role,
+					content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content
+				};
+
+				// Extract and format image URLs if any exist
+				const imageUrls = message.files
+					?.filter((file) => file.type === 'image')
+					.map((file) => file.url.slice(file.url.indexOf(',') + 1));
+
+				// Add images array only if it contains elements
+				if (imageUrls && imageUrls.length > 0) {
+					baseMessage.images = imageUrls;
+				}
+
+				return baseMessage;
+			});
 
 
 		let lastImageIndex = -1;
 		let lastImageIndex = -1;
 
 
@@ -368,6 +331,13 @@
 			}
 			}
 		});
 		});
 
 
+		const docs = messages
+			.filter((message) => message?.files ?? null)
+			.map((message) =>
+				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+			)
+			.flat(1);
+
 		const [res, controller] = await generateChatCompletion(localStorage.token, {
 		const [res, controller] = await generateChatCompletion(localStorage.token, {
 			model: model,
 			model: model,
 			messages: messagesBody,
 			messages: messagesBody,
@@ -375,7 +345,8 @@
 				...($settings.options ?? {})
 				...($settings.options ?? {})
 			},
 			},
 			format: $settings.requestFormat ?? undefined,
 			format: $settings.requestFormat ?? undefined,
-			keep_alive: $settings.keepAlive ?? undefined
+			keep_alive: $settings.keepAlive ?? undefined,
+			docs: docs.length > 0 ? docs : undefined
 		});
 		});
 
 
 		if (res && res.ok) {
 		if (res && res.ok) {
@@ -535,6 +506,15 @@
 		const responseMessage = history.messages[responseMessageId];
 		const responseMessage = history.messages[responseMessageId];
 		scrollToBottom();
 		scrollToBottom();
 
 
+		const docs = messages
+			.filter((message) => message?.files ?? null)
+			.map((message) =>
+				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+			)
+			.flat(1);
+
+		console.log(docs);
+
 		const res = await generateOpenAIChatCompletion(
 		const res = await generateOpenAIChatCompletion(
 			localStorage.token,
 			localStorage.token,
 			{
 			{
@@ -583,7 +563,8 @@
 				top_p: $settings?.options?.top_p ?? undefined,
 				top_p: $settings?.options?.top_p ?? undefined,
 				num_ctx: $settings?.options?.num_ctx ?? undefined,
 				num_ctx: $settings?.options?.num_ctx ?? undefined,
 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
-				max_tokens: $settings?.options?.num_predict ?? undefined
+				max_tokens: $settings?.options?.num_predict ?? undefined,
+				docs: docs.length > 0 ? docs : undefined
 			},
 			},
 			model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
 			model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
 		);
 		);
@@ -694,7 +675,12 @@
 
 
 		if (messages.length == 2) {
 		if (messages.length == 2) {
 			window.history.replaceState(history.state, '', `/c/${_chatId}`);
 			window.history.replaceState(history.state, '', `/c/${_chatId}`);
-			await setChatTitle(_chatId, userPrompt);
+
+			if ($settings?.titleAutoGenerateModel) {
+				await generateChatTitle(_chatId, userPrompt);
+			} else {
+				await setChatTitle(_chatId, userPrompt);
+			}
 		}
 		}
 	};
 	};
 
 

+ 40 - 58
src/routes/(app)/c/[id]/+page.svelte

@@ -245,53 +245,6 @@
 	const sendPrompt = async (prompt, parentId) => {
 	const sendPrompt = async (prompt, parentId) => {
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 
 
-		const docs = messages
-			.filter((message) => message?.files ?? null)
-			.map((message) =>
-				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
-			)
-			.flat(1);
-
-		console.log(docs);
-		if (docs.length > 0) {
-			processing = 'Reading';
-			const query = history.messages[parentId].content;
-
-			let relevantContexts = await Promise.all(
-				docs.map(async (doc) => {
-					if (doc.type === 'collection') {
-						return await queryCollection(localStorage.token, doc.collection_names, query).catch(
-							(error) => {
-								console.log(error);
-								return null;
-							}
-						);
-					} else {
-						return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => {
-							console.log(error);
-							return null;
-						});
-					}
-				})
-			);
-			relevantContexts = relevantContexts.filter((context) => context);
-
-			const contextString = relevantContexts.reduce((a, context, i, arr) => {
-				return `${a}${context.documents.join(' ')}\n`;
-			}, '');
-
-			console.log(contextString);
-
-			history.messages[parentId].raContent = await RAGTemplate(
-				localStorage.token,
-				contextString,
-				query
-			);
-			history.messages[parentId].contexts = relevantContexts;
-			await tick();
-			processing = '';
-		}
-
 		await Promise.all(
 		await Promise.all(
 			selectedModels.map(async (modelId) => {
 			selectedModels.map(async (modelId) => {
 				const model = $models.filter((m) => m.id === modelId).at(0);
 				const model = $models.filter((m) => m.id === modelId).at(0);
@@ -355,15 +308,25 @@
 			...messages
 			...messages
 		]
 		]
 			.filter((message) => message)
 			.filter((message) => message)
-			.map((message, idx, arr) => ({
-				role: message.role,
-				content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content,
-				...(message.files && {
-					images: message.files
-						.filter((file) => file.type === 'image')
-						.map((file) => file.url.slice(file.url.indexOf(',') + 1))
-				})
-			}));
+			.map((message, idx, arr) => {
+				// Prepare the base message object
+				const baseMessage = {
+					role: message.role,
+					content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content
+				};
+
+				// Extract and format image URLs if any exist
+				const imageUrls = message.files
+					?.filter((file) => file.type === 'image')
+					.map((file) => file.url.slice(file.url.indexOf(',') + 1));
+
+				// Add images array only if it contains elements
+				if (imageUrls && imageUrls.length > 0) {
+					baseMessage.images = imageUrls;
+				}
+
+				return baseMessage;
+			});
 
 
 		let lastImageIndex = -1;
 		let lastImageIndex = -1;
 
 
@@ -381,6 +344,13 @@
 			}
 			}
 		});
 		});
 
 
+		const docs = messages
+			.filter((message) => message?.files ?? null)
+			.map((message) =>
+				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+			)
+			.flat(1);
+
 		const [res, controller] = await generateChatCompletion(localStorage.token, {
 		const [res, controller] = await generateChatCompletion(localStorage.token, {
 			model: model,
 			model: model,
 			messages: messagesBody,
 			messages: messagesBody,
@@ -388,7 +358,8 @@
 				...($settings.options ?? {})
 				...($settings.options ?? {})
 			},
 			},
 			format: $settings.requestFormat ?? undefined,
 			format: $settings.requestFormat ?? undefined,
-			keep_alive: $settings.keepAlive ?? undefined
+			keep_alive: $settings.keepAlive ?? undefined,
+			docs: docs.length > 0 ? docs : undefined
 		});
 		});
 
 
 		if (res && res.ok) {
 		if (res && res.ok) {
@@ -548,6 +519,15 @@
 		const responseMessage = history.messages[responseMessageId];
 		const responseMessage = history.messages[responseMessageId];
 		scrollToBottom();
 		scrollToBottom();
 
 
+		const docs = messages
+			.filter((message) => message?.files ?? null)
+			.map((message) =>
+				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+			)
+			.flat(1);
+
+		console.log(docs);
+
 		const res = await generateOpenAIChatCompletion(
 		const res = await generateOpenAIChatCompletion(
 			localStorage.token,
 			localStorage.token,
 			{
 			{
@@ -596,7 +576,8 @@
 				top_p: $settings?.options?.top_p ?? undefined,
 				top_p: $settings?.options?.top_p ?? undefined,
 				num_ctx: $settings?.options?.num_ctx ?? undefined,
 				num_ctx: $settings?.options?.num_ctx ?? undefined,
 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
-				max_tokens: $settings?.options?.num_predict ?? undefined
+				max_tokens: $settings?.options?.num_predict ?? undefined,
+				docs: docs.length > 0 ? docs : undefined
 			},
 			},
 			model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
 			model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
 		);
 		);
@@ -710,6 +691,7 @@
 			await setChatTitle(_chatId, userPrompt);
 			await setChatTitle(_chatId, userPrompt);
 		}
 		}
 	};
 	};
+
 	const stopResponse = () => {
 	const stopResponse = () => {
 		stopResponseFlag = true;
 		stopResponseFlag = true;
 		console.log('stopResponse');
 		console.log('stopResponse');

+ 1 - 1
src/routes/(app)/playground/+page.svelte

@@ -267,7 +267,7 @@
 
 
 <div class="min-h-screen max-h-[100dvh] w-full flex justify-center dark:text-white">
 <div class="min-h-screen max-h-[100dvh] w-full flex justify-center dark:text-white">
 	<div class=" flex flex-col justify-between w-full overflow-y-auto h-[100dvh]">
 	<div class=" flex flex-col justify-between w-full overflow-y-auto h-[100dvh]">
-		<div class="max-w-2xl mx-auto w-full px-3 p-3 md:px-0 h-full">
+		<div class="max-w-2xl mx-auto w-full px-3 md:px-0 my-10 h-full">
 			<div class=" flex flex-col h-full">
 			<div class=" flex flex-col h-full">
 				<div class="flex flex-col justify-between mb-2.5 gap-1">
 				<div class="flex flex-col justify-between mb-2.5 gap-1">
 					<div class="flex justify-between items-center gap-2">
 					<div class="flex justify-between items-center gap-2">