浏览代码

Merge remote-tracking branch 'upstream/dev' into feat/add-i18n

Ased Mammad 1 年之前
父节点
当前提交
f221e39c24

+ 5 - 5
.github/workflows/build-release.yml

@@ -29,11 +29,11 @@ jobs:
     - name: Extract latest CHANGELOG entry
       id: changelog
       run: |
-        CHANGELOG_CONTENT=$(awk '/^## \[/{n++} n==1' CHANGELOG.md)
-        echo "CHANGELOG_CONTENT<<EOF" 
-        echo "$CHANGELOG_CONTENT"
-        echo "EOF" 
-        echo "::set-output name=content::${CHANGELOG_CONTENT}"
+        CHANGELOG_CONTENT=$(awk 'BEGIN {print_section=0;} /^## \[/ {if (print_section == 0) {print_section=1;} else {exit;}} print_section {print;}' CHANGELOG.md)
+        CHANGELOG_ESCAPED=$(echo "$CHANGELOG_CONTENT" | sed ':a;N;$!ba;s/\n/%0A/g')
+        echo "Extracted latest release notes from CHANGELOG.md:" 
+        echo -e "$CHANGELOG_CONTENT" 
+        echo "::set-output name=content::$CHANGELOG_ESCAPED"
 
     - name: Create GitHub release
       uses: actions/github-script@v5

+ 2 - 2
Dockerfile

@@ -41,7 +41,7 @@ ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
 # for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
 # IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
 ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2"
-# device type for whisper tts and ebbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
+# device type for whisper tts and embbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
 ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu"
 ENV RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models"
 ENV SENTENCE_TRANSFORMERS_HOME $RAG_EMBEDDING_MODEL_DIR
@@ -81,4 +81,4 @@ COPY --from=build /app/package.json /app/package.json
 # copy backend files
 COPY ./backend .
 
-CMD [ "bash", "start.sh"]
+CMD [ "bash", "start.sh"]

+ 10 - 2
README.md

@@ -53,8 +53,6 @@ User-friendly WebUI for LLMs, supported LLM runners include Ollama and OpenAI-co
 
 - 💬 **Collaborative Chat**: Harness the collective intelligence of multiple models by seamlessly orchestrating group conversations. Use the `@` command to specify the model, enabling dynamic and diverse dialogues within your chat interface. Immerse yourself in the collective intelligence woven into your chat environment.
 
-- 🤝 **OpenAI API Integration**: Effortlessly integrate OpenAI-compatible API for versatile conversations alongside Ollama models. Customize the API Base URL to link with **LMStudio, Mistral, OpenRouter, and more**.
-
 - 🔄 **Regeneration History Access**: Easily revisit and explore your entire regeneration history.
 
 - 📜 **Chat History**: Effortlessly access and manage your conversation history.
@@ -65,8 +63,18 @@ User-friendly WebUI for LLMs, supported LLM runners include Ollama and OpenAI-co
 
 - ⚙️ **Fine-Tuned Control with Advanced Parameters**: Gain a deeper level of control by adjusting parameters such as temperature and defining your system prompts to tailor the conversation to your specific preferences and needs.
 
+- 🎨🤖 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using AUTOMATIC1111 API (local) and DALL-E, enriching your chat experience with dynamic visual content.
+
+- 🤝 **OpenAI API Integration**: Effortlessly integrate OpenAI-compatible API for versatile conversations alongside Ollama models. Customize the API Base URL to link with **LMStudio, Mistral, OpenRouter, and more**.
+
+- ✨ **Multiple OpenAI-Compatible API Support**: Seamlessly integrate and customize various OpenAI-compatible APIs, enhancing the versatility of your chat interactions.
+
 - 🔗 **External Ollama Server Connection**: Seamlessly link to an external Ollama server hosted on a different address by configuring the environment variable.
 
+- 🔀 **Multiple Ollama Instance Load Balancing**: Effortlessly distribute chat requests across multiple Ollama instances for enhanced performance and reliability.
+
+- 👥 **Multi-User Management**: Easily oversee and administer users via our intuitive admin panel, streamlining user management processes.
+
 - 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
 
 - 🔒 **Backend Reverse Proxy Support**: Bolster security through direct communication between Open WebUI backend and Ollama. This key feature eliminates the need to expose Ollama over LAN. Requests made to the '/ollama/api' route from the web UI are seamlessly redirected to Ollama from the backend, enhancing overall system security.

+ 187 - 50
backend/apps/images/main.py

