Pārlūkot izejas kodu

Merge pull request #1107 from open-webui/dev

0.1.111
Timothy Jaeryang Baek 1 gadu atpakaļ
vecāks
revīzija
89634046e7
37 mainītis faili ar 1554 papildinājumiem un 765 dzēšanām
  1. 5 5
      .github/workflows/build-release.yml
  2. 18 0
      CHANGELOG.md
  3. 2 2
      Dockerfile
  4. 10 2
      README.md
  5. 190 53
      backend/apps/images/main.py
  6. 41 0
      backend/apps/litellm/main.py
  7. 17 3
      backend/apps/ollama/main.py
  8. 22 4
      backend/apps/openai/main.py
  9. 38 88
      backend/apps/rag/main.py
  10. 97 0
      backend/apps/rag/utils.py
  11. 6 1
      backend/config.py
  12. 1 0
      backend/data/config.json
  13. 175 38
      backend/main.py
  14. 2 1
      backend/requirements.txt
  15. 1 1
      package.json
  16. 83 8
      src/lib/apis/images/index.ts
  17. 62 0
      src/lib/apis/index.ts
  18. 3 1
      src/lib/apis/litellm/index.ts
  19. 16 7
      src/lib/apis/rag/index.ts
  20. 113 0
      src/lib/components/admin/Settings/Users.svelte
  21. 3 3
      src/lib/components/chat/MessageInput.svelte
  22. 14 11
      src/lib/components/chat/Messages/ResponseMessage.svelte
  23. 3 3
      src/lib/components/chat/Messages/UserMessage.svelte
  24. 1 1
      src/lib/components/chat/Settings/Account.svelte
  25. 1 1
      src/lib/components/chat/Settings/Audio.svelte
  26. 1 1
      src/lib/components/chat/Settings/Connections.svelte
  27. 2 2
      src/lib/components/chat/Settings/General.svelte
  28. 135 76
      src/lib/components/chat/Settings/Images.svelte
  29. 5 27
      src/lib/components/chat/Settings/Interface.svelte
  30. 277 221
      src/lib/components/chat/Settings/Models.svelte
  31. 6 6
      src/lib/components/chat/SettingsModal.svelte
  32. 7 2
      src/lib/components/common/Image.svelte
  33. 99 72
      src/lib/components/documents/Settings/General.svelte
  34. 12 7
      src/lib/components/layout/Sidebar.svelte
  35. 45 59
      src/routes/(app)/+page.svelte
  36. 40 58
      src/routes/(app)/c/[id]/+page.svelte
  37. 1 1
      src/routes/(app)/playground/+page.svelte

+ 5 - 5
.github/workflows/build-release.yml

@@ -29,11 +29,11 @@ jobs:
     - name: Extract latest CHANGELOG entry
       id: changelog
       run: |
-        CHANGELOG_CONTENT=$(awk '/^## \[/{n++} n==1' CHANGELOG.md)
-        echo "CHANGELOG_CONTENT<<EOF" 
-        echo "$CHANGELOG_CONTENT"
-        echo "EOF" 
-        echo "::set-output name=content::${CHANGELOG_CONTENT}"
+        CHANGELOG_CONTENT=$(awk 'BEGIN {print_section=0;} /^## \[/ {if (print_section == 0) {print_section=1;} else {exit;}} print_section {print;}' CHANGELOG.md)
+        CHANGELOG_ESCAPED=$(echo "$CHANGELOG_CONTENT" | sed ':a;N;$!ba;s/\n/%0A/g')
+        echo "Extracted latest release notes from CHANGELOG.md:" 
+        echo -e "$CHANGELOG_CONTENT" 
+        echo "::set-output name=content::$CHANGELOG_ESCAPED"
 
     - name: Create GitHub release
       uses: actions/github-script@v5

+ 18 - 0
CHANGELOG.md

@@ -5,6 +5,24 @@ All notable changes to this project will be documented in this file.
 The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
 and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
 