@@ -21,7 +21,16 @@ from utils.utils import (
 from utils.misc import calculate_sha256
 from typing import Optional
 from pydantic import BaseModel
-from config import AUTOMATIC1111_BASE_URL
+from pathlib import Path
+import uuid
+import base64
+import json
+
+from config import CACHE_DIR, AUTOMATIC1111_BASE_URL
+
+
+IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
+IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
 
 app = FastAPI()
 app.add_middleware(
@@ -32,25 +41,34 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
+app.state.ENGINE = ""
+app.state.ENABLED = False
+
+app.state.OPENAI_API_KEY = ""
+app.state.MODEL = ""
+
+
 app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
-app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != ""
+
 app.state.IMAGE_SIZE = "512x512"
 app.state.IMAGE_STEPS = 50
 
 
-@app.get("/enabled", response_model=bool)
-async def get_enable_status(request: Request, user=Depends(get_admin_user)):
-    return app.state.ENABLED
+@app.get("/config")
+async def get_config(request: Request, user=Depends(get_admin_user)):
+    return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
 
 
-@app.get("/enabled/toggle", response_model=bool)
-async def toggle_enabled(request: Request, user=Depends(get_admin_user)):
-    try:
-        r = requests.head(app.state.AUTOMATIC1111_BASE_URL)
-        app.state.ENABLED = not app.state.ENABLED
-        return app.state.ENABLED
-    except Exception as e:
-        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
+class ConfigUpdateForm(BaseModel):
+    engine: str
+    enabled: bool
+
+
+@app.post("/config/update")
+async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
+    app.state.ENGINE = form_data.engine
+    app.state.ENABLED = form_data.enabled
+    return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
 
 
 class UrlUpdateForm(BaseModel):
@@ -58,17 +76,24 @@ class UrlUpdateForm(BaseModel):
 
 
 @app.get("/url")
-async def get_openai_url(user=Depends(get_admin_user)):
+async def get_automatic1111_url(user=Depends(get_admin_user)):
     return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
 
 
 @app.post("/url/update")
-async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
+async def update_automatic1111_url(
+    form_data: UrlUpdateForm, user=Depends(get_admin_user)
+):
 
     if form_data.url == "":
         app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
     else:
-        app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/")
+        url = form_data.url.strip("/")
+        try:
+            r = requests.head(url)
+            app.state.AUTOMATIC1111_BASE_URL = url
+        except Exception as e:
+            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
 
     return {
         "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
@@ -76,6 +101,30 @@ async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_use
     }
 
 
+class OpenAIKeyUpdateForm(BaseModel):
+    key: str
+
+
+@app.get("/key")
+async def get_openai_key(user=Depends(get_admin_user)):
+    return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
+
+
+@app.post("/key/update")
+async def update_openai_key(
+    form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user)
+):
+
+    if form_data.key == "":
+        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
+
+    app.state.OPENAI_API_KEY = form_data.key
+    return {
+        "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
+        "status": True,
+    }
+
+
 class ImageSizeUpdateForm(BaseModel):
     size: str
 
@@ -132,9 +181,22 @@ async def update_image_size(
 @app.get("/models")
 def get_models(user=Depends(get_current_user)):
     try:
-        r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models")
-        models = r.json()
-        return models
+        if app.state.ENGINE == "openai":
+            return [
+                {"id": "dall-e-2", "name": "DALL·E 2"},
+                {"id": "dall-e-3", "name": "DALL·E 3"},
+            ]
+        else:
+            r = requests.get(
+                url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
+            )
+            models = r.json()
+            return list(
+                map(
+                    lambda model: {"id": model["title"], "name": model["model_name"]},
+                    models,
+                )
+            )
     except Exception as e:
         app.state.ENABLED = False
         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@@ -143,10 +205,12 @@ def get_models(user=Depends(get_current_user)):
 @app.get("/models/default")
 async def get_default_model(user=Depends(get_admin_user)):
     try:
-        r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
-        options = r.json()
-
-        return {"model": options["sd_model_checkpoint"]}
+        if app.state.ENGINE == "openai":
+            return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
+        else:
+            r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
+            options = r.json()
+            return {"model": options["sd_model_checkpoint"]}
     except Exception as e:
         app.state.ENABLED = False
         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@@ -157,16 +221,21 @@ class UpdateModelForm(BaseModel):
 
 
 def set_model_handler(model: str):
-    r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
-    options = r.json()
 
-    if model != options["sd_model_checkpoint"]:
-        options["sd_model_checkpoint"] = model
-        r = requests.post(
-            url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
-        )
+    if app.state.ENGINE == "openai":
+        app.state.MODEL = model
+        return app.state.MODEL
+    else:
+        r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
+        options = r.json()
+
+        if model != options["sd_model_checkpoint"]:
+            options["sd_model_checkpoint"] = model
+            r = requests.post(
+                url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
+            )
 
-    return options
+        return options
 
 
 @app.post("/models/default/update")
@@ -185,6 +254,24 @@ class GenerateImageForm(BaseModel):
     negative_prompt: Optional[str] = None
 
 
+def save_b64_image(b64_str):
+    image_id = str(uuid.uuid4())
+    file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
+
+    try:
+        # Split the base64 string to get the actual image data
+        img_data = base64.b64decode(b64_str)
+
+        # Write the image data to a file
+        with open(file_path, "wb") as f:
+            f.write(img_data)
+
+        return image_id
+    except Exception as e:
+        print(f"Error saving image: {e}")
+        return None
+
+
 @app.post("/generations")
 def generate_image(
     form_data: GenerateImageForm,
@@ -194,32 +281,82 @@ def generate_image(
     print(form_data)
 
     try:
-        if form_data.model:
-            set_model_handler(form_data.model)
+        if app.state.ENGINE == "openai":
 
-        width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
+            headers = {}
+            headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
+            headers["Content-Type"] = "application/json"
 
-        data = {
-            "prompt": form_data.prompt,
-            "batch_size": form_data.n,
-            "width": width,
-            "height": height,
-        }
+            data = {
+                "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
+                "prompt": form_data.prompt,
+                "n": form_data.n,
+                "size": form_data.size,
+                "response_format": "b64_json",
+            }
 
-        if app.state.IMAGE_STEPS != None:
-            data["steps"] = app.state.IMAGE_STEPS
+            r = requests.post(
+                url=f"https://api.openai.com/v1/images/generations",
+                json=data,
+                headers=headers,
+            )
 
-        if form_data.negative_prompt != None:
-            data["negative_prompt"] = form_data.negative_prompt
+            r.raise_for_status()
 
-        print(data)
+            res = r.json()
 
-        r = requests.post(
-            url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
-            json=data,
-        )
+            images = []
+
+            for image in res["data"]:
+                image_id = save_b64_image(image["b64_json"])
+                images.append({"url": f"/cache/image/generations/{image_id}.png"})
+                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
+
+                with open(file_body_path, "w") as f:
+                    json.dump(data, f)
+
+            return images
+
+        else:
+            if form_data.model:
+                set_model_handler(form_data.model)
+
+            width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
+
+            data = {
+                "prompt": form_data.prompt,
+                "batch_size": form_data.n,
+                "width": width,
+                "height": height,
+            }
+
+            if app.state.IMAGE_STEPS != None:
+                data["steps"] = app.state.IMAGE_STEPS
+
+            if form_data.negative_prompt != None:
+                data["negative_prompt"] = form_data.negative_prompt
+
+            r = requests.post(
+                url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
+                json=data,
+            )
+
+            res = r.json()
+
+            print(res)
+
+            images = []
+
+            for image in res["images"]:
+                image_id = save_b64_image(image)
+                images.append({"url": f"/cache/image/generations/{image_id}.png"})
+                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
+
+                with open(file_body_path, "w") as f:
+                    json.dump({**data, "info": res["info"]}, f)
+
+            return images
 
-        return r.json()
     except Exception as e:
         print(e)
         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))

+ 41 - 0
backend/apps/litellm/main.py

@@ -0,0 +1,41 @@
+from litellm.proxy.proxy_server import ProxyConfig, initialize
+from litellm.proxy.proxy_server import app
+
+from fastapi import FastAPI, Request, Depends, status
+from fastapi.responses import JSONResponse
+from utils.utils import get_http_authorization_cred, get_current_user
+from config import ENV
+
+proxy_config = ProxyConfig()
+
+
+async def config():
+    router, model_list, general_settings = await proxy_config.load_config(
+        router=None, config_file_path="./data/litellm/config.yaml"
+    )
+
+    await initialize(config="./data/litellm/config.yaml", telemetry=False)
+
+
+async def startup():
+    await config()
+
+
+@app.on_event("startup")
+async def on_startup():
+    await startup()
+
+
+@app.middleware("http")
+async def auth_middleware(request: Request, call_next):
+    auth_header = request.headers.get("Authorization", "")
+
+    if ENV != "dev":
+        try:
+            user = get_current_user(get_http_authorization_cred(auth_header))
+            print(user)
+        except Exception as e:
+            return JSONResponse(status_code=400, content={"detail": str(e)})
+
+    response = await call_next(request)
+    return response

+ 3 - 37
backend/main.py

@@ -9,17 +9,14 @@ import requests
 from fastapi import FastAPI, Request, Depends, status
 from fastapi.staticfiles import StaticFiles
 from fastapi import HTTPException
-from fastapi.responses import JSONResponse
 from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from starlette.exceptions import HTTPException as StarletteHTTPException
 
 
-from litellm.proxy.proxy_server import ProxyConfig, initialize
-from litellm.proxy.proxy_server import app as litellm_app
-
 from apps.ollama.main import app as ollama_app
 from apps.openai.main import app as openai_app
+from apps.litellm.main import app as litellm_app, startup as litellm_app_startup
 from apps.audio.main import app as audio_app
 from apps.images.main import app as images_app
 from apps.rag.main import app as rag_app
@@ -29,8 +26,6 @@ from apps.web.main import app as webui_app
 from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
 from constants import ERROR_MESSAGES
 
-from utils.utils import get_http_authorization_cred, get_current_user
-
 
 class SPAStaticFiles(StaticFiles):
     async def get_response(self, path: str, scope):
@@ -43,21 +38,6 @@ class SPAStaticFiles(StaticFiles):
                 raise ex
 
 
-proxy_config = ProxyConfig()
-
-
-async def config():
-    router, model_list, general_settings = await proxy_config.load_config(
-        router=None, config_file_path="./data/litellm/config.yaml"
-    )
-
-    await initialize(config="./data/litellm/config.yaml", telemetry=False)
-
-
-async def startup():
-    await config()
-
-
 app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
 
 origins = ["*"]
@@ -73,7 +53,7 @@ app.add_middleware(
 
 @app.on_event("startup")
 async def on_startup():
-    await startup()
+    await litellm_app_startup()
 
 
 @app.middleware("http")
@@ -86,21 +66,6 @@ async def check_url(request: Request, call_next):
     return response
 
 
-@litellm_app.middleware("http")
-async def auth_middleware(request: Request, call_next):
-    auth_header = request.headers.get("Authorization", "")
-
-    if ENV != "dev":
-        try:
-            user = get_current_user(get_http_authorization_cred(auth_header))
-            print(user)
-        except Exception as e:
-            return JSONResponse(status_code=400, content={"detail": str(e)})
-
-    response = await call_next(request)
-    return response
-
-
 app.mount("/api/v1", webui_app)
 app.mount("/litellm/api", litellm_app)
 
@@ -156,6 +121,7 @@ async def get_app_latest_release_version():
 
 
 app.mount("/static", StaticFiles(directory="static"), name="static")
+app.mount("/cache", StaticFiles(directory="data/cache"), name="cache")
 
 
 app.mount(

+ 83 - 8
src/lib/apis/images/index.ts

@@ -1,9 +1,9 @@
 import { IMAGES_API_BASE_URL } from '$lib/constants';
 
-export const getImageGenerationEnabledStatus = async (token: string = '') => {
+export const getImageGenerationConfig = async (token: string = '') => {
 	let error = null;
 
-	const res = await fetch(`${IMAGES_API_BASE_URL}/enabled`, {
+	const res = await fetch(`${IMAGES_API_BASE_URL}/config`, {
 		method: 'GET',
 		headers: {
 			Accept: 'application/json',
@@ -32,10 +32,50 @@ export const getImageGenerationEnabledStatus = async (token: string = '') => {
 	return res;
 };
 
-export const toggleImageGenerationEnabledStatus = async (token: string = '') => {
+export const updateImageGenerationConfig = async (
+	token: string = '',
+	engine: string,
+	enabled: boolean
+) => {
 	let error = null;
 
-	const res = await fetch(`${IMAGES_API_BASE_URL}/enabled/toggle`, {
+	const res = await fetch(`${IMAGES_API_BASE_URL}/config/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		},
+		body: JSON.stringify({
+			engine,
+			enabled
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getOpenAIKey = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${IMAGES_API_BASE_URL}/key`, {
 		method: 'GET',
 		headers: {
 			Accept: 'application/json',
@@ -61,7 +101,42 @@ export const toggleImageGenerationEnabledStatus = async (token: string = '') =>
 		throw error;
 	}
 
-	return res;
+	return res.OPENAI_API_KEY;
+};
+
+export const updateOpenAIKey = async (token: string = '', key: string) => {
+	let error = null;
+
+	const res = await fetch(`${IMAGES_API_BASE_URL}/key/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		},
+		body: JSON.stringify({
+			key: key
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.OPENAI_API_KEY;
 };
 
 export const getAUTOMATIC1111Url = async (token: string = '') => {
@@ -263,7 +338,7 @@ export const updateImageSteps = async (token: string = '', steps: number) => {
 	return res.IMAGE_STEPS;
 };
 
-export const getDiffusionModels = async (token: string = '') => {
+export const getImageGenerationModels = async (token: string = '') => {
 	let error = null;
 
 	const res = await fetch(`${IMAGES_API_BASE_URL}/models`, {
@@ -295,7 +370,7 @@ export const getDiffusionModels = async (token: string = '') => {
 	return res;
 };
 
-export const getDefaultDiffusionModel = async (token: string = '') => {
+export const getDefaultImageGenerationModel = async (token: string = '') => {
 	let error = null;
 
 	const res = await fetch(`${IMAGES_API_BASE_URL}/models/default`, {
@@ -327,7 +402,7 @@ export const getDefaultDiffusionModel = async (token: string = '') => {
 	return res.model;
 };
 
-export const updateDefaultDiffusionModel = async (token: string = '', model: string) => {
+export const updateDefaultImageGenerationModel = async (token: string = '', model: string) => {
 	let error = null;
 
 	const res = await fetch(`${IMAGES_API_BASE_URL}/models/default/update`, {

+ 3 - 1
src/lib/apis/litellm/index.ts

@@ -77,6 +77,7 @@ type AddLiteLLMModelForm = {
 	api_base: string;
 	api_key: string;
 	rpm: string;
+	max_tokens: string;
 };
 
 export const addLiteLLMModel = async (token: string = '', payload: AddLiteLLMModelForm) => {
@@ -95,7 +96,8 @@ export const addLiteLLMModel = async (token: string = '', payload: AddLiteLLMMod
 				model: payload.model,
 				...(payload.api_base === '' ? {} : { api_base: payload.api_base }),
 				...(payload.api_key === '' ? {} : { api_key: payload.api_key }),
-				...(isNaN(parseInt(payload.rpm)) ? {} : { rpm: parseInt(payload.rpm) })
+				...(isNaN(parseInt(payload.rpm)) ? {} : { rpm: parseInt(payload.rpm) }),
+				...(payload.max_tokens === '' ? {} : { max_tokens: payload.max_tokens })
 			}
 		})
 	})

+ 14 - 11
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -6,6 +6,7 @@
 	import auto_render from 'katex/dist/contrib/auto-render.mjs';
 	import 'katex/dist/katex.min.css';
 
+	import { fade } from 'svelte/transition';
 	import { createEventDispatcher } from 'svelte';
 	import { onMount, tick, getContext } from 'svelte';
 
@@ -278,13 +279,15 @@
 
 	const generateImage = async (message) => {
 		generatingImage = true;
-		const res = await imageGenerations(localStorage.token, message.content);
+		const res = await imageGenerations(localStorage.token, message.content).catch((error) => {
+			toast.error(error);
+		});
 		console.log(res);
 
 		if (res) {
-			message.files = res.images.map((image) => ({
+			message.files = res.map((image) => ({
 				type: 'image',
-				url: `data:image/png;base64,${image}`
+				url: `${image.url}`
 			}));
 
 			dispatch('save', message);
@@ -479,7 +482,7 @@
 													xmlns="http://www.w3.org/2000/svg"
 													fill="none"
 													viewBox="0 0 24 24"
-													stroke-width="1.5"
+													stroke-width="2"
 													stroke="currentColor"
 													class="w-4 h-4"
 												>
@@ -505,7 +508,7 @@
 													xmlns="http://www.w3.org/2000/svg"
 													fill="none"
 													viewBox="0 0 24 24"
-													stroke-width="1.5"
+													stroke-width="2"
 													stroke="currentColor"
 													class="w-4 h-4"
 												>
@@ -624,7 +627,7 @@
 														xmlns="http://www.w3.org/2000/svg"
 														fill="none"
 														viewBox="0 0 24 24"
-														stroke-width="1.5"
+														stroke-width="2"
 														stroke="currentColor"
 														class="w-4 h-4"
 													>
@@ -639,7 +642,7 @@
 														xmlns="http://www.w3.org/2000/svg"
 														fill="none"
 														viewBox="0 0 24 24"
-														stroke-width="1.5"
+														stroke-width="2"
 														stroke="currentColor"
 														class="w-4 h-4"
 													>
@@ -705,7 +708,7 @@
 															xmlns="http://www.w3.org/2000/svg"
 															fill="none"
 															viewBox="0 0 24 24"
-															stroke-width="1.5"
+															stroke-width="2"
 															stroke="currentColor"
 															class="w-4 h-4"
 														>
@@ -735,7 +738,7 @@
 														xmlns="http://www.w3.org/2000/svg"
 														fill="none"
 														viewBox="0 0 24 24"
-														stroke-width="1.5"
+														stroke-width="2"
 														stroke="currentColor"
 														class="w-4 h-4"
 													>
@@ -764,7 +767,7 @@
 														xmlns="http://www.w3.org/2000/svg"
 														fill="none"
 														viewBox="0 0 24 24"
-														stroke-width="1.5"
+														stroke-width="2"
 														stroke="currentColor"
 														class="w-4 h-4"
 													>
@@ -794,7 +797,7 @@
 														xmlns="http://www.w3.org/2000/svg"
 														fill="none"
 														viewBox="0 0 24 24"
-														stroke-width="1.5"
+														stroke-width="2"
 														stroke="currentColor"
 														class="w-4 h-4"
 													>

+ 3 - 3
src/lib/components/chat/Messages/UserMessage.svelte

@@ -261,7 +261,7 @@
 									xmlns="http://www.w3.org/2000/svg"
 									fill="none"
 									viewBox="0 0 24 24"
-									stroke-width="1.5"
+									stroke-width="2"
 									stroke="currentColor"
 									class="w-4 h-4"
 								>
@@ -285,7 +285,7 @@
 									xmlns="http://www.w3.org/2000/svg"
 									fill="none"
 									viewBox="0 0 24 24"
-									stroke-width="1.5"
+									stroke-width="2"
 									stroke="currentColor"
 									class="w-4 h-4"
 								>
@@ -310,7 +310,7 @@
 										xmlns="http://www.w3.org/2000/svg"
 										fill="none"
 										viewBox="0 0 24 24"
-										stroke-width="1.5"
+										stroke-width="2"
 										stroke="currentColor"
 										class="w-4 h-4"
 									>

+ 1 - 1
src/lib/components/chat/Settings/General.svelte

@@ -186,7 +186,7 @@
 			<div class=" my-2.5 text-sm font-medium">{$i18n.t('System Prompt')}</div>
 			<textarea
 				bind:value={system}
-				class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none"
+				class="w-full rounded-lg p-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
 				rows="4"
 			/>
 		</div>

+ 134 - 76
src/lib/components/chat/Settings/Images.svelte

@@ -5,16 +5,18 @@
 	import { config, user } from '$lib/stores';
 	import {
 		getAUTOMATIC1111Url,
-		getDefaultDiffusionModel,
-		getDiffusionModels,
-		getImageGenerationEnabledStatus,
+		getImageGenerationModels,
+		getDefaultImageGenerationModel,
+		updateDefaultImageGenerationModel,
 		getImageSize,
-		toggleImageGenerationEnabledStatus,
+		getImageGenerationConfig,
+		updateImageGenerationConfig,
 		updateAUTOMATIC1111Url,
-		updateDefaultDiffusionModel,
 		updateImageSize,
 		getImageSteps,
-		updateImageSteps
+		updateImageSteps,
+		getOpenAIKey,
+		updateOpenAIKey
 	} from '$lib/apis/images';
 	import { getBackendConfig } from '$lib/apis';
 	const dispatch = createEventDispatcher();
@@ -25,8 +27,11 @@
 
 	let loading = false;
 
+	let imageGenerationEngine = '';
 	let enableImageGeneration = false;
+
 	let AUTOMATIC1111_BASE_URL = '';
+	let OPENAI_API_KEY = '';
 
 	let selectedModel = '';
 	let models = null;
@@ -35,11 +40,11 @@
 	let steps = 50;
 
 	const getModels = async () => {
-		models = await getDiffusionModels(localStorage.token).catch((error) => {
+		models = await getImageGenerationModels(localStorage.token).catch((error) => {
 			toast.error(error);
 			return null;
 		});
-		selectedModel = await getDefaultDiffusionModel(localStorage.token).catch((error) => {
+		selectedModel = await getDefaultImageGenerationModel(localStorage.token).catch((error) => {
 			return '';
 		});
 	};
@@ -64,33 +69,45 @@
 			AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
 		}
 	};
-	const toggleImageGeneration = async () => {
-		if (AUTOMATIC1111_BASE_URL) {
-			enableImageGeneration = await toggleImageGenerationEnabledStatus(localStorage.token).catch(
-				(error) => {
-					toast.error(error);
-					return false;
-				}
-			);
+	const updateImageGeneration = async () => {
+		const res = await updateImageGenerationConfig(
+			localStorage.token,
+			imageGenerationEngine,
+			enableImageGeneration
+		).catch((error) => {
+			toast.error(error);
+			return null;
+		});
 
-			if (enableImageGeneration) {
-				config.set(await getBackendConfig(localStorage.token));
-				getModels();
-			}
-		} else {
-			enableImageGeneration = false;
-			toast.error($i18n.t('{{item}} not provided', { item: 'AUTOMATIC1111_BASE_URL' }));
+		if (res) {
+			imageGenerationEngine = res.engine;
+			enableImageGeneration = res.enabled;
+		}
+
+		if (enableImageGeneration) {
+			config.set(await getBackendConfig(localStorage.token));
+			getModels();
 		}
 	};
 
 	onMount(async () => {
 		if ($user.role === 'admin') {
-			enableImageGeneration = await getImageGenerationEnabledStatus(localStorage.token);
+			const res = await getImageGenerationConfig(localStorage.token).catch((error) => {
+				toast.error(error);
+				return null;
+			});
+
+			if (res) {
+				imageGenerationEngine = res.engine;
+				enableImageGeneration = res.enabled;
+			}
 			AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
+			OPENAI_API_KEY = await getOpenAIKey(localStorage.token);
+
+			imageSize = await getImageSize(localStorage.token);
+			steps = await getImageSteps(localStorage.token);
 
-			if (enableImageGeneration && AUTOMATIC1111_BASE_URL) {
-				imageSize = await getImageSize(localStorage.token);
-				steps = await getImageSteps(localStorage.token);
+			if (enableImageGeneration) {
 				getModels();
 			}
 		}
@@ -101,7 +118,11 @@
 	class="flex flex-col h-full justify-between space-y-3 text-sm"
 	on:submit|preventDefault={async () => {
 		loading = true;
-		await updateDefaultDiffusionModel(localStorage.token, selectedModel);
+		await updateOpenAIKey(localStorage.token, OPENAI_API_KEY);
+
+		await updateDefaultImageGenerationModel(localStorage.token, selectedModel);
+
+		await updateDefaultImageGenerationModel(localStorage.token, selectedModel);
 		await updateImageSize(localStorage.token, imageSize).catch((error) => {
 			toast.error(error);
 			return null;
@@ -119,6 +140,23 @@
 		<div>
 			<div class=" mb-1 text-sm font-medium">{$i18n.t('Image Settings')}</div>
 
+			<div class=" py-0.5 flex w-full justify-between">
+				<div class=" self-center text-xs font-medium">Image Generation Engine</div>
+				<div class="flex items-center relative">
+					<select
+						class="w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
+						bind:value={imageGenerationEngine}
+						placeholder="Select a mode"
+						on:change={async () => {
+							await updateImageGeneration();
+						}}
+					>
+						<option value="">Default (Automatic1111)</option>
+						<option value="openai">Open AI (Dall-E)</option>
+					</select>
+				</div>
+			</div>
+
 			<div>
 				<div class=" py-0.5 flex w-full justify-between">
 					<div class=" self-center text-xs font-medium">
@@ -128,7 +166,17 @@
 					<button
 						class="p-1 px-3 text-xs flex rounded transition"
 						on:click={() => {
-							toggleImageGeneration();
+							if (imageGenerationEngine === '' && AUTOMATIC1111_BASE_URL === '') {
+								toast.error('AUTOMATIC1111 Base URL is required.');
+								enableImageGeneration = false;
+							} else if (imageGenerationEngine === 'openai' && OPENAI_API_KEY === '') {
+								toast.error('OpenAI API Key is required.');
+								enableImageGeneration = false;
+							} else {
+								enableImageGeneration = !enableImageGeneration;
+							}
+
+							updateImageGeneration();
 						}}
 						type="button"
 					>
@@ -143,50 +191,62 @@
 		</div>
 		<hr class=" dark:border-gray-700" />
 
-		<div class=" mb-2.5 text-sm font-medium">{$i18n.t('AUTOMATIC1111 Base URL')}</div>
-		<div class="flex w-full">
-			<div class="flex-1 mr-2">
-				<input
-					class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-					placeholder="Enter URL (e.g. http://127.0.0.1:7860/)"
-					bind:value={AUTOMATIC1111_BASE_URL}
-				/>
+		{#if imageGenerationEngine === ''}
+			<div class=" mb-2.5 text-sm font-medium">{$i18n.t('AUTOMATIC1111 Base URL')}</div>
+			<div class="flex w-full">
+				<div class="flex-1 mr-2">
+					<input
+						class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+						placeholder="Enter URL (e.g. http://127.0.0.1:7860/)"
+						bind:value={AUTOMATIC1111_BASE_URL}
+					/>
+				</div>
+				<button
+					class="px-3 bg-gray-200 hover:bg-gray-300 dark:bg-gray-600 dark:hover:bg-gray-700 rounded-lg transition"
+					type="button"
+					on:click={() => {
+						// updateOllamaAPIUrlHandler();
+
+						updateAUTOMATIC1111UrlHandler();
+					}}
+				>
+					<svg
+						xmlns="http://www.w3.org/2000/svg"
+						viewBox="0 0 20 20"
+						fill="currentColor"
+						class="w-4 h-4"
+					>
+						<path
+							fill-rule="evenodd"
+							d="M15.312 11.424a5.5 5.5 0 01-9.201 2.466l-.312-.311h2.433a.75.75 0 000-1.5H3.989a.75.75 0 00-.75.75v4.242a.75.75 0 001.5 0v-2.43l.31.31a7 7 0 0011.712-3.138.75.75 0 00-1.449-.39zm1.23-3.723a.75.75 0 00.219-.53V2.929a.75.75 0 00-1.5 0V5.36l-.31-.31A7 7 0 003.239 8.188a.75.75 0 101.448.389A5.5 5.5 0 0113.89 6.11l.311.31h-2.432a.75.75 0 000 1.5h4.243a.75.75 0 00.53-.219z"
+							clip-rule="evenodd"
+						/>
+					</svg>
+				</button>
 			</div>
-			<button
-				class="px-3 bg-gray-200 hover:bg-gray-300 dark:bg-gray-600 dark:hover:bg-gray-700 rounded transition"
-				type="button"
-				on:click={() => {
-					// updateOllamaAPIUrlHandler();
-
-					updateAUTOMATIC1111UrlHandler();
-				}}
-			>
-				<svg
-					xmlns="http://www.w3.org/2000/svg"
-					viewBox="0 0 20 20"
-					fill="currentColor"
-					class="w-4 h-4"
+
+			<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
+				Include `--api` flag when running stable-diffusion-webui
+				<a
+					class=" text-gray-300 font-medium"
+					href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/3734"
+					target="_blank"
 				>
-					<path
-						fill-rule="evenodd"
-						d="M15.312 11.424a5.5 5.5 0 01-9.201 2.466l-.312-.311h2.433a.75.75 0 000-1.5H3.989a.75.75 0 00-.75.75v4.242a.75.75 0 001.5 0v-2.43l.31.31a7 7 0 0011.712-3.138.75.75 0 00-1.449-.39zm1.23-3.723a.75.75 0 00.219-.53V2.929a.75.75 0 00-1.5 0V5.36l-.31-.31A7 7 0 003.239 8.188a.75.75 0 101.448.389A5.5 5.5 0 0113.89 6.11l.311.31h-2.432a.75.75 0 000 1.5h4.243a.75.75 0 00.53-.219z"
-						clip-rule="evenodd"
+					(e.g. `sh webui.sh --api`)
+				</a>
+			</div>
+		{:else if imageGenerationEngine === 'openai'}
+			<div class=" mb-2.5 text-sm font-medium">OpenAI API Key</div>
+			<div class="flex w-full">
+				<div class="flex-1 mr-2">
+					<input
+						class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+						placeholder="Enter API Key"
+						bind:value={OPENAI_API_KEY}
 					/>
-				</svg>
-			</button>
-		</div>
-
-		<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
-			{$i18n.t('Include `--api` flag when running stable-diffusion-webui')}
-			<a
-				class=" text-gray-300 font-medium underline"
-				href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/3734"
-				target="_blank"
-			>
-				<br />
-				{$i18n.t('(e.g. `sh webui.sh --api`)')}
-			</a>
-		</div>
+				</div>
+			</div>
+		{/if}
 
 		{#if enableImageGeneration}
 			<hr class=" dark:border-gray-700" />
@@ -196,7 +256,7 @@
 				<div class="flex w-full">
 					<div class="flex-1 mr-2">
 						<select
-							class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
+							class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
 							bind:value={selectedModel}
 							placeholder={$i18n.t('Select a model')}
 						>
@@ -204,9 +264,7 @@
 								<option value="" disabled selected>{$i18n.t('Select a model')}</option>
 							{/if}
 							{#each models ?? [] as model}
-								<option value={model.title} class="bg-gray-100 dark:bg-gray-700"
-									>{model.model_name}</option
-								>
+								<option value={model.id} class="bg-gray-100 dark:bg-gray-700">{model.name}</option>
 							{/each}
 						</select>
 					</div>
@@ -218,7 +276,7 @@
 				<div class="flex w-full">
 					<div class="flex-1 mr-2">
 						<input
-							class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
+							class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
 							placeholder="Enter Image Size (e.g. 512x512)"
 							bind:value={imageSize}
 						/>
@@ -231,7 +289,7 @@
 				<div class="flex w-full">
 					<div class="flex-1 mr-2">
 						<input
-							class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
+							class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
 							placeholder="Enter Number of Steps (e.g. 50)"
 							bind:value={steps}
 						/>

+ 20 - 1
src/lib/components/chat/Settings/Models.svelte

@@ -29,6 +29,7 @@
 	let liteLLMAPIBase = '';
 	let liteLLMAPIKey = '';
 	let liteLLMRPM = '';
+	let liteLLMMaxTokens = '';
 
 	let deleteLiteLLMModelId = '';
 
@@ -336,7 +337,8 @@
 				model: liteLLMModel,
 				api_base: liteLLMAPIBase,
 				api_key: liteLLMAPIKey,
-				rpm: liteLLMRPM
+				rpm: liteLLMRPM,
+				max_tokens: liteLLMMaxTokens
 			}).catch((error) => {
 				toast.error(error);
 				return null;
@@ -356,6 +358,7 @@
 		liteLLMAPIBase = '';
 		liteLLMAPIKey = '';
 		liteLLMRPM = '';
+		liteLLMMaxTokens = '';
 
 		liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token);
 		models.set(await getModels());
@@ -838,6 +841,22 @@
 									</div>
 								</div>
 							</div>
+
+							<div>
+								<div class="mb-1.5 text-sm font-medium">Max Tokens</div>
+								<div class="flex w-full">
+									<div class="flex-1">
+										<input
+											class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+											placeholder="Enter Max Tokens (litellm_params.max_tokens)"
+											bind:value={liteLLMMaxTokens}
+											type="number"
+											min="1"
+											autocomplete="off"
+										/>
+									</div>
+								</div>
+							</div>
 						{/if}
 					</div>
 

+ 7 - 2
src/lib/components/common/Image.svelte

@@ -1,18 +1,23 @@
 <script lang="ts">
+	import { WEBUI_BASE_URL } from '$lib/constants';
 	import ImagePreview from './ImagePreview.svelte';
 
 	export let src = '';
 	export let alt = '';
 
+	let _src = '';
+
+	$: _src = src.startsWith('/') ? `${WEBUI_BASE_URL}${src}` : src;
+
 	let showImagePreview = false;
 </script>
 
-<ImagePreview bind:show={showImagePreview} {src} {alt} />
+<ImagePreview bind:show={showImagePreview} src={_src} {alt} />
 <button
 	on:click={() => {
 		console.log('image preview');
 		showImagePreview = true;
 	}}
 >
-	<img {src} {alt} class=" max-h-96 rounded-lg" draggable="false" />
+	<img src={_src} {alt} class=" max-h-96 rounded-lg" draggable="false" />
 </button>

+ 12 - 7
src/lib/components/layout/Sidebar.svelte

@@ -64,12 +64,16 @@
 	};
 
 	const editChatTitle = async (id, _title) => {
-		title = _title;
-
-		await updateChatById(localStorage.token, id, {
-			title: _title
-		});
-		await chats.set(await getChatList(localStorage.token));
+		if (_title === '') {
+			toast.error('Title cannot be an empty string.');
+		} else {
+			title = _title;
+
+			await updateChatById(localStorage.token, id, {
+				title: _title
+			});
+			await chats.set(await getChatList(localStorage.token));
+		}
 	};
 
 	const deleteChat = async (id) => {
@@ -393,12 +397,13 @@
 										show = false;
 									}
 								}}
+								draggable="false"
 							>
 								<div class=" flex self-center flex-1 w-full">
 									<div
 										class=" text-left self-center overflow-hidden {chat.id === $chatId
 											? 'w-[160px]'
-											: 'w-full'} "
+											: 'w-full'}  h-[20px]"
 									>
 										{chat.title}
 									</div>