+## [0.1.111] - 2024-03-10
+
+### Added
+
+- 🛡️ **Model Whitelisting**: Admins now have the ability to whitelist models for users with the 'user' role.
+- 🔄 **Update All Models**: Added a convenient button to update all models at once.
+- 📄 **Toggle PDF OCR**: Users can now toggle PDF OCR option for improved parsing performance.
+- 🎨 **DALL-E Integration**: Introduced DALL-E integration for image generation alongside automatic1111.
+- 🛠️ **RAG API Refactoring**: Refactored RAG logic and exposed its API, with additional documentation to follow.
+
+### Fixed
+
+- 🔒 **Max Token Settings**: Added max token settings for anthropic/claude-3-sonnet-20240229 (Issue #1094).
+- 🔧 **Misalignment Issue**: Corrected misalignment of Edit and Delete Icons when Chat Title is Empty (Issue #1104).
+- 🔄 **Context Loss Fix**: Resolved RAG losing context on model response regeneration with Groq models via API key (Issue #1105).
+- 📁 **File Handling Bug**: Addressed File Not Found Notification when Dropping a Conversation Element (Issue #1098).
+- 🖱️ **Dragged File Styling**: Fixed dragged file layover styling issue.
+
 ## [0.1.110] - 2024-03-06
 
 ### Added

+ 2 - 2
Dockerfile

@@ -41,7 +41,7 @@ ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
 # for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
 # IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
 ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2"
-# device type for whisper tts and ebbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
+# device type for whisper tts and embbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
 ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu"
 ENV RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models"
 ENV SENTENCE_TRANSFORMERS_HOME $RAG_EMBEDDING_MODEL_DIR
@@ -81,4 +81,4 @@ COPY --from=build /app/package.json /app/package.json
 # copy backend files
 COPY ./backend .
 
-CMD [ "bash", "start.sh"]
+CMD [ "bash", "start.sh"]

+ 10 - 2
README.md

@@ -53,8 +53,6 @@ User-friendly WebUI for LLMs, supported LLM runners include Ollama and OpenAI-co
 
 - 💬 **Collaborative Chat**: Harness the collective intelligence of multiple models by seamlessly orchestrating group conversations. Use the `@` command to specify the model, enabling dynamic and diverse dialogues within your chat interface. Immerse yourself in the collective intelligence woven into your chat environment.
 
-- 🤝 **OpenAI API Integration**: Effortlessly integrate OpenAI-compatible API for versatile conversations alongside Ollama models. Customize the API Base URL to link with **LMStudio, Mistral, OpenRouter, and more**.
-
 - 🔄 **Regeneration History Access**: Easily revisit and explore your entire regeneration history.
 
 - 📜 **Chat History**: Effortlessly access and manage your conversation history.
@@ -65,8 +63,18 @@ User-friendly WebUI for LLMs, supported LLM runners include Ollama and OpenAI-co
 
 - ⚙️ **Fine-Tuned Control with Advanced Parameters**: Gain a deeper level of control by adjusting parameters such as temperature and defining your system prompts to tailor the conversation to your specific preferences and needs.
 
+- 🎨🤖 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using AUTOMATIC1111 API (local) and DALL-E, enriching your chat experience with dynamic visual content.
+
+- 🤝 **OpenAI API Integration**: Effortlessly integrate OpenAI-compatible API for versatile conversations alongside Ollama models. Customize the API Base URL to link with **LMStudio, Mistral, OpenRouter, and more**.
+
+- ✨ **Multiple OpenAI-Compatible API Support**: Seamlessly integrate and customize various OpenAI-compatible APIs, enhancing the versatility of your chat interactions.
+
 - 🔗 **External Ollama Server Connection**: Seamlessly link to an external Ollama server hosted on a different address by configuring the environment variable.
 
+- 🔀 **Multiple Ollama Instance Load Balancing**: Effortlessly distribute chat requests across multiple Ollama instances for enhanced performance and reliability.
+
+- 👥 **Multi-User Management**: Easily oversee and administer users via our intuitive admin panel, streamlining user management processes.
+
 - 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
 
 - 🔒 **Backend Reverse Proxy Support**: Bolster security through direct communication between Open WebUI backend and Ollama. This key feature eliminates the need to expose Ollama over LAN. Requests made to the '/ollama/api' route from the web UI are seamlessly redirected to Ollama from the backend, enhancing overall system security.

+ 190 - 53
backend/apps/images/main.py

@@ -21,7 +21,16 @@ from utils.utils import (
 from utils.misc import calculate_sha256
 from typing import Optional
 from pydantic import BaseModel
-from config import AUTOMATIC1111_BASE_URL
+from pathlib import Path
+import uuid
+import base64
+import json
+
+from config import CACHE_DIR, AUTOMATIC1111_BASE_URL
+
+
+IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
+IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
 
 app = FastAPI()
 app.add_middleware(
@@ -32,25 +41,34 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
+app.state.ENGINE = ""
+app.state.ENABLED = False
+
+app.state.OPENAI_API_KEY = ""
+app.state.MODEL = ""
+
+
 app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
-app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != ""
+
 app.state.IMAGE_SIZE = "512x512"
 app.state.IMAGE_STEPS = 50
 
 
-@app.get("/enabled", response_model=bool)
-async def get_enable_status(request: Request, user=Depends(get_admin_user)):
-    return app.state.ENABLED
+@app.get("/config")
+async def get_config(request: Request, user=Depends(get_admin_user)):
+    return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
 
 
-@app.get("/enabled/toggle", response_model=bool)
-async def toggle_enabled(request: Request, user=Depends(get_admin_user)):
-    try:
-        r = requests.head(app.state.AUTOMATIC1111_BASE_URL)
-        app.state.ENABLED = not app.state.ENABLED
-        return app.state.ENABLED
-    except Exception as e:
-        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
+class ConfigUpdateForm(BaseModel):
+    engine: str
+    enabled: bool
+
+
+@app.post("/config/update")
+async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
+    app.state.ENGINE = form_data.engine
+    app.state.ENABLED = form_data.enabled
+    return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
 
 
 class UrlUpdateForm(BaseModel):
@@ -58,17 +76,24 @@ class UrlUpdateForm(BaseModel):
 
 
 @app.get("/url")
-async def get_openai_url(user=Depends(get_admin_user)):
+async def get_automatic1111_url(user=Depends(get_admin_user)):
     return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
 
 
 @app.post("/url/update")
-async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
+async def update_automatic1111_url(
+    form_data: UrlUpdateForm, user=Depends(get_admin_user)
+):
 
     if form_data.url == "":
         app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
     else:
-        app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/")
+        url = form_data.url.strip("/")
+        try:
+            r = requests.head(url)
+            app.state.AUTOMATIC1111_BASE_URL = url
+        except Exception as e:
+            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
 
     return {
         "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
@@ -76,6 +101,30 @@ async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_use
     }
 
 
+class OpenAIKeyUpdateForm(BaseModel):
+    key: str
+
+
+@app.get("/key")
+async def get_openai_key(user=Depends(get_admin_user)):
+    return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
+
+
+@app.post("/key/update")
+async def update_openai_key(
+    form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user)
+):
+
+    if form_data.key == "":
+        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
+
+    app.state.OPENAI_API_KEY = form_data.key
+    return {
+        "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
+        "status": True,
+    }
+
+
 class ImageSizeUpdateForm(BaseModel):
     size: str
 
@@ -132,9 +181,22 @@ async def update_image_size(
 @app.get("/models")
 def get_models(user=Depends(get_current_user)):
     try:
-        r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models")
-        models = r.json()
-        return models
+        if app.state.ENGINE == "openai":
+            return [
+                {"id": "dall-e-2", "name": "DALL·E 2"},
+                {"id": "dall-e-3", "name": "DALL·E 3"},
+            ]
+        else:
+            r = requests.get(
+                url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
+            )
+            models = r.json()
+            return list(
+                map(
+                    lambda model: {"id": model["title"], "name": model["model_name"]},
+                    models,
+                )
+            )
     except Exception as e:
         app.state.ENABLED = False
         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@@ -143,10 +205,12 @@ def get_models(user=Depends(get_current_user)):
 @app.get("/models/default")
 async def get_default_model(user=Depends(get_admin_user)):
     try:
-        r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
-        options = r.json()
-
-        return {"model": options["sd_model_checkpoint"]}
+        if app.state.ENGINE == "openai":
+            return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
+        else:
+            r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
+            options = r.json()
+            return {"model": options["sd_model_checkpoint"]}
     except Exception as e:
         app.state.ENABLED = False
         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@@ -157,16 +221,21 @@ class UpdateModelForm(BaseModel):
 
 
 def set_model_handler(model: str):
-    r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
-    options = r.json()
 
-    if model != options["sd_model_checkpoint"]:
-        options["sd_model_checkpoint"] = model
-        r = requests.post(
-            url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
-        )
+    if app.state.ENGINE == "openai":
+        app.state.MODEL = model
+        return app.state.MODEL
+    else:
+        r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
+        options = r.json()
+
+        if model != options["sd_model_checkpoint"]:
+            options["sd_model_checkpoint"] = model
+            r = requests.post(
+                url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
+            )
 
-    return options
+        return options
 
 
 @app.post("/models/default/update")
@@ -181,45 +250,113 @@ class GenerateImageForm(BaseModel):
     model: Optional[str] = None
     prompt: str
     n: int = 1
-    size: str = "512x512"
+    size: Optional[str] = None
     negative_prompt: Optional[str] = None
 
 
+def save_b64_image(b64_str):
+    image_id = str(uuid.uuid4())
+    file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
+
+    try:
+        # Split the base64 string to get the actual image data
+        img_data = base64.b64decode(b64_str)
+
+        # Write the image data to a file
+        with open(file_path, "wb") as f:
+            f.write(img_data)
+
+        return image_id
+    except Exception as e:
+        print(f"Error saving image: {e}")
+        return None
+
+
 @app.post("/generations")
 def generate_image(
     form_data: GenerateImageForm,
     user=Depends(get_current_user),
 ):
 
-    print(form_data)
-
+    r = None
     try:
-        if form_data.model:
-            set_model_handler(form_data.model)
+        if app.state.ENGINE == "openai":
 
-        width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
+            headers = {}
+            headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
+            headers["Content-Type"] = "application/json"
 
-        data = {
-            "prompt": form_data.prompt,
-            "batch_size": form_data.n,
-            "width": width,
-            "height": height,
-        }
+            data = {
+                "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
+                "prompt": form_data.prompt,
+                "n": form_data.n,
+                "size": form_data.size if form_data.size else app.state.IMAGE_SIZE,
+                "response_format": "b64_json",
+            }
+            r = requests.post(
+                url=f"https://api.openai.com/v1/images/generations",
+                json=data,
+                headers=headers,
+            )
 
-        if app.state.IMAGE_STEPS != None:
-            data["steps"] = app.state.IMAGE_STEPS
+            r.raise_for_status()
 
-        if form_data.negative_prompt != None:
-            data["negative_prompt"] = form_data.negative_prompt
+            res = r.json()
 
-        print(data)
+            images = []
 
-        r = requests.post(
-            url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
-            json=data,
-        )
+            for image in res["data"]:
+                image_id = save_b64_image(image["b64_json"])
+                images.append({"url": f"/cache/image/generations/{image_id}.png"})
+                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
+
+                with open(file_body_path, "w") as f:
+                    json.dump(data, f)
+
+            return images
+
+        else:
+            if form_data.model:
+                set_model_handler(form_data.model)
+
+            width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
+
+            data = {
+                "prompt": form_data.prompt,
+                "batch_size": form_data.n,
+                "width": width,
+                "height": height,
+            }
+
+            if app.state.IMAGE_STEPS != None:
+                data["steps"] = app.state.IMAGE_STEPS
+
+            if form_data.negative_prompt != None:
+                data["negative_prompt"] = form_data.negative_prompt
+
+            r = requests.post(
+                url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
+                json=data,
+            )
+
+            res = r.json()
+
+            print(res)
+
+            images = []
+
+            for image in res["images"]:
+                image_id = save_b64_image(image)
+                images.append({"url": f"/cache/image/generations/{image_id}.png"})
+                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
+
+                with open(file_body_path, "w") as f:
+                    json.dump({**data, "info": res["info"]}, f)
+
+            return images
 
-        return r.json()
     except Exception as e:
         print(e)
+        if r:
+            print(r.json())
         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))

+ 41 - 0
backend/apps/litellm/main.py

@@ -0,0 +1,41 @@
+from litellm.proxy.proxy_server import ProxyConfig, initialize
+from litellm.proxy.proxy_server import app
+
+from fastapi import FastAPI, Request, Depends, status
+from fastapi.responses import JSONResponse
+from utils.utils import get_http_authorization_cred, get_current_user
+from config import ENV
+
+proxy_config = ProxyConfig()
+
+
+async def config():
+    router, model_list, general_settings = await proxy_config.load_config(
+        router=None, config_file_path="./data/litellm/config.yaml"
+    )
+
+    await initialize(config="./data/litellm/config.yaml", telemetry=False)
+
+
+async def startup():
+    await config()
+
+
+@app.on_event("startup")
+async def on_startup():
+    await startup()
+
+
+@app.middleware("http")
+async def auth_middleware(request: Request, call_next):
+    auth_header = request.headers.get("Authorization", "")
+
+    if ENV != "dev":
+        try:
+            user = get_current_user(get_http_authorization_cred(auth_header))
+            print(user)
+        except Exception as e:
+            return JSONResponse(status_code=400, content={"detail": str(e)})
+
+    response = await call_next(request)
+    return response

+ 17 - 3
backend/apps/ollama/main.py

@@ -15,7 +15,7 @@ import asyncio
 from apps.web.models.users import Users
 from constants import ERROR_MESSAGES
 from utils.utils import decode_token, get_current_user, get_admin_user
-from config import OLLAMA_BASE_URLS
+from config import OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST
 
 from typing import Optional, List, Union
 
@@ -29,6 +29,10 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
+
+app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
+app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
+
 app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
 app.state.MODELS = {}
 
@@ -129,9 +133,19 @@ async def get_all_models():
 async def get_ollama_tags(
     url_idx: Optional[int] = None, user=Depends(get_current_user)
 ):
-
     if url_idx == None:
-        return await get_all_models()
+        models = await get_all_models()
+
+        if app.state.MODEL_FILTER_ENABLED:
+            if user.role == "user":
+                models["models"] = list(
+                    filter(
+                        lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
+                        models["models"],
+                    )
+                )
+                return models
+        return models
     else:
         url = app.state.OLLAMA_BASE_URLS[url_idx]
         try:

+ 22 - 4
backend/apps/openai/main.py

@@ -18,7 +18,13 @@ from utils.utils import (
     get_verified_user,
     get_admin_user,
 )
-from config import OPENAI_API_BASE_URLS, OPENAI_API_KEYS, CACHE_DIR
+from config import (
+    OPENAI_API_BASE_URLS,
+    OPENAI_API_KEYS,
+    CACHE_DIR,
+    MODEL_FILTER_ENABLED,
+    MODEL_FILTER_LIST,
+)
 from typing import List, Optional
 
 
@@ -34,6 +40,9 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
+app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
+app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
+
 app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
 app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
 
@@ -186,12 +195,21 @@ async def get_all_models():
     return models
 
 
-# , user=Depends(get_current_user)
 @app.get("/models")
 @app.get("/models/{url_idx}")
-async def get_models(url_idx: Optional[int] = None):
+async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
     if url_idx == None:
-        return await get_all_models()
+        models = await get_all_models()
+        if app.state.MODEL_FILTER_ENABLED:
+            if user.role == "user":
+                models["data"] = list(
+                    filter(
+                        lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
+                        models["data"],
+                    )
+                )
+                return models
+        return models
     else:
         url = app.state.OPENAI_API_BASE_URLS[url_idx]
         try:

+ 38 - 88
backend/apps/rag/main.py

@@ -44,6 +44,8 @@ from apps.web.models.documents import (
     DocumentResponse,
 )
 
+from apps.rag.utils import query_doc, query_collection
+
 from utils.misc import (
     calculate_sha256,
     calculate_sha256_string,
@@ -75,6 +77,7 @@ from constants import ERROR_MESSAGES
 
 app = FastAPI()
 
+app.state.PDF_EXTRACT_IMAGES = False
 app.state.CHUNK_SIZE = CHUNK_SIZE
 app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
 app.state.RAG_TEMPLATE = RAG_TEMPLATE
@@ -182,12 +185,15 @@ async def update_embedding_model(
     }
 
 
-@app.get("/chunk")
-async def get_chunk_params(user=Depends(get_admin_user)):
+@app.get("/config")
+async def get_rag_config(user=Depends(get_admin_user)):
     return {
         "status": True,
-        "chunk_size": app.state.CHUNK_SIZE,
-        "chunk_overlap": app.state.CHUNK_OVERLAP,
+        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
+        "chunk": {
+            "chunk_size": app.state.CHUNK_SIZE,
+            "chunk_overlap": app.state.CHUNK_OVERLAP,
+        },
     }
 
 
@@ -196,17 +202,24 @@ class ChunkParamUpdateForm(BaseModel):
     chunk_overlap: int
 
 
-@app.post("/chunk/update")
-async def update_chunk_params(
-    form_data: ChunkParamUpdateForm, user=Depends(get_admin_user)
-):
-    app.state.CHUNK_SIZE = form_data.chunk_size
-    app.state.CHUNK_OVERLAP = form_data.chunk_overlap
+class ConfigUpdateForm(BaseModel):
+    pdf_extract_images: bool
+    chunk: ChunkParamUpdateForm
+
+
+@app.post("/config/update")
+async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
+    app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
+    app.state.CHUNK_SIZE = form_data.chunk.chunk_size
+    app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
 
     return {
         "status": True,
-        "chunk_size": app.state.CHUNK_SIZE,
-        "chunk_overlap": app.state.CHUNK_OVERLAP,
+        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
+        "chunk": {
+            "chunk_size": app.state.CHUNK_SIZE,
+            "chunk_overlap": app.state.CHUNK_OVERLAP,
+        },
     }
 
 
@@ -248,21 +261,18 @@ class QueryDocForm(BaseModel):
 
 
 @app.post("/query/doc")
-def query_doc(
+def query_doc_handler(
     form_data: QueryDocForm,
     user=Depends(get_current_user),
 ):
+
     try:
-        # if you use docker use the model from the environment variable
-        collection = CHROMA_CLIENT.get_collection(
-            name=form_data.collection_name,
+        return query_doc(
+            collection_name=form_data.collection_name,
+            query=form_data.query,
+            k=form_data.k if form_data.k else app.state.TOP_K,
             embedding_function=app.state.sentence_transformer_ef,
         )
-        result = collection.query(
-            query_texts=[form_data.query],
-            n_results=form_data.k if form_data.k else app.state.TOP_K,
-        )
-        return result
     except Exception as e:
         print(e)
         raise HTTPException(
@@ -277,76 +287,16 @@ class QueryCollectionsForm(BaseModel):
     k: Optional[int] = None
 
 
-def merge_and_sort_query_results(query_results, k):
-    # Initialize lists to store combined data
-    combined_ids = []
-    combined_distances = []
-    combined_metadatas = []
-    combined_documents = []
-
-    # Combine data from each dictionary
-    for data in query_results:
-        combined_ids.extend(data["ids"][0])
-        combined_distances.extend(data["distances"][0])
-        combined_metadatas.extend(data["metadatas"][0])
-        combined_documents.extend(data["documents"][0])
-
-    # Create a list of tuples (distance, id, metadata, document)
-    combined = list(
-        zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
-    )
-
-    # Sort the list based on distances
-    combined.sort(key=lambda x: x[0])
-
-    # Unzip the sorted list
-    sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
-
-    # Slicing the lists to include only k elements
-    sorted_distances = list(sorted_distances)[:k]
-    sorted_ids = list(sorted_ids)[:k]
-    sorted_metadatas = list(sorted_metadatas)[:k]
-    sorted_documents = list(sorted_documents)[:k]
-
-    # Create the output dictionary
-    merged_query_results = {
-        "ids": [sorted_ids],
-        "distances": [sorted_distances],
-        "metadatas": [sorted_metadatas],
-        "documents": [sorted_documents],
-        "embeddings": None,
-        "uris": None,
-        "data": None,
-    }
-
-    return merged_query_results
-
-
 @app.post("/query/collection")
-def query_collection(
+def query_collection_handler(
     form_data: QueryCollectionsForm,
     user=Depends(get_current_user),
 ):
-    results = []
-
-    for collection_name in form_data.collection_names:
-        try:
-            # if you use docker use the model from the environment variable
-            collection = CHROMA_CLIENT.get_collection(
-                name=collection_name,
-                embedding_function=app.state.sentence_transformer_ef,
-            )
-
-            result = collection.query(
-                query_texts=[form_data.query],
-                n_results=form_data.k if form_data.k else app.state.TOP_K,
-            )
-            results.append(result)
-        except:
-            pass
-
-    return merge_and_sort_query_results(
-        results, form_data.k if form_data.k else app.state.TOP_K
+    return query_collection(
+        collection_names=form_data.collection_names,
+        query=form_data.query,
+        k=form_data.k if form_data.k else app.state.TOP_K,
+        embedding_function=app.state.sentence_transformer_ef,
     )
 
 
@@ -425,7 +375,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
     ]
 
     if file_ext == "pdf":
-        loader = PyPDFLoader(file_path, extract_images=True)
+        loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
     elif file_ext == "csv":
         loader = CSVLoader(file_path)
     elif file_ext == "rst":

+ 97 - 0
backend/apps/rag/utils.py

@@ -0,0 +1,97 @@
+import re
+from typing import List
+
+from config import CHROMA_CLIENT
+
+
+def query_doc(collection_name: str, query: str, k: int, embedding_function):
+    try:
+        # if you use docker use the model from the environment variable
+        collection = CHROMA_CLIENT.get_collection(
+            name=collection_name,
+            embedding_function=embedding_function,
+        )
+        result = collection.query(
+            query_texts=[query],
+            n_results=k,
+        )
+        return result
+    except Exception as e:
+        raise e
+
+
+def merge_and_sort_query_results(query_results, k):
+    # Initialize lists to store combined data
+    combined_ids = []
+    combined_distances = []
+    combined_metadatas = []
+    combined_documents = []
+
+    # Combine data from each dictionary
+    for data in query_results:
+        combined_ids.extend(data["ids"][0])
+        combined_distances.extend(data["distances"][0])
+        combined_metadatas.extend(data["metadatas"][0])
+        combined_documents.extend(data["documents"][0])
+
+    # Create a list of tuples (distance, id, metadata, document)
+    combined = list(
+        zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
+    )
+
+    # Sort the list based on distances
+    combined.sort(key=lambda x: x[0])
+
+    # Unzip the sorted list
+    sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
+
+    # Slicing the lists to include only k elements
+    sorted_distances = list(sorted_distances)[:k]
+    sorted_ids = list(sorted_ids)[:k]
+    sorted_metadatas = list(sorted_metadatas)[:k]
+    sorted_documents = list(sorted_documents)[:k]
+
+    # Create the output dictionary
+    merged_query_results = {
+        "ids": [sorted_ids],
+        "distances": [sorted_distances],
+        "metadatas": [sorted_metadatas],
+        "documents": [sorted_documents],
+        "embeddings": None,
+        "uris": None,
+        "data": None,
+    }
+
+    return merged_query_results
+
+
+def query_collection(
+    collection_names: List[str], query: str, k: int, embedding_function
+):
+
+    results = []
+
+    for collection_name in collection_names:
+        try:
+            # if you use docker use the model from the environment variable
+            collection = CHROMA_CLIENT.get_collection(
+                name=collection_name,
+                embedding_function=embedding_function,
+            )
+
+            result = collection.query(
+                query_texts=[query],
+                n_results=k,
+            )
+            results.append(result)
+        except:
+            pass
+
+    return merge_and_sort_query_results(results, k)
+
+
+def rag_template(template: str, context: str, query: str):
+    template = re.sub(r"\[context\]", context, template)
+    template = re.sub(r"\[query\]", query, template)
+
+    return template

+ 6 - 1
backend/config.py

@@ -251,7 +251,7 @@ OPENAI_API_BASE_URLS = (
     OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL
 )
 
-OPENAI_API_BASE_URLS = [url.strip() for url in OPENAI_API_BASE_URL.split(";")]
+OPENAI_API_BASE_URLS = [url.strip() for url in OPENAI_API_BASE_URLS.split(";")]
 
 
 ####################################
@@ -292,6 +292,11 @@ DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending")
 USER_PERMISSIONS = {"chat": {"deletion": True}}
 
 
+MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", False)
+MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
+MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]
+
+
 ####################################
 # WEBUI_VERSION
 ####################################

+ 1 - 0
backend/data/config.json

@@ -1,4 +1,5 @@
 {
+    "version": "0.0.1",
     "ui": {
         "prompt_suggestions": [
             {

+ 175 - 38
backend/main.py

@@ -9,27 +9,37 @@ import requests
 from fastapi import FastAPI, Request, Depends, status
 from fastapi.staticfiles import StaticFiles
 from fastapi import HTTPException
-from fastapi.responses import JSONResponse
 from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from starlette.exceptions import HTTPException as StarletteHTTPException
+from starlette.middleware.base import BaseHTTPMiddleware
 
 
-from litellm.proxy.proxy_server import ProxyConfig, initialize
-from litellm.proxy.proxy_server import app as litellm_app
-
 from apps.ollama.main import app as ollama_app
 from apps.openai.main import app as openai_app
+from apps.litellm.main import app as litellm_app, startup as litellm_app_startup
 from apps.audio.main import app as audio_app
 from apps.images.main import app as images_app
 from apps.rag.main import app as rag_app
 from apps.web.main import app as webui_app
 
+from pydantic import BaseModel
+from typing import List
 
-from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
-from constants import ERROR_MESSAGES
 
-from utils.utils import get_http_authorization_cred, get_current_user
+from utils.utils import get_admin_user
+from apps.rag.utils import query_doc, query_collection, rag_template
+
+from config import (
+    WEBUI_NAME,
+    ENV,
+    VERSION,
+    CHANGELOG,
+    FRONTEND_BUILD_DIR,
+    MODEL_FILTER_ENABLED,
+    MODEL_FILTER_LIST,
+)
+from constants import ERROR_MESSAGES
 
 
 class SPAStaticFiles(StaticFiles):
@@ -43,23 +53,11 @@ class SPAStaticFiles(StaticFiles):
                 raise ex
 
 
-proxy_config = ProxyConfig()
-
-
-async def config():
-    router, model_list, general_settings = await proxy_config.load_config(
-        router=None, config_file_path="./data/litellm/config.yaml"
-    )
-
-    await initialize(config="./data/litellm/config.yaml", telemetry=False)
-
-
-async def startup():
-    await config()
-
-
 app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
 
+app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
+app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
+
 origins = ["*"]
 
 app.add_middleware(
@@ -73,7 +71,127 @@ app.add_middleware(
 
 @app.on_event("startup")
 async def on_startup():
-    await startup()
+    await litellm_app_startup()
+
+
+class RAGMiddleware(BaseHTTPMiddleware):
+    async def dispatch(self, request: Request, call_next):
+        if request.method == "POST" and (
+            "/api/chat" in request.url.path or "/chat/completions" in request.url.path
+        ):
+            print(request.url.path)
+
+            # Read the original request body
+            body = await request.body()
+            # Decode body to string
+            body_str = body.decode("utf-8")
+            # Parse string to JSON
+            data = json.loads(body_str) if body_str else {}
+
+            # Example: Add a new key-value pair or modify existing ones
+            # data["modified"] = True  # Example modification
+            if "docs" in data:
+                docs = data["docs"]
+                print(docs)
+
+                last_user_message_idx = None
+                for i in range(len(data["messages"]) - 1, -1, -1):
+                    if data["messages"][i]["role"] == "user":
+                        last_user_message_idx = i
+                        break
+
+                user_message = data["messages"][last_user_message_idx]
+
+                if isinstance(user_message["content"], list):
+                    # Handle list content input
+                    content_type = "list"
+                    query = ""
+                    for content_item in user_message["content"]:
+                        if content_item["type"] == "text":
+                            query = content_item["text"]
+                            break
+                elif isinstance(user_message["content"], str):
+                    # Handle text content input
+                    content_type = "text"
+                    query = user_message["content"]
+                else:
+                    # Fallback in case the input does not match expected types
+                    content_type = None
+                    query = ""
+
+                relevant_contexts = []
+
+                for doc in docs:
+                    context = None
+
+                    try:
+                        if doc["type"] == "collection":
+                            context = query_collection(
+                                collection_names=doc["collection_names"],
+                                query=query,
+                                k=rag_app.state.TOP_K,
+                                embedding_function=rag_app.state.sentence_transformer_ef,
+                            )
+                        else:
+                            context = query_doc(
+                                collection_name=doc["collection_name"],
+                                query=query,
+                                k=rag_app.state.TOP_K,
+                                embedding_function=rag_app.state.sentence_transformer_ef,
+                            )
+                    except Exception as e:
+                        print(e)
+                        context = None
+
+                    relevant_contexts.append(context)
+
+                context_string = ""
+                for context in relevant_contexts:
+                    if context:
+                        context_string += " ".join(context["documents"][0]) + "\n"
+
+                ra_content = rag_template(
+                    template=rag_app.state.RAG_TEMPLATE,
+                    context=context_string,
+                    query=query,
+                )
+
+                if content_type == "list":
+                    new_content = []
+                    for content_item in user_message["content"]:
+                        if content_item["type"] == "text":
+                            # Update the text item's content with ra_content
+                            new_content.append({"type": "text", "text": ra_content})
+                        else:
+                            # Keep other types of content as they are
+                            new_content.append(content_item)
+                    new_user_message = {**user_message, "content": new_content}
+                else:
+                    new_user_message = {
+                        **user_message,
+                        "content": ra_content,
+                    }
+
+                data["messages"][last_user_message_idx] = new_user_message
+                del data["docs"]
+
+                print(data["messages"])
+
+            modified_body_bytes = json.dumps(data).encode("utf-8")
+
+            # Create a new request with the modified body
+            scope = request.scope
+            scope["body"] = modified_body_bytes
+            request = Request(scope, receive=lambda: self._receive(modified_body_bytes))
+
+        response = await call_next(request)
+        return response
+
+    async def _receive(self, body: bytes):
+        return {"type": "http.request", "body": body, "more_body": False}
+
+
+app.add_middleware(RAGMiddleware)
 
 
 @app.middleware("http")
@@ -86,21 +204,6 @@ async def check_url(request: Request, call_next):
     return response
 
 
-@litellm_app.middleware("http")
-async def auth_middleware(request: Request, call_next):
-    auth_header = request.headers.get("Authorization", "")
-
-    if ENV != "dev":
-        try:
-            user = get_current_user(get_http_authorization_cred(auth_header))
-            print(user)
-        except Exception as e:
-            return JSONResponse(status_code=400, content={"detail": str(e)})
-
-    response = await call_next(request)
-    return response
-
-
 app.mount("/api/v1", webui_app)
 app.mount("/litellm/api", litellm_app)
 
@@ -125,6 +228,39 @@ async def get_app_config():
     }
 
 
+@app.get("/api/config/model/filter")
+async def get_model_filter_config(user=Depends(get_admin_user)):
+    return {
+        "enabled": app.state.MODEL_FILTER_ENABLED,
+        "models": app.state.MODEL_FILTER_LIST,
+    }
+
+
+class ModelFilterConfigForm(BaseModel):
+    enabled: bool
+    models: List[str]
+
+
+@app.post("/api/config/model/filter")
+async def get_model_filter_config(
+    form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
+):
+
+    app.state.MODEL_FILTER_ENABLED = form_data.enabled
+    app.state.MODEL_FILTER_LIST = form_data.models
+
+    ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
+    ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
+
+    openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
+    openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
+
+    return {
+        "enabled": app.state.MODEL_FILTER_ENABLED,
+        "models": app.state.MODEL_FILTER_LIST,
+    }
+
+
 @app.get("/api/version")
 async def get_app_config():
 
@@ -156,6 +292,7 @@ async def get_app_latest_release_version():
 
 
 app.mount("/static", StaticFiles(directory="static"), name="static")
+app.mount("/cache", StaticFiles(directory="data/cache"), name="cache")
 
 
 app.mount(

+ 2 - 1
backend/requirements.txt

@@ -16,7 +16,8 @@ aiohttp
 peewee
 bcrypt
 
-litellm
+litellm==1.30.7
+argon2-cffi
 apscheduler
 google-generativeai
 

+ 1 - 1
package.json

@@ -1,6 +1,6 @@
 {
 	"name": "open-webui",
-	"version": "0.1.110",
+	"version": "0.1.111",
 	"private": true,
 	"scripts": {
 		"dev": "vite dev --host",

+ 83 - 8
src/lib/apis/images/index.ts

@@ -1,9 +1,9 @@
 import { IMAGES_API_BASE_URL } from '$lib/constants';
 
-export const getImageGenerationEnabledStatus = async (token: string = '') => {
+export const getImageGenerationConfig = async (token: string = '') => {
 	let error = null;
 
-	const res = await fetch(`${IMAGES_API_BASE_URL}/enabled`, {
+	const res = await fetch(`${IMAGES_API_BASE_URL}/config`, {
 		method: 'GET',
 		headers: {
 			Accept: 'application/json',
@@ -32,10 +32,50 @@ export const getImageGenerationEnabledStatus = async (token: string = '') => {
 	return res;
 };
 
-export const toggleImageGenerationEnabledStatus = async (token: string = '') => {
+export const updateImageGenerationConfig = async (
+	token: string = '',
+	engine: string,
+	enabled: boolean
+) => {
 	let error = null;
 
-	const res = await fetch(`${IMAGES_API_BASE_URL}/enabled/toggle`, {
+	const res = await fetch(`${IMAGES_API_BASE_URL}/config/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		},
+		body: JSON.stringify({
+			engine,
+			enabled
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getOpenAIKey = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${IMAGES_API_BASE_URL}/key`, {
 		method: 'GET',
 		headers: {
 			Accept: 'application/json',
@@ -61,7 +101,42 @@ export const toggleImageGenerationEnabledStatus = async (token: string = '') =>
 		throw error;
 	}
 
-	return res;
+	return res.OPENAI_API_KEY;
+};
+
+export const updateOpenAIKey = async (token: string = '', key: string) => {
+	let error = null;
+
+	const res = await fetch(`${IMAGES_API_BASE_URL}/key/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		},
+		body: JSON.stringify({
+			key: key
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.OPENAI_API_KEY;
 };
 
 export const getAUTOMATIC1111Url = async (token: string = '') => {
@@ -263,7 +338,7 @@ export const updateImageSteps = async (token: string = '', steps: number) => {
 	return res.IMAGE_STEPS;
 };
 
-export const getDiffusionModels = async (token: string = '') => {
+export const getImageGenerationModels = async (token: string = '') => {
 	let error = null;
 
 	const res = await fetch(`${IMAGES_API_BASE_URL}/models`, {
@@ -295,7 +370,7 @@ export const getDiffusionModels = async (token: string = '') => {
 	return res;
 };
 
-export const getDefaultDiffusionModel = async (token: string = '') => {
+export const getDefaultImageGenerationModel = async (token: string = '') => {
 	let error = null;
 
 	const res = await fetch(`${IMAGES_API_BASE_URL}/models/default`, {
@@ -327,7 +402,7 @@ export const getDefaultDiffusionModel = async (token: string = '') => {
 	return res.model;
 };
 
-export const updateDefaultDiffusionModel = async (token: string = '', model: string) => {
+export const updateDefaultImageGenerationModel = async (token: string = '', model: string) => {
 	let error = null;
 
 	const res = await fetch(`${IMAGES_API_BASE_URL}/models/default/update`, {

+ 62 - 0
src/lib/apis/index.ts

@@ -77,3 +77,65 @@ export const getVersionUpdates = async () => {
 
 	return res;
 };
+
+export const getModelFilterConfig = async (token: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_BASE_URL}/api/config/model/filter`, {
+		method: 'GET',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const updateModelFilterConfig = async (
+	token: string,
+	enabled: boolean,
+	models: string[]
+) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_BASE_URL}/api/config/model/filter`, {
+		method: 'POST',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			enabled: enabled,
+			models: models
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 3 - 1
src/lib/apis/litellm/index.ts

@@ -77,6 +77,7 @@ type AddLiteLLMModelForm = {
 	api_base: string;
 	api_key: string;
 	rpm: string;
+	max_tokens: string;
 };
 
 export const addLiteLLMModel = async (token: string = '', payload: AddLiteLLMModelForm) => {
@@ -95,7 +96,8 @@ export const addLiteLLMModel = async (token: string = '', payload: AddLiteLLMMod
 				model: payload.model,
 				...(payload.api_base === '' ? {} : { api_base: payload.api_base }),
 				...(payload.api_key === '' ? {} : { api_key: payload.api_key }),
-				...(isNaN(parseInt(payload.rpm)) ? {} : { rpm: parseInt(payload.rpm) })
+				...(isNaN(parseInt(payload.rpm)) ? {} : { rpm: parseInt(payload.rpm) }),
+				...(payload.max_tokens === '' ? {} : { max_tokens: payload.max_tokens })
 			}
 		})
 	})

+ 16 - 7
src/lib/apis/rag/index.ts

@@ -1,9 +1,9 @@
 import { RAG_API_BASE_URL } from '$lib/constants';
 
-export const getChunkParams = async (token: string) => {
+export const getRAGConfig = async (token: string) => {
 	let error = null;
 
-	const res = await fetch(`${RAG_API_BASE_URL}/chunk`, {
+	const res = await fetch(`${RAG_API_BASE_URL}/config`, {
 		method: 'GET',
 		headers: {
 			'Content-Type': 'application/json',
@@ -27,18 +27,27 @@ export const getChunkParams = async (token: string) => {
 	return res;
 };
 
-export const updateChunkParams = async (token: string, size: number, overlap: number) => {
+type ChunkConfigForm = {
+	chunk_size: number;
+	chunk_overlap: number;
+};
+
+type RAGConfigForm = {
+	pdf_extract_images: boolean;
+	chunk: ChunkConfigForm;
+};
+
+export const updateRAGConfig = async (token: string, payload: RAGConfigForm) => {
 	let error = null;
 
-	const res = await fetch(`${RAG_API_BASE_URL}/chunk/update`, {
+	const res = await fetch(`${RAG_API_BASE_URL}/config/update`, {
 		method: 'POST',
 		headers: {
 			'Content-Type': 'application/json',
 			Authorization: `Bearer ${token}`
 		},
 		body: JSON.stringify({
-			chunk_size: size,
-			chunk_overlap: overlap
+			...payload
 		})
 	})
 		.then(async (res) => {
@@ -252,7 +261,7 @@ export const queryCollection = async (
 	token: string,
 	collection_names: string,
 	query: string,
-	k: number
+	k: number | null = null
 ) => {
 	let error = null;
 

+ 113 - 0
src/lib/components/admin/Settings/Users.svelte

@@ -1,10 +1,14 @@
 <script lang="ts">
+	import { getModelFilterConfig, updateModelFilterConfig } from '$lib/apis';
 	import { getSignUpEnabledStatus, toggleSignUpEnabledStatus } from '$lib/apis/auths';
 	import { getUserPermissions, updateUserPermissions } from '$lib/apis/users';
+	import { models } from '$lib/stores';
 	import { onMount } from 'svelte';
 
 	export let saveHandler: Function;
 
+	let whitelistEnabled = false;
+	let whitelistModels = [''];
 	let permissions = {
 		chat: {
 			deletion: true
@@ -13,6 +17,13 @@
 
 	onMount(async () => {
 		permissions = await getUserPermissions(localStorage.token);
+
+		const res = await getModelFilterConfig(localStorage.token);
+		if (res) {
+			whitelistEnabled = res.enabled;
+
+			whitelistModels = res.models.length > 0 ? res.models : [''];
+		}
 	});
 </script>
 
@@ -21,6 +32,8 @@
 	on:submit|preventDefault={async () => {
 		// console.log('submit');
 		await updateUserPermissions(localStorage.token, permissions);
+
+		await updateModelFilterConfig(localStorage.token, whitelistEnabled, whitelistModels);
 		saveHandler();
 	}}
 >
@@ -69,6 +82,106 @@
 				</button>
 			</div>
 		</div>
+
+		<hr class=" dark:border-gray-700 my-2" />
+
+		<div class="mt-2 space-y-3 pr-1.5">
+			<div>
+				<div class="mb-2">
+					<div class="flex justify-between items-center text-xs">
+						<div class=" text-sm font-medium">Manage Models</div>
+					</div>
+				</div>
+
+				<div class=" space-y-3">
+					<div>
+						<div class="flex justify-between items-center text-xs">
+							<div class=" text-xs font-medium">Model Whitelisting</div>
+
+							<button
+								class=" text-xs font-medium text-gray-500"
+								type="button"
+								on:click={() => {
+									whitelistEnabled = !whitelistEnabled;
+								}}>{whitelistEnabled ? 'On' : 'Off'}</button
+							>
+						</div>
+					</div>
+
+					{#if whitelistEnabled}
+						<div>
+							<div class=" space-y-1.5">
+								{#each whitelistModels as modelId, modelIdx}
+									<div class="flex w-full">
+										<div class="flex-1 mr-2">
+											<select
+												class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+												bind:value={modelId}
+												placeholder="Select a model"
+											>
+												<option value="" disabled selected>Select a model</option>
+												{#each $models.filter((model) => model.id) as model}
+													<option value={model.id} class="bg-gray-100 dark:bg-gray-700"
+														>{model.name}</option
+													>
+												{/each}
+											</select>
+										</div>
+
+										{#if modelIdx === 0}
+											<button
+												class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-900 dark:text-white rounded-lg transition"
+												type="button"
+												on:click={() => {
+													if (whitelistModels.at(-1) !== '') {
+														whitelistModels = [...whitelistModels, ''];
+													}
+												}}
+											>
+												<svg
+													xmlns="http://www.w3.org/2000/svg"
+													viewBox="0 0 16 16"
+													fill="currentColor"
+													class="w-4 h-4"
+												>
+													<path
+														d="M8.75 3.75a.75.75 0 0 0-1.5 0v3.5h-3.5a.75.75 0 0 0 0 1.5h3.5v3.5a.75.75 0 0 0 1.5 0v-3.5h3.5a.75.75 0 0 0 0-1.5h-3.5v-3.5Z"
+													/>
+												</svg>
+											</button>
+										{:else}
+											<button
+												class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-900 dark:text-white rounded-lg transition"
+												type="button"
+												on:click={() => {
+													whitelistModels.splice(modelIdx, 1);
+													whitelistModels = whitelistModels;
+												}}
+											>
+												<svg
+													xmlns="http://www.w3.org/2000/svg"
+													viewBox="0 0 16 16"
+													fill="currentColor"
+													class="w-4 h-4"
+												>
+													<path d="M3.75 7.25a.75.75 0 0 0 0 1.5h8.5a.75.75 0 0 0 0-1.5h-8.5Z" />
+												</svg>
+											</button>
+										{/if}
+									</div>
+								{/each}
+							</div>
+
+							<div class="flex justify-end items-center text-xs mt-1.5 text-right">
+								<div class=" text-xs font-medium">
+									{whitelistModels.length} Model(s) Whitelisted
+								</div>
+							</div>
+						</div>
+					{/if}
+				</div>
+			</div>
+		</div>
 	</div>
 
 	<div class="flex justify-end pt-3 text-sm font-medium">

+ 3 - 3
src/lib/components/chat/MessageInput.svelte

@@ -19,7 +19,7 @@
 
 	export let suggestionPrompts = [];
 	export let autoScroll = true;
-	let chatTextAreaElement:HTMLTextAreaElement
+	let chatTextAreaElement: HTMLTextAreaElement;
 	let filesInputElement;
 
 	let promptsElement;
@@ -359,12 +359,12 @@
 
 {#if dragged}
 	<div
-		class="fixed w-full h-full flex z-50 touch-none pointer-events-none"
+		class="fixed lg:w-[calc(100%-260px)] w-full h-full flex z-50 touch-none pointer-events-none"
 		id="dropzone"
 		role="region"
 		aria-label="Drag and Drop Container"
 	>
-		<div class="absolute rounded-xl w-full h-full backdrop-blur bg-gray-800/40 flex justify-center">
+		<div class="absolute w-full h-full backdrop-blur bg-gray-800/40 flex justify-center">
 			<div class="m-auto pt-64 flex flex-col justify-center">
 				<div class="max-w-md">
 					<AddFilesPlaceholder />

+ 14 - 11
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -6,6 +6,7 @@
 	import auto_render from 'katex/dist/contrib/auto-render.mjs';
 	import 'katex/dist/katex.min.css';
 
+	import { fade } from 'svelte/transition';
 	import { createEventDispatcher } from 'svelte';
 	import { onMount, tick } from 'svelte';
 
@@ -276,13 +277,15 @@
 
 	const generateImage = async (message) => {
 		generatingImage = true;
-		const res = await imageGenerations(localStorage.token, message.content);
+		const res = await imageGenerations(localStorage.token, message.content).catch((error) => {
+			toast.error(error);
+		});
 		console.log(res);
 
 		if (res) {
-			message.files = res.images.map((image) => ({
+			message.files = res.map((image) => ({
 				type: 'image',
-				url: `data:image/png;base64,${image}`
+				url: `${image.url}`
 			}));
 
 			dispatch('save', message);
@@ -477,7 +480,7 @@
 													xmlns="http://www.w3.org/2000/svg"
 													fill="none"
 													viewBox="0 0 24 24"
-													stroke-width="1.5"
+													stroke-width="2"
 													stroke="currentColor"
 													class="w-4 h-4"
 												>
@@ -503,7 +506,7 @@
 													xmlns="http://www.w3.org/2000/svg"
 													fill="none"
 													viewBox="0 0 24 24"
-													stroke-width="1.5"
+													stroke-width="2"
 													stroke="currentColor"
 													class="w-4 h-4"
 												>
@@ -622,7 +625,7 @@
 														xmlns="http://www.w3.org/2000/svg"
 														fill="none"
 														viewBox="0 0 24 24"
-														stroke-width="1.5"
+														stroke-width="2"
 														stroke="currentColor"
 														class="w-4 h-4"
 													>
@@ -637,7 +640,7 @@
 														xmlns="http://www.w3.org/2000/svg"
 														fill="none"
 														viewBox="0 0 24 24"
-														stroke-width="1.5"
+														stroke-width="2"
 														stroke="currentColor"
 														class="w-4 h-4"
 													>
@@ -703,7 +706,7 @@
 															xmlns="http://www.w3.org/2000/svg"
 															fill="none"
 															viewBox="0 0 24 24"
-															stroke-width="1.5"
+															stroke-width="2"
 															stroke="currentColor"
 															class="w-4 h-4"
 														>
@@ -733,7 +736,7 @@
 														xmlns="http://www.w3.org/2000/svg"
 														fill="none"
 														viewBox="0 0 24 24"
-														stroke-width="1.5"
+														stroke-width="2"
 														stroke="currentColor"
 														class="w-4 h-4"
 													>
@@ -762,7 +765,7 @@
 														xmlns="http://www.w3.org/2000/svg"
 														fill="none"
 														viewBox="0 0 24 24"
-														stroke-width="1.5"
+														stroke-width="2"
 														stroke="currentColor"
 														class="w-4 h-4"
 													>
@@ -792,7 +795,7 @@
 														xmlns="http://www.w3.org/2000/svg"
 														fill="none"
 														viewBox="0 0 24 24"
-														stroke-width="1.5"
+														stroke-width="2"
 														stroke="currentColor"
 														class="w-4 h-4"
 													>

+ 3 - 3
src/lib/components/chat/Messages/UserMessage.svelte

@@ -258,7 +258,7 @@
 									xmlns="http://www.w3.org/2000/svg"
 									fill="none"
 									viewBox="0 0 24 24"
-									stroke-width="1.5"
+									stroke-width="2"
 									stroke="currentColor"
 									class="w-4 h-4"
 								>
@@ -282,7 +282,7 @@
 									xmlns="http://www.w3.org/2000/svg"
 									fill="none"
 									viewBox="0 0 24 24"
-									stroke-width="1.5"
+									stroke-width="2"
 									stroke="currentColor"
 									class="w-4 h-4"
 								>
@@ -307,7 +307,7 @@
 										xmlns="http://www.w3.org/2000/svg"
 										fill="none"
 										viewBox="0 0 24 24"
-										stroke-width="1.5"
+										stroke-width="2"
 										stroke="currentColor"
 										class="w-4 h-4"
 									>

+ 1 - 1
src/lib/components/chat/Settings/Account.svelte

@@ -271,7 +271,7 @@
 
 	<div class="flex justify-end pt-3 text-sm font-medium">
 		<button
-			class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded"
+			class="  px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
 			on:click={async () => {
 				const res = await submitHandler();
 

+ 1 - 1
src/lib/components/chat/Settings/Audio.svelte

@@ -251,7 +251,7 @@
 
 	<div class="flex justify-end pt-3 text-sm font-medium">
 		<button
-			class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded"
+			class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
 			type="submit"
 		>
 			Save

+ 1 - 1
src/lib/components/chat/Settings/Connections.svelte

@@ -247,7 +247,7 @@
 
 	<div class="flex justify-end pt-3 text-sm font-medium">
 		<button
-			class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded"
+			class="  px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
 			type="submit"
 		>
 			Save

+ 2 - 2
src/lib/components/chat/Settings/General.svelte

@@ -176,7 +176,7 @@
 			<div class=" my-2.5 text-sm font-medium">System Prompt</div>
 			<textarea
 				bind:value={system}
-				class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none"
+				class="w-full rounded-lg p-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
 				rows="4"
 			/>
 		</div>
@@ -262,7 +262,7 @@
 
 	<div class="flex justify-end pt-3 text-sm font-medium">
 		<button
-			class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded"
+			class="  px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
 			on:click={() => {
 				saveSettings({
 					system: system !== '' ? system : undefined,

+ 135 - 76
src/lib/components/chat/Settings/Images.svelte

@@ -5,16 +5,18 @@
 	import { config, user } from '$lib/stores';
 	import {
 		getAUTOMATIC1111Url,
-		getDefaultDiffusionModel,
-		getDiffusionModels,
-		getImageGenerationEnabledStatus,
+		getImageGenerationModels,
+		getDefaultImageGenerationModel,
+		updateDefaultImageGenerationModel,
 		getImageSize,
-		toggleImageGenerationEnabledStatus,
+		getImageGenerationConfig,
+		updateImageGenerationConfig,
 		updateAUTOMATIC1111Url,
-		updateDefaultDiffusionModel,
 		updateImageSize,
 		getImageSteps,
-		updateImageSteps
+		updateImageSteps,
+		getOpenAIKey,
+		updateOpenAIKey
 	} from '$lib/apis/images';
 	import { getBackendConfig } from '$lib/apis';
 	const dispatch = createEventDispatcher();
@@ -23,8 +25,11 @@
 
 	let loading = false;
 
+	let imageGenerationEngine = '';
 	let enableImageGeneration = false;
+
 	let AUTOMATIC1111_BASE_URL = '';
+	let OPENAI_API_KEY = '';
 
 	let selectedModel = '';
 	let models = null;
@@ -33,11 +38,11 @@
 	let steps = 50;
 
 	const getModels = async () => {
-		models = await getDiffusionModels(localStorage.token).catch((error) => {
+		models = await getImageGenerationModels(localStorage.token).catch((error) => {
 			toast.error(error);
 			return null;
 		});
-		selectedModel = await getDefaultDiffusionModel(localStorage.token).catch((error) => {
+		selectedModel = await getDefaultImageGenerationModel(localStorage.token).catch((error) => {
 			return '';
 		});
 	};
@@ -62,33 +67,45 @@
 			AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
 		}
 	};
-	const toggleImageGeneration = async () => {
-		if (AUTOMATIC1111_BASE_URL) {
-			enableImageGeneration = await toggleImageGenerationEnabledStatus(localStorage.token).catch(
-				(error) => {
-					toast.error(error);
-					return false;
-				}
-			);
+	const updateImageGeneration = async () => {
+		const res = await updateImageGenerationConfig(
+			localStorage.token,
+			imageGenerationEngine,
+			enableImageGeneration
+		).catch((error) => {
+			toast.error(error);
+			return null;
+		});
 
-			if (enableImageGeneration) {
-				config.set(await getBackendConfig(localStorage.token));
-				getModels();
-			}
-		} else {
-			enableImageGeneration = false;
-			toast.error('AUTOMATIC1111_BASE_URL not provided');
+		if (res) {
+			imageGenerationEngine = res.engine;
+			enableImageGeneration = res.enabled;
+		}
+
+		if (enableImageGeneration) {
+			config.set(await getBackendConfig(localStorage.token));
+			getModels();
 		}
 	};
 
 	onMount(async () => {
 		if ($user.role === 'admin') {
-			enableImageGeneration = await getImageGenerationEnabledStatus(localStorage.token);
+			const res = await getImageGenerationConfig(localStorage.token).catch((error) => {
+				toast.error(error);
+				return null;
+			});
+
+			if (res) {
+				imageGenerationEngine = res.engine;
+				enableImageGeneration = res.enabled;
+			}
 			AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
+			OPENAI_API_KEY = await getOpenAIKey(localStorage.token);
+
+			imageSize = await getImageSize(localStorage.token);
+			steps = await getImageSteps(localStorage.token);
 
-			if (enableImageGeneration && AUTOMATIC1111_BASE_URL) {
-				imageSize = await getImageSize(localStorage.token);
-				steps = await getImageSteps(localStorage.token);
+			if (enableImageGeneration) {
 				getModels();
 			}
 		}
@@ -99,7 +116,11 @@
 	class="flex flex-col h-full justify-between space-y-3 text-sm"
 	on:submit|preventDefault={async () => {
 		loading = true;
-		await updateDefaultDiffusionModel(localStorage.token, selectedModel);
+		await updateOpenAIKey(localStorage.token, OPENAI_API_KEY);
+
+		await updateDefaultImageGenerationModel(localStorage.token, selectedModel);
+
+		await updateDefaultImageGenerationModel(localStorage.token, selectedModel);
 		await updateImageSize(localStorage.token, imageSize).catch((error) => {
 			toast.error(error);
 			return null;
@@ -117,6 +138,23 @@
 		<div>
 			<div class=" mb-1 text-sm font-medium">Image Settings</div>
 
+			<div class=" py-0.5 flex w-full justify-between">
+				<div class=" self-center text-xs font-medium">Image Generation Engine</div>
+				<div class="flex items-center relative">
+					<select
+						class="w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
+						bind:value={imageGenerationEngine}
+						placeholder="Select a mode"
+						on:change={async () => {
+							await updateImageGeneration();
+						}}
+					>
+						<option value="">Default (Automatic1111)</option>
+						<option value="openai">Open AI (Dall-E)</option>
+					</select>
+				</div>
+			</div>
+
 			<div>
 				<div class=" py-0.5 flex w-full justify-between">
 					<div class=" self-center text-xs font-medium">Image Generation (Experimental)</div>
@@ -124,7 +162,17 @@
 					<button
 						class="p-1 px-3 text-xs flex rounded transition"
 						on:click={() => {
-							toggleImageGeneration();
+							if (imageGenerationEngine === '' && AUTOMATIC1111_BASE_URL === '') {
+								toast.error('AUTOMATIC1111 Base URL is required.');
+								enableImageGeneration = false;
+							} else if (imageGenerationEngine === 'openai' && OPENAI_API_KEY === '') {
+								toast.error('OpenAI API Key is required.');
+								enableImageGeneration = false;
+							} else {
+								enableImageGeneration = !enableImageGeneration;
+							}
+
+							updateImageGeneration();
 						}}
 						type="button"
 					>
@@ -139,49 +187,62 @@
 		</div>
 		<hr class=" dark:border-gray-700" />
 
-		<div class=" mb-2.5 text-sm font-medium">AUTOMATIC1111 Base URL</div>
-		<div class="flex w-full">
-			<div class="flex-1 mr-2">
-				<input
-					class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-					placeholder="Enter URL (e.g. http://127.0.0.1:7860/)"
-					bind:value={AUTOMATIC1111_BASE_URL}
-				/>
+		{#if imageGenerationEngine === ''}
+			<div class=" mb-2.5 text-sm font-medium">AUTOMATIC1111 Base URL</div>
+			<div class="flex w-full">
+				<div class="flex-1 mr-2">
+					<input
+						class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+						placeholder="Enter URL (e.g. http://127.0.0.1:7860/)"
+						bind:value={AUTOMATIC1111_BASE_URL}
+					/>
+				</div>
+				<button
+					class="px-3 bg-gray-200 hover:bg-gray-300 dark:bg-gray-600 dark:hover:bg-gray-700 rounded-lg transition"
+					type="button"
+					on:click={() => {
+						// updateOllamaAPIUrlHandler();
+
+						updateAUTOMATIC1111UrlHandler();
+					}}
+				>
+					<svg
+						xmlns="http://www.w3.org/2000/svg"
+						viewBox="0 0 20 20"
+						fill="currentColor"
+						class="w-4 h-4"
+					>
+						<path
+							fill-rule="evenodd"
+							d="M15.312 11.424a5.5 5.5 0 01-9.201 2.466l-.312-.311h2.433a.75.75 0 000-1.5H3.989a.75.75 0 00-.75.75v4.242a.75.75 0 001.5 0v-2.43l.31.31a7 7 0 0011.712-3.138.75.75 0 00-1.449-.39zm1.23-3.723a.75.75 0 00.219-.53V2.929a.75.75 0 00-1.5 0V5.36l-.31-.31A7 7 0 003.239 8.188a.75.75 0 101.448.389A5.5 5.5 0 0113.89 6.11l.311.31h-2.432a.75.75 0 000 1.5h4.243a.75.75 0 00.53-.219z"
+							clip-rule="evenodd"
+						/>
+					</svg>
+				</button>
 			</div>
-			<button
-				class="px-3 bg-gray-200 hover:bg-gray-300 dark:bg-gray-600 dark:hover:bg-gray-700 rounded transition"
-				type="button"
-				on:click={() => {
-					// updateOllamaAPIUrlHandler();
-
-					updateAUTOMATIC1111UrlHandler();
-				}}
-			>
-				<svg
-					xmlns="http://www.w3.org/2000/svg"
-					viewBox="0 0 20 20"
-					fill="currentColor"
-					class="w-4 h-4"
+
+			<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
+				Include `--api` flag when running stable-diffusion-webui
+				<a
+					class=" text-gray-300 font-medium"
+					href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/3734"
+					target="_blank"
 				>
-					<path
-						fill-rule="evenodd"
-						d="M15.312 11.424a5.5 5.5 0 01-9.201 2.466l-.312-.311h2.433a.75.75 0 000-1.5H3.989a.75.75 0 00-.75.75v4.242a.75.75 0 001.5 0v-2.43l.31.31a7 7 0 0011.712-3.138.75.75 0 00-1.449-.39zm1.23-3.723a.75.75 0 00.219-.53V2.929a.75.75 0 00-1.5 0V5.36l-.31-.31A7 7 0 003.239 8.188a.75.75 0 101.448.389A5.5 5.5 0 0113.89 6.11l.311.31h-2.432a.75.75 0 000 1.5h4.243a.75.75 0 00.53-.219z"
-						clip-rule="evenodd"
+					(e.g. `sh webui.sh --api`)
+				</a>
+			</div>
+		{:else if imageGenerationEngine === 'openai'}
+			<div class=" mb-2.5 text-sm font-medium">OpenAI API Key</div>
+			<div class="flex w-full">
+				<div class="flex-1 mr-2">
+					<input
+						class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+						placeholder="Enter API Key"
+						bind:value={OPENAI_API_KEY}
 					/>
-				</svg>
-			</button>
-		</div>
-
-		<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
-			Include `--api` flag when running stable-diffusion-webui
-			<a
-				class=" text-gray-300 font-medium"
-				href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/3734"
-				target="_blank"
-			>
-				(e.g. `sh webui.sh --api`)
-			</a>
-		</div>
+				</div>
+			</div>
+		{/if}
 
 		{#if enableImageGeneration}
 			<hr class=" dark:border-gray-700" />
@@ -191,7 +252,7 @@
 				<div class="flex w-full">
 					<div class="flex-1 mr-2">
 						<select
-							class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
+							class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
 							bind:value={selectedModel}
 							placeholder="Select a model"
 						>
@@ -199,9 +260,7 @@
 								<option value="" disabled selected>Select a model</option>
 							{/if}
 							{#each models ?? [] as model}
-								<option value={model.title} class="bg-gray-100 dark:bg-gray-700"
-									>{model.model_name}</option
-								>
+								<option value={model.id} class="bg-gray-100 dark:bg-gray-700">{model.name}</option>
 							{/each}
 						</select>
 					</div>
@@ -213,7 +272,7 @@
 				<div class="flex w-full">
 					<div class="flex-1 mr-2">
 						<input
-							class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
+							class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
 							placeholder="Enter Image Size (e.g. 512x512)"
 							bind:value={imageSize}
 						/>
@@ -226,7 +285,7 @@
 				<div class="flex w-full">
 					<div class="flex-1 mr-2">
 						<input
-							class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
+							class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
 							placeholder="Enter Number of Steps (e.g. 50)"
 							bind:value={steps}
 						/>
@@ -238,7 +297,7 @@
 
 	<div class="flex justify-end pt-3 text-sm font-medium">
 		<button
-			class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded flex flex-row space-x-1 items-center {loading
+			class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg flex flex-row space-x-1 items-center {loading
 				? ' cursor-not-allowed'
 				: ''}"
 			type="submit"

+ 5 - 27
src/lib/components/chat/Settings/Interface.svelte

@@ -63,6 +63,7 @@
 		}
 
 		saveSettings({
+			titleAutoGenerateModel: titleAutoGenerateModel !== '' ? titleAutoGenerateModel : undefined,
 			titleGenerationPrompt: titleGenerationPrompt ? titleGenerationPrompt : undefined
 		});
 	};
@@ -186,7 +187,7 @@
 			<div class="flex w-full">
 				<div class="flex-1 mr-2">
 					<select
-						class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
+						class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
 						bind:value={titleAutoGenerateModel}
 						placeholder="Select a model"
 					>
@@ -200,35 +201,12 @@
 						{/each}
 					</select>
 				</div>
-				<button
-					class="px-3 bg-gray-200 hover:bg-gray-300 dark:bg-gray-700 dark:hover:bg-gray-800 dark:text-gray-100 rounded transition"
-					on:click={() => {
-						saveSettings({
-							titleAutoGenerateModel:
-								titleAutoGenerateModel !== '' ? titleAutoGenerateModel : undefined
-						});
-					}}
-					type="button"
-				>
-					<svg
-						xmlns="http://www.w3.org/2000/svg"
-						viewBox="0 0 16 16"
-						fill="currentColor"
-						class="w-3.5 h-3.5"
-					>
-						<path
-							fill-rule="evenodd"
-							d="M13.836 2.477a.75.75 0 0 1 .75.75v3.182a.75.75 0 0 1-.75.75h-3.182a.75.75 0 0 1 0-1.5h1.37l-.84-.841a4.5 4.5 0 0 0-7.08.932.75.75 0 0 1-1.3-.75 6 6 0 0 1 9.44-1.242l.842.84V3.227a.75.75 0 0 1 .75-.75Zm-.911 7.5A.75.75 0 0 1 13.199 11a6 6 0 0 1-9.44 1.241l-.84-.84v1.371a.75.75 0 0 1-1.5 0V9.591a.75.75 0 0 1 .75-.75H5.35a.75.75 0 0 1 0 1.5H3.98l.841.841a4.5 4.5 0 0 0 7.08-.932.75.75 0 0 1 1.025-.273Z"
-							clip-rule="evenodd"
-						/>
-					</svg>
-				</button>
 			</div>
-			<div class="mt-3">
+			<div class="mt-3 mr-2">
 				<div class=" mb-2.5 text-sm font-medium">Title Generation Prompt</div>
 				<textarea
 					bind:value={titleGenerationPrompt}
-					class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none"
+					class="w-full rounded-lg p-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
 					rows="3"
 				/>
 			</div>
@@ -321,7 +299,7 @@
 
 	<div class="flex justify-end pt-3 text-sm font-medium">
 		<button
-			class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded"
+			class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
 			type="submit"
 		>
 			Save

+ 277 - 221
src/lib/components/chat/Settings/Models.svelte

@@ -14,6 +14,7 @@
 	import { splitStream } from '$lib/utils';
 	import { onMount } from 'svelte';
 	import { addLiteLLMModel, deleteLiteLLMModel, getLiteLLMModelInfo } from '$lib/apis/litellm';
+	import Tooltip from '$lib/components/common/Tooltip.svelte';
 
 	export let getModels: Function;
 
@@ -27,6 +28,7 @@
 	let liteLLMAPIBase = '';
 	let liteLLMAPIKey = '';
 	let liteLLMRPM = '';
+	let liteLLMMaxTokens = '';
 
 	let deleteLiteLLMModelId = '';
 
@@ -36,6 +38,10 @@
 
 	let OLLAMA_URLS = [];
 	let selectedOllamaUrlIdx: string | null = null;
+
+	let updateModelId = null;
+	let updateProgress = null;
+
 	let showExperimentalOllama = false;
 	let ollamaVersion = '';
 	const MAX_PARALLEL_DOWNLOADS = 3;
@@ -60,6 +66,71 @@
 
 	let deleteModelTag = '';
 
+	const updateModelsHandler = async () => {
+		for (const model of $models.filter(
+			(m) =>
+				m.size != null &&
+				(selectedOllamaUrlIdx === null ? true : (m?.urls ?? []).includes(selectedOllamaUrlIdx))
+		)) {
+			console.log(model);
+
+			updateModelId = model.id;
+			const res = await pullModel(localStorage.token, model.id, selectedOllamaUrlIdx).catch(
+				(error) => {
+					toast.error(error);
+					return null;
+				}
+			);
+
+			if (res) {
+				const reader = res.body
+					.pipeThrough(new TextDecoderStream())
+					.pipeThrough(splitStream('\n'))
+					.getReader();
+
+				while (true) {
+					try {
+						const { value, done } = await reader.read();
+						if (done) break;
+
+						let lines = value.split('\n');
+
+						for (const line of lines) {
+							if (line !== '') {
+								let data = JSON.parse(line);
+
+								console.log(data);
+								if (data.error) {
+									throw data.error;
+								}
+								if (data.detail) {
+									throw data.detail;
+								}
+								if (data.status) {
+									if (data.digest) {
+										updateProgress = 0;
+										if (data.completed) {
+											updateProgress = Math.round((data.completed / data.total) * 1000) / 10;
+										} else {
+											updateProgress = 100;
+										}
+									} else {
+										toast.success(data.status);
+									}
+								}
+							}
+						}
+					} catch (error) {
+						console.log(error);
+					}
+				}
+			}
+		}
+
+		updateModelId = null;
+		updateProgress = null;
+	};
+
 	const pullModelHandler = async () => {
 		const sanitizedModelTag = modelTag.trim();
 		if (modelDownloadStatus[sanitizedModelTag]) {
@@ -326,7 +397,8 @@
 				model: liteLLMModel,
 				api_base: liteLLMAPIBase,
 				api_key: liteLLMAPIKey,
-				rpm: liteLLMRPM
+				rpm: liteLLMRPM,
+				max_tokens: liteLLMMaxTokens
 			}).catch((error) => {
 				toast.error(error);
 				return null;
@@ -346,6 +418,7 @@
 		liteLLMAPIBase = '';
 		liteLLMAPIKey = '';
 		liteLLMRPM = '';
+		liteLLMMaxTokens = '';
 
 		liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token);
 		models.set(await getModels());
@@ -376,7 +449,7 @@
 			return [];
 		});
 
-		if (OLLAMA_URLS.length > 1) {
+		if (OLLAMA_URLS.length > 0) {
 			selectedOllamaUrlIdx = 0;
 		}
 
@@ -391,18 +464,51 @@
 			<div class="space-y-2 pr-1.5">
 				<div class="text-sm font-medium">Manage Ollama Models</div>
 
-				{#if OLLAMA_URLS.length > 1}
-					<div class="flex-1 pb-1">
-						<select
-							class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
-							bind:value={selectedOllamaUrlIdx}
-							placeholder="Select an Ollama instance"
-						>
-							{#each OLLAMA_URLS as url, idx}
-								<option value={idx} class="bg-gray-100 dark:bg-gray-700">{url}</option>
-							{/each}
-						</select>
+				{#if OLLAMA_URLS.length > 0}
+					<div class="flex gap-2">
+						<div class="flex-1 pb-1">
+							<select
+								class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+								bind:value={selectedOllamaUrlIdx}
+								placeholder="Select an Ollama instance"
+							>
+								{#each OLLAMA_URLS as url, idx}
+									<option value={idx} class="bg-gray-100 dark:bg-gray-700">{url}</option>
+								{/each}
+							</select>
+						</div>
+
+						<div>
+							<div class="flex w-full justify-end">
+								<Tooltip content="Update All Models" placement="top">
+									<button
+										class="p-2.5 flex gap-2 items-center bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
+										on:click={() => {
+											updateModelsHandler();
+										}}
+									>
+										<svg
+											xmlns="http://www.w3.org/2000/svg"
+											viewBox="0 0 16 16"
+											fill="currentColor"
+											class="w-4 h-4"
+										>
+											<path
+												d="M7 1a.75.75 0 0 1 .75.75V6h-1.5V1.75A.75.75 0 0 1 7 1ZM6.25 6v2.94L5.03 7.72a.75.75 0 0 0-1.06 1.06l2.5 2.5a.75.75 0 0 0 1.06 0l2.5-2.5a.75.75 0 1 0-1.06-1.06L7.75 8.94V6H10a2 2 0 0 1 2 2v3a2 2 0 0 1-2 2H4a2 2 0 0 1-2-2V8a2 2 0 0 1 2-2h2.25Z"
+											/>
+											<path
+												d="M4.268 14A2 2 0 0 0 6 15h6a2 2 0 0 0 2-2v-3a2 2 0 0 0-1-1.732V11a3 3 0 0 1-3 3H4.268Z"
+											/>
+										</svg>
+									</button>
+								</Tooltip>
+							</div>
+						</div>
 					</div>
+
+					{#if updateModelId}
+						Updating "{updateModelId}" {updateProgress ? `(${updateProgress}%)` : ''}
+					{/if}
 				{/if}
 
 				<div class="space-y-2">
@@ -467,12 +573,14 @@
 							</button>
 						</div>
 
-						<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
-							To access the available model names for downloading, <a
-								class=" text-gray-500 dark:text-gray-300 font-medium underline"
-								href="https://ollama.com/library"
-								target="_blank">click here.</a
-							>
+						<div>
+							<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
+								To access the available model names for downloading, <a
+									class=" text-gray-500 dark:text-gray-300 font-medium underline"
+									href="https://ollama.com/library"
+									target="_blank">click here.</a
+								>
+							</div>
 						</div>
 
 						{#if Object.keys(modelDownloadStatus).length > 0}
@@ -589,7 +697,7 @@
 												on:change={() => {
 													console.log(modelInputFile);
 												}}
-												accept=".gguf"
+												accept=".gguf,.safetensors"
 												required
 												hidden
 											/>
@@ -722,245 +830,193 @@
 		<div class=" space-y-3">
 			<div class="mt-2 space-y-3 pr-1.5">
 				<div>
-					<div class=" mb-2 text-sm font-medium">Manage LiteLLM Models</div>
-
-					<div>
+					<div class="mb-2">
 						<div class="flex justify-between items-center text-xs">
-							<div class=" text-sm font-medium">Add a model</div>
+							<div class=" text-sm font-medium">Manage LiteLLM Models</div>
 							<button
 								class=" text-xs font-medium text-gray-500"
 								type="button"
 								on:click={() => {
-									showLiteLLMParams = !showLiteLLMParams;
-								}}>{showLiteLLMParams ? 'Hide Additional Params' : 'Show Additional Params'}</button
+									showLiteLLM = !showLiteLLM;
+								}}>{showLiteLLM ? 'Hide' : 'Show'}</button
 							>
 						</div>
 					</div>
 
-					<div class="my-2 space-y-2">
-						<div class="flex w-full mb-1.5">
-							<div class="flex-1 mr-2">
-								<input
-									class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
-									placeholder="Enter LiteLLM Model (litellm_params.model)"
-									bind:value={liteLLMModel}
-									autocomplete="off"
-								/>
+					{#if showLiteLLM}
+						<div>
+							<div class="flex justify-between items-center text-xs">
+								<div class=" text-sm font-medium">Add a model</div>
+								<button
+									class=" text-xs font-medium text-gray-500"
+									type="button"
+									on:click={() => {
+										showLiteLLMParams = !showLiteLLMParams;
+									}}
+									>{showLiteLLMParams ? 'Hide Additional Params' : 'Show Additional Params'}</button
+								>
 							</div>
+						</div>
 
-							<button
-								class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
-								on:click={() => {
-									addLiteLLMModelHandler();
-								}}
-							>
-								<svg
-									xmlns="http://www.w3.org/2000/svg"
-									viewBox="0 0 16 16"
-									fill="currentColor"
-									class="w-4 h-4"
-								>
-									<path
-										d="M8.75 3.75a.75.75 0 0 0-1.5 0v3.5h-3.5a.75.75 0 0 0 0 1.5h3.5v3.5a.75.75 0 0 0 1.5 0v-3.5h3.5a.75.75 0 0 0 0-1.5h-3.5v-3.5Z"
+						<div class="my-2 space-y-2">
+							<div class="flex w-full mb-1.5">
+								<div class="flex-1 mr-2">
+									<input
+										class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+										placeholder="Enter LiteLLM Model (litellm_params.model)"
+										bind:value={liteLLMModel}
+										autocomplete="off"
 									/>
-								</svg>
-							</button>
-						</div>
+								</div>
 
-						{#if showLiteLLMParams}
-							<div>
-								<div class=" mb-1.5 text-sm font-medium">Model Name</div>
-								<div class="flex w-full">
-									<div class="flex-1">
-										<input
-											class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
-											placeholder="Enter Model Name (model_name)"
-											bind:value={liteLLMModelName}
-											autocomplete="off"
+								<button
+									class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
+									on:click={() => {
+										addLiteLLMModelHandler();
+									}}
+								>
+									<svg
+										xmlns="http://www.w3.org/2000/svg"
+										viewBox="0 0 16 16"
+										fill="currentColor"
+										class="w-4 h-4"
+									>
+										<path
+											d="M8.75 3.75a.75.75 0 0 0-1.5 0v3.5h-3.5a.75.75 0 0 0 0 1.5h3.5v3.5a.75.75 0 0 0 1.5 0v-3.5h3.5a.75.75 0 0 0 0-1.5h-3.5v-3.5Z"
 										/>
-									</div>
-								</div>
+									</svg>
+								</button>
 							</div>
 
-							<div>
-								<div class=" mb-1.5 text-sm font-medium">API Base URL</div>
-								<div class="flex w-full">
-									<div class="flex-1">
-										<input
-											class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
-											placeholder="Enter LiteLLM API Base URL (litellm_params.api_base)"
-											bind:value={liteLLMAPIBase}
-											autocomplete="off"
-										/>
+							{#if showLiteLLMParams}
+								<div>
+									<div class=" mb-1.5 text-sm font-medium">Model Name</div>
+									<div class="flex w-full">
+										<div class="flex-1">
+											<input
+												class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+												placeholder="Enter Model Name (model_name)"
+												bind:value={liteLLMModelName}
+												autocomplete="off"
+											/>
+										</div>
 									</div>
 								</div>
-							</div>
 
-							<div>
-								<div class=" mb-1.5 text-sm font-medium">API Key</div>
-								<div class="flex w-full">
-									<div class="flex-1">
-										<input
-											class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
-											placeholder="Enter LiteLLM API Key (litellm_params.api_key)"
-											bind:value={liteLLMAPIKey}
-											autocomplete="off"
-										/>
+								<div>
+									<div class=" mb-1.5 text-sm font-medium">API Base URL</div>
+									<div class="flex w-full">
+										<div class="flex-1">
+											<input
+												class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+												placeholder="Enter LiteLLM API Base URL (litellm_params.api_base)"
+												bind:value={liteLLMAPIBase}
+												autocomplete="off"
+											/>
+										</div>
 									</div>
 								</div>
-							</div>
 
-							<div>
-								<div class="mb-1.5 text-sm font-medium">API RPM</div>
-								<div class="flex w-full">
-									<div class="flex-1">
-										<input
-											class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
-											placeholder="Enter LiteLLM API RPM (litellm_params.rpm)"
-											bind:value={liteLLMRPM}
-											autocomplete="off"
-										/>
+								<div>
+									<div class=" mb-1.5 text-sm font-medium">API Key</div>
+									<div class="flex w-full">
+										<div class="flex-1">
+											<input
+												class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+												placeholder="Enter LiteLLM API Key (litellm_params.api_key)"
+												bind:value={liteLLMAPIKey}
+												autocomplete="off"
+											/>
+										</div>
 									</div>
 								</div>
-							</div>
-						{/if}
-					</div>
-
-					<div class="mb-2 text-xs text-gray-400 dark:text-gray-500">
-						Not sure what to add?
-						<a
-							class=" text-gray-300 font-medium underline"
-							href="https://litellm.vercel.app/docs/proxy/configs#quick-start"
-							target="_blank"
-						>
-							Click here for help.
-						</a>
-					</div>
-
-					<div>
-						<div class=" mb-2.5 text-sm font-medium">Delete a model</div>
-						<div class="flex w-full">
-							<div class="flex-1 mr-2">
-								<select
-									class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
-									bind:value={deleteLiteLLMModelId}
-									placeholder="Select a model"
-								>
-									{#if !deleteLiteLLMModelId}
-										<option value="" disabled selected>Select a model</option>
-									{/if}
-									{#each liteLLMModelInfo as model}
-										<option value={model.model_info.id} class="bg-gray-100 dark:bg-gray-700"
-											>{model.model_name}</option
-										>
-									{/each}
-								</select>
-							</div>
-							<button
-								class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
-								on:click={() => {
-									deleteLiteLLMModelHandler();
-								}}
-							>
-								<svg
-									xmlns="http://www.w3.org/2000/svg"
-									viewBox="0 0 16 16"
-									fill="currentColor"
-									class="w-4 h-4"
-								>
-									<path
-										fill-rule="evenodd"
-										d="M5 3.25V4H2.75a.75.75 0 0 0 0 1.5h.3l.815 8.15A1.5 1.5 0 0 0 5.357 15h5.285a1.5 1.5 0 0 0 1.493-1.35l.815-8.15h.3a.75.75 0 0 0 0-1.5H11v-.75A2.25 2.25 0 0 0 8.75 1h-1.5A2.25 2.25 0 0 0 5 3.25Zm2.25-.75a.75.75 0 0 0-.75.75V4h3v-.75a.75.75 0 0 0-.75-.75h-1.5ZM6.05 6a.75.75 0 0 1 .787.713l.275 5.5a.75.75 0 0 1-1.498.075l-.275-5.5A.75.75 0 0 1 6.05 6Zm3.9 0a.75.75 0 0 1 .712.787l-.275 5.5a.75.75 0 0 1-1.498-.075l.275-5.5a.75.75 0 0 1 .786-.711Z"
-										clip-rule="evenodd"
-									/>
-								</svg>
-							</button>
-						</div>
-					</div>
-				</div>
-			</div>
-
-			<!-- <div class="mt-2 space-y-3 pr-1.5">
-				<div>
-					<div class=" mb-2.5 text-sm font-medium">Add LiteLLM Model</div>
-					<div class="flex w-full mb-2">
-						<div class="flex-1">
-							<input
-								class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-								placeholder="Enter LiteLLM Model (e.g. ollama/mistral)"
-								bind:value={liteLLMModel}
-								autocomplete="off"
-							/>
-						</div>
-					</div>
 
-					<div class="flex justify-between items-center text-sm">
-						<div class="  font-medium">Advanced Model Params</div>
-						<button
-							class=" text-xs font-medium text-gray-500"
-							type="button"
-							on:click={() => {
-								showLiteLLMParams = !showLiteLLMParams;
-							}}>{showLiteLLMParams ? 'Hide' : 'Show'}</button
-						>
-					</div>
+								<div>
+									<div class="mb-1.5 text-sm font-medium">API RPM</div>
+									<div class="flex w-full">
+										<div class="flex-1">
+											<input
+												class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+												placeholder="Enter LiteLLM API RPM (litellm_params.rpm)"
+												bind:value={liteLLMRPM}
+												autocomplete="off"
+											/>
+										</div>
+									</div>
+								</div>
 
-					{#if showLiteLLMParams}
-						<div>
-							<div class=" mb-2.5 text-sm font-medium">LiteLLM API Key</div>
-							<div class="flex w-full">
-								<div class="flex-1">
-									<input
-										class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-										placeholder="Enter LiteLLM API Key (e.g. os.environ/AZURE_API_KEY_CA)"
-										bind:value={liteLLMAPIKey}
-										autocomplete="off"
-									/>
+								<div>
+									<div class="mb-1.5 text-sm font-medium">Max Tokens</div>
+									<div class="flex w-full">
+										<div class="flex-1">
+											<input
+												class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+												placeholder="Enter Max Tokens (litellm_params.max_tokens)"
+												bind:value={liteLLMMaxTokens}
+												type="number"
+												min="1"
+												autocomplete="off"
+											/>
+										</div>
+									</div>
 								</div>
-							</div>
+							{/if}
 						</div>
 
-						<div>
-							<div class=" mb-2.5 text-sm font-medium">LiteLLM API Base URL</div>
-							<div class="flex w-full">
-								<div class="flex-1">
-									<input
-										class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-										placeholder="Enter LiteLLM API Base URL"
-										bind:value={liteLLMAPIBase}
-										autocomplete="off"
-									/>
-								</div>
-							</div>
+						<div class="mb-2 text-xs text-gray-400 dark:text-gray-500">
+							Not sure what to add?
+							<a
+								class=" text-gray-300 font-medium underline"
+								href="https://litellm.vercel.app/docs/proxy/configs#quick-start"
+								target="_blank"
+							>
+								Click here for help.
+							</a>
 						</div>
 
 						<div>
-							<div class=" mb-2.5 text-sm font-medium">LiteLLM API RPM</div>
+							<div class=" mb-2.5 text-sm font-medium">Delete a model</div>
 							<div class="flex w-full">
-								<div class="flex-1">
-									<input
-										class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-										placeholder="Enter LiteLLM API RPM"
-										bind:value={liteLLMRPM}
-										autocomplete="off"
-									/>
+								<div class="flex-1 mr-2">
+									<select
+										class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+										bind:value={deleteLiteLLMModelId}
+										placeholder="Select a model"
+									>
+										{#if !deleteLiteLLMModelId}
+											<option value="" disabled selected>Select a model</option>
+										{/if}
+										{#each liteLLMModelInfo as model}
+											<option value={model.model_info.id} class="bg-gray-100 dark:bg-gray-700"
+												>{model.model_name}</option
+											>
+										{/each}
+									</select>
 								</div>
+								<button
+									class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
+									on:click={() => {
+										deleteLiteLLMModelHandler();
+									}}
+								>
+									<svg
+										xmlns="http://www.w3.org/2000/svg"
+										viewBox="0 0 16 16"
+										fill="currentColor"
+										class="w-4 h-4"
+									>
+										<path
+											fill-rule="evenodd"
+											d="M5 3.25V4H2.75a.75.75 0 0 0 0 1.5h.3l.815 8.15A1.5 1.5 0 0 0 5.357 15h5.285a1.5 1.5 0 0 0 1.493-1.35l.815-8.15h.3a.75.75 0 0 0 0-1.5H11v-.75A2.25 2.25 0 0 0 8.75 1h-1.5A2.25 2.25 0 0 0 5 3.25Zm2.25-.75a.75.75 0 0 0-.75.75V4h3v-.75a.75.75 0 0 0-.75-.75h-1.5ZM6.05 6a.75.75 0 0 1 .787.713l.275 5.5a.75.75 0 0 1-1.498.075l-.275-5.5A.75.75 0 0 1 6.05 6Zm3.9 0a.75.75 0 0 1 .712.787l-.275 5.5a.75.75 0 0 1-1.498-.075l.275-5.5a.75.75 0 0 1 .786-.711Z"
+											clip-rule="evenodd"
+										/>
+									</svg>
+								</button>
 							</div>
 						</div>
 					{/if}
-
-					<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
-						Not sure what to add?
-						<a
-							class=" text-gray-300 font-medium underline"
-							href="https://litellm.vercel.app/docs/proxy/configs#quick-start"
-							target="_blank"
-						>
-							Click here for help.
-						</a>
-					</div>
 				</div>
-			</div> -->
+			</div>
 		</div>
 	</div>
 </div>

+ 6 - 6
src/lib/components/chat/SettingsModal.svelte

@@ -326,7 +326,7 @@
 						{getModels}
 						{saveSettings}
 						on:save={() => {
-							show = false;
+							toast.success('Settings saved successfully!');
 						}}
 					/>
 				{:else if selectedTab === 'models'}
@@ -335,28 +335,28 @@
 					<Connections
 						{getModels}
 						on:save={() => {
-							show = false;
+							toast.success('Settings saved successfully!');
 						}}
 					/>
 				{:else if selectedTab === 'interface'}
 					<Interface
 						{saveSettings}
 						on:save={() => {
-							show = false;
+							toast.success('Settings saved successfully!');
 						}}
 					/>
 				{:else if selectedTab === 'audio'}
 					<Audio
 						{saveSettings}
 						on:save={() => {
-							show = false;
+							toast.success('Settings saved successfully!');
 						}}
 					/>
 				{:else if selectedTab === 'images'}
 					<Images
 						{saveSettings}
 						on:save={() => {
-							show = false;
+							toast.success('Settings saved successfully!');
 						}}
 					/>
 				{:else if selectedTab === 'chats'}
@@ -364,7 +364,7 @@
 				{:else if selectedTab === 'account'}
 					<Account
 						saveHandler={() => {
-							show = false;
+							toast.success('Settings saved successfully!');
 						}}
 					/>
 				{:else if selectedTab === 'about'}

+ 7 - 2
src/lib/components/common/Image.svelte

@@ -1,18 +1,23 @@
 <script lang="ts">
+	import { WEBUI_BASE_URL } from '$lib/constants';
 	import ImagePreview from './ImagePreview.svelte';
 
 	export let src = '';
 	export let alt = '';
 
+	let _src = '';
+
+	$: _src = src.startsWith('/') ? `${WEBUI_BASE_URL}${src}` : src;
+
 	let showImagePreview = false;
 </script>
 
-<ImagePreview bind:show={showImagePreview} {src} {alt} />
+<ImagePreview bind:show={showImagePreview} src={_src} {alt} />
 <button
 	on:click={() => {
 		console.log('image preview');
 		showImagePreview = true;
 	}}
 >
-	<img {src} {alt} class=" max-h-96 rounded-lg" draggable="false" />
+	<img src={_src} {alt} class=" max-h-96 rounded-lg" draggable="false" />
 </button>

+ 99 - 72
src/lib/components/documents/Settings/General.svelte

@@ -1,10 +1,10 @@
 <script lang="ts">
 	import { getDocs } from '$lib/apis/documents';
 	import {
-		getChunkParams,
+		getRAGConfig,
+		updateRAGConfig,
 		getQuerySettings,
 		scanDocs,
-		updateChunkParams,
 		updateQuerySettings
 	} from '$lib/apis/rag';
 	import { documents } from '$lib/stores';
@@ -17,6 +17,7 @@
 
 	let chunkSize = 0;
 	let chunkOverlap = 0;
+	let pdfExtractImages = true;
 
 	let querySettings = {
 		template: '',
@@ -35,16 +36,24 @@
 	};
 
 	const submitHandler = async () => {
-		const res = await updateChunkParams(localStorage.token, chunkSize, chunkOverlap);
+		const res = await updateRAGConfig(localStorage.token, {
+			pdf_extract_images: pdfExtractImages,
+			chunk: {
+				chunk_overlap: chunkOverlap,
+				chunk_size: chunkSize
+			}
+		});
 		querySettings = await updateQuerySettings(localStorage.token, querySettings);
 	};
 
 	onMount(async () => {
-		const res = await getChunkParams(localStorage.token);
+		const res = await getRAGConfig(localStorage.token);
 
 		if (res) {
-			chunkSize = res.chunk_size;
-			chunkOverlap = res.chunk_overlap;
+			pdfExtractImages = res.pdf_extract_images;
+
+			chunkSize = res.chunk.chunk_size;
+			chunkOverlap = res.chunk.chunk_overlap;
 		}
 
 		querySettings = await getQuerySettings(localStorage.token);
@@ -124,82 +133,100 @@
 
 		<hr class=" dark:border-gray-700" />
 
-		<div class=" ">
-			<div class=" text-sm font-medium">Chunk Params</div>
-
-			<div class=" flex">
-				<div class="  flex w-full justify-between">
-					<div class="self-center text-xs font-medium min-w-fit">Chunk Size</div>
-
-					<div class="self-center p-3">
-						<input
-							class=" w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
-							type="number"
-							placeholder="Enter Chunk Size"
-							bind:value={chunkSize}
-							autocomplete="off"
-							min="0"
-						/>
+		<div class=" space-y-3">
+			<div class=" space-y-3">
+				<div class=" text-sm font-medium">Chunk Params</div>
+
+				<div class=" flex gap-2">
+					<div class="  flex w-full justify-between gap-2">
+						<div class="self-center text-xs font-medium min-w-fit">Chunk Size</div>
+
+						<div class="self-center">
+							<input
+								class=" w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
+								type="number"
+								placeholder="Enter Chunk Size"
+								bind:value={chunkSize}
+								autocomplete="off"
+								min="0"
+							/>
+						</div>
 					</div>
-				</div>
 
-				<div class="flex w-full">
-					<div class=" self-center text-xs font-medium min-w-fit">Chunk Overlap</div>
-
-					<div class="self-center p-3">
-						<input
-							class="w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
-							type="number"
-							placeholder="Enter Chunk Overlap"
-							bind:value={chunkOverlap}
-							autocomplete="off"
-							min="0"
-						/>
+					<div class="flex w-full gap-2">
+						<div class=" self-center text-xs font-medium min-w-fit">Chunk Overlap</div>
+
+						<div class="self-center">
+							<input
+								class="w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
+								type="number"
+								placeholder="Enter Chunk Overlap"
+								bind:value={chunkOverlap}
+								autocomplete="off"
+								min="0"
+							/>
+						</div>
 					</div>
 				</div>
-			</div>
-
-			<div class=" text-sm font-medium">Query Params</div>
 
-			<div class=" flex">
-				<div class="  flex w-full justify-between">
-					<div class="self-center text-xs font-medium flex-1">Top K</div>
-
-					<div class="self-center p-3">
-						<input
-							class=" w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
-							type="number"
-							placeholder="Enter Top K"
-							bind:value={querySettings.k}
-							autocomplete="off"
-							min="0"
-						/>
+				<div>
+					<div class="flex justify-between items-center text-xs">
+						<div class=" text-xs font-medium">PDF Extract Images (OCR)</div>
+
+						<button
+							class=" text-xs font-medium text-gray-500"
+							type="button"
+							on:click={() => {
+								pdfExtractImages = !pdfExtractImages;
+							}}>{pdfExtractImages ? 'On' : 'Off'}</button
+						>
 					</div>
 				</div>
-
-				<!-- <div class="flex w-full">
-					<div class=" self-center text-xs font-medium min-w-fit">Chunk Overlap</div>
-
-					<div class="self-center p-3">
-						<input
-							class="w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
-							type="number"
-							placeholder="Enter Chunk Overlap"
-							bind:value={chunkOverlap}
-							autocomplete="off"
-							min="0"
-						/>
-					</div>
-				</div> -->
 			</div>
 
 			<div>
-				<div class=" mb-2.5 text-sm font-medium">RAG Template</div>
-				<textarea
-					bind:value={querySettings.template}
-					class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none"
-					rows="4"
-				/>
+				<div class=" text-sm font-medium">Query Params</div>
+
+				<div class=" flex py-2">
+					<div class="  flex w-full justify-between gap-2">
+						<div class="self-center text-xs font-medium flex-1">Top K</div>
+
+						<div class="self-center">
+							<input
+								class=" w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
+								type="number"
+								placeholder="Enter Top K"
+								bind:value={querySettings.k}
+								autocomplete="off"
+								min="0"
+							/>
+						</div>
+					</div>
+
+					<!-- <div class="flex w-full">
+						<div class=" self-center text-xs font-medium min-w-fit">Chunk Overlap</div>
+	
+						<div class="self-center p-3">
+							<input
+								class="w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
+								type="number"
+								placeholder="Enter Chunk Overlap"
+								bind:value={chunkOverlap}
+								autocomplete="off"
+								min="0"
+							/>
+						</div>
+					</div> -->
+				</div>
+
+				<div>
+					<div class=" mb-2.5 text-sm font-medium">RAG Template</div>
+					<textarea
+						bind:value={querySettings.template}
+						class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none"
+						rows="4"
+					/>
+				</div>
 			</div>
 		</div>
 	</div>

+ 12 - 7
src/lib/components/layout/Sidebar.svelte

@@ -61,12 +61,16 @@
 	};
 
 	const editChatTitle = async (id, _title) => {
-		title = _title;
-
-		await updateChatById(localStorage.token, id, {
-			title: _title
-		});
-		await chats.set(await getChatList(localStorage.token));
+		if (_title === '') {
+			toast.error('Title cannot be an empty string.');
+		} else {
+			title = _title;
+
+			await updateChatById(localStorage.token, id, {
+				title: _title
+			});
+			await chats.set(await getChatList(localStorage.token));
+		}
 	};
 
 	const deleteChat = async (id) => {
@@ -388,12 +392,13 @@
 										show = false;
 									}
 								}}
+								draggable="false"
 							>
 								<div class=" flex self-center flex-1 w-full">
 									<div
 										class=" text-left self-center overflow-hidden {chat.id === $chatId
 											? 'w-[160px]'
-											: 'w-full'} "
+											: 'w-full'}  h-[20px]"
 									>
 										{chat.title}
 									</div>

+ 45 - 59
src/routes/(app)/+page.svelte

@@ -232,53 +232,6 @@
 	const sendPrompt = async (prompt, parentId) => {
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 
-		const docs = messages
-			.filter((message) => message?.files ?? null)
-			.map((message) =>
-				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
-			)
-			.flat(1);
-
-		console.log(docs);
-		if (docs.length > 0) {
-			processing = 'Reading';
-			const query = history.messages[parentId].content;
-
-			let relevantContexts = await Promise.all(
-				docs.map(async (doc) => {
-					if (doc.type === 'collection') {
-						return await queryCollection(localStorage.token, doc.collection_names, query).catch(
-							(error) => {
-								console.log(error);
-								return null;
-							}
-						);
-					} else {
-						return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => {
-							console.log(error);
-							return null;
-						});
-					}
-				})
-			);
-			relevantContexts = relevantContexts.filter((context) => context);
-
-			const contextString = relevantContexts.reduce((a, context, i, arr) => {
-				return `${a}${context.documents.join(' ')}\n`;
-			}, '');
-
-			console.log(contextString);
-
-			history.messages[parentId].raContent = await RAGTemplate(
-				localStorage.token,
-				contextString,
-				query
-			);
-			history.messages[parentId].contexts = relevantContexts;
-			await tick();
-			processing = '';
-		}
-
 		await Promise.all(
 			selectedModels.map(async (modelId) => {
 				const model = $models.filter((m) => m.id === modelId).at(0);
@@ -342,15 +295,25 @@
 			...messages
 		]
 			.filter((message) => message)
-			.map((message, idx, arr) => ({
-				role: message.role,
-				content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content,
-				...(message.files && {
-					images: message.files
-						.filter((file) => file.type === 'image')
-						.map((file) => file.url.slice(file.url.indexOf(',') + 1))
-				})
-			}));
+			.map((message, idx, arr) => {
+				// Prepare the base message object
+				const baseMessage = {
+					role: message.role,
+					content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content
+				};
+
+				// Extract and format image URLs if any exist
+				const imageUrls = message.files
+					?.filter((file) => file.type === 'image')
+					.map((file) => file.url.slice(file.url.indexOf(',') + 1));
+
+				// Add images array only if it contains elements
+				if (imageUrls && imageUrls.length > 0) {
+					baseMessage.images = imageUrls;
+				}
+
+				return baseMessage;
+			});
 
 		let lastImageIndex = -1;
 
@@ -368,6 +331,13 @@
 			}
 		});
 
+		const docs = messages
+			.filter((message) => message?.files ?? null)
+			.map((message) =>
+				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+			)
+			.flat(1);
+
 		const [res, controller] = await generateChatCompletion(localStorage.token, {
 			model: model,
 			messages: messagesBody,
@@ -375,7 +345,8 @@
 				...($settings.options ?? {})
 			},
 			format: $settings.requestFormat ?? undefined,
-			keep_alive: $settings.keepAlive ?? undefined
+			keep_alive: $settings.keepAlive ?? undefined,
+			docs: docs.length > 0 ? docs : undefined
 		});
 
 		if (res && res.ok) {
@@ -535,6 +506,15 @@
 		const responseMessage = history.messages[responseMessageId];
 		scrollToBottom();
 
+		const docs = messages
+			.filter((message) => message?.files ?? null)
+			.map((message) =>
+				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+			)
+			.flat(1);
+
+		console.log(docs);
+
 		const res = await generateOpenAIChatCompletion(
 			localStorage.token,
 			{
@@ -583,7 +563,8 @@
 				top_p: $settings?.options?.top_p ?? undefined,
 				num_ctx: $settings?.options?.num_ctx ?? undefined,
 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
-				max_tokens: $settings?.options?.num_predict ?? undefined
+				max_tokens: $settings?.options?.num_predict ?? undefined,
+				docs: docs.length > 0 ? docs : undefined
 			},
 			model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
 		);
@@ -694,7 +675,12 @@
 
 		if (messages.length == 2) {
 			window.history.replaceState(history.state, '', `/c/${_chatId}`);
-			await setChatTitle(_chatId, userPrompt);
+
+			if ($settings?.titleAutoGenerateModel) {
+				await generateChatTitle(_chatId, userPrompt);
+			} else {
+				await setChatTitle(_chatId, userPrompt);
+			}
 		}
 	};
 

+ 40 - 58
src/routes/(app)/c/[id]/+page.svelte

@@ -245,53 +245,6 @@
 	const sendPrompt = async (prompt, parentId) => {
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 
-		const docs = messages
-			.filter((message) => message?.files ?? null)
-			.map((message) =>
-				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
-			)
-			.flat(1);
-
-		console.log(docs);
-		if (docs.length > 0) {
-			processing = 'Reading';
-			const query = history.messages[parentId].content;
-
-			let relevantContexts = await Promise.all(
-				docs.map(async (doc) => {
-					if (doc.type === 'collection') {
-						return await queryCollection(localStorage.token, doc.collection_names, query).catch(
-							(error) => {
-								console.log(error);
-								return null;
-							}
-						);
-					} else {
-						return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => {
-							console.log(error);
-							return null;
-						});
-					}
-				})
-			);
-			relevantContexts = relevantContexts.filter((context) => context);
-
-			const contextString = relevantContexts.reduce((a, context, i, arr) => {
-				return `${a}${context.documents.join(' ')}\n`;
-			}, '');
-
-			console.log(contextString);
-
-			history.messages[parentId].raContent = await RAGTemplate(
-				localStorage.token,
-				contextString,
-				query
-			);
-			history.messages[parentId].contexts = relevantContexts;
-			await tick();
-			processing = '';
-		}
-
 		await Promise.all(
 			selectedModels.map(async (modelId) => {
 				const model = $models.filter((m) => m.id === modelId).at(0);
@@ -355,15 +308,25 @@
 			...messages
 		]
 			.filter((message) => message)
-			.map((message, idx, arr) => ({
-				role: message.role,
-				content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content,
-				...(message.files && {
-					images: message.files
-						.filter((file) => file.type === 'image')
-						.map((file) => file.url.slice(file.url.indexOf(',') + 1))
-				})
-			}));
+			.map((message, idx, arr) => {
+				// Prepare the base message object
+				const baseMessage = {
+					role: message.role,
+					content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content
+				};
+
+				// Extract and format image URLs if any exist
+				const imageUrls = message.files
+					?.filter((file) => file.type === 'image')
+					.map((file) => file.url.slice(file.url.indexOf(',') + 1));
+
+				// Add images array only if it contains elements
+				if (imageUrls && imageUrls.length > 0) {
+					baseMessage.images = imageUrls;
+				}
+
+				return baseMessage;
+			});
 
 		let lastImageIndex = -1;
 
@@ -381,6 +344,13 @@
 			}
 		});
 
+		const docs = messages
+			.filter((message) => message?.files ?? null)
+			.map((message) =>
+				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+			)
+			.flat(1);
+
 		const [res, controller] = await generateChatCompletion(localStorage.token, {
 			model: model,
 			messages: messagesBody,
@@ -388,7 +358,8 @@
 				...($settings.options ?? {})
 			},
 			format: $settings.requestFormat ?? undefined,
-			keep_alive: $settings.keepAlive ?? undefined
+			keep_alive: $settings.keepAlive ?? undefined,
+			docs: docs.length > 0 ? docs : undefined
 		});
 
 		if (res && res.ok) {
@@ -548,6 +519,15 @@
 		const responseMessage = history.messages[responseMessageId];
 		scrollToBottom();
 
+		const docs = messages
+			.filter((message) => message?.files ?? null)
+			.map((message) =>
+				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+			)
+			.flat(1);
+
+		console.log(docs);
+
 		const res = await generateOpenAIChatCompletion(
 			localStorage.token,
 			{
@@ -596,7 +576,8 @@
 				top_p: $settings?.options?.top_p ?? undefined,
 				num_ctx: $settings?.options?.num_ctx ?? undefined,
 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
-				max_tokens: $settings?.options?.num_predict ?? undefined
+				max_tokens: $settings?.options?.num_predict ?? undefined,
+				docs: docs.length > 0 ? docs : undefined
 			},
 			model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
 		);
@@ -710,6 +691,7 @@
 			await setChatTitle(_chatId, userPrompt);
 		}
 	};
+
 	const stopResponse = () => {
 		stopResponseFlag = true;
 		console.log('stopResponse');

+ 1 - 1
src/routes/(app)/playground/+page.svelte

@@ -267,7 +267,7 @@
 
 <div class="min-h-screen max-h-[100dvh] w-full flex justify-center dark:text-white">
 	<div class=" flex flex-col justify-between w-full overflow-y-auto h-[100dvh]">
-		<div class="max-w-2xl mx-auto w-full px-3 p-3 md:px-0 h-full">
+		<div class="max-w-2xl mx-auto w-full px-3 md:px-0 my-10 h-full">
 			<div class=" flex flex-col h-full">
 				<div class="flex flex-col justify-between mb-2.5 gap-1">
 					<div class="flex justify-between items-center gap-2">