Explorar o código

Merge pull request #1117 from open-webui/model-whitelist

feat: model filter (whitelist)
Timothy Jaeryang Baek hai 1 ano
pai
achega
bcabd3df84

+ 16 - 2
backend/apps/ollama/main.py

@@ -29,6 +29,10 @@ app.add_middleware(
     allow_headers=["*"],
     allow_headers=["*"],
 )
 )
 
 
+
+app.state.MODEL_FILTER_ENABLED = False
+app.state.MODEL_LIST = []
+
 app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
 app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
 app.state.MODELS = {}
 app.state.MODELS = {}
 
 
@@ -129,9 +133,19 @@ async def get_all_models():
 async def get_ollama_tags(
 async def get_ollama_tags(
     url_idx: Optional[int] = None, user=Depends(get_current_user)
     url_idx: Optional[int] = None, user=Depends(get_current_user)
 ):
 ):
-
     if url_idx == None:
     if url_idx == None:
-        return await get_all_models()
+        models = await get_all_models()
+
+        if app.state.MODEL_FILTER_ENABLED:
+            if user.role == "user":
+                models["models"] = list(
+                    filter(
+                        lambda model: model["name"] in app.state.MODEL_LIST,
+                        models["models"],
+                    )
+                )
+                return models
+        return models
     else:
     else:
         url = app.state.OLLAMA_BASE_URLS[url_idx]
         url = app.state.OLLAMA_BASE_URLS[url_idx]
         try:
         try:

+ 15 - 3
backend/apps/openai/main.py

@@ -34,6 +34,9 @@ app.add_middleware(
     allow_headers=["*"],
     allow_headers=["*"],
 )
 )
 
 
+app.state.MODEL_FILTER_ENABLED = False
+app.state.MODEL_LIST = []
+
 app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
 app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
 app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
 app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
 
 
@@ -186,12 +189,21 @@ async def get_all_models():
     return models
     return models
 
 
 
 
-# , user=Depends(get_current_user)
 @app.get("/models")
 @app.get("/models")
 @app.get("/models/{url_idx}")
 @app.get("/models/{url_idx}")
-async def get_models(url_idx: Optional[int] = None):
+async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
     if url_idx == None:
     if url_idx == None:
-        return await get_all_models()
+        models = await get_all_models()
+        if app.state.MODEL_FILTER_ENABLED:
+            if user.role == "user":
+                models["data"] = list(
+                    filter(
+                        lambda model: model["id"] in app.state.MODEL_LIST,
+                        models["data"],
+                    )
+                )
+                return models
+        return models
     else:
     else:
         url = app.state.OPENAI_API_BASE_URLS[url_idx]
         url = app.state.OPENAI_API_BASE_URLS[url_idx]
         try:
         try:

+ 34 - 0
backend/main.py

@@ -23,7 +23,11 @@ from apps.images.main import app as images_app
 from apps.rag.main import app as rag_app
 from apps.rag.main import app as rag_app
 from apps.web.main import app as webui_app
 from apps.web.main import app as webui_app
 
 
+from pydantic import BaseModel
+from typing import List
 
 
+
+from utils.utils import get_admin_user
 from apps.rag.utils import query_doc, query_collection, rag_template
 from apps.rag.utils import query_doc, query_collection, rag_template
 
 
 from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
 from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
@@ -43,6 +47,9 @@ class SPAStaticFiles(StaticFiles):
 
 
 app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
 app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
 
 
+app.state.MODEL_FILTER_ENABLED = False
+app.state.MODEL_LIST = []
+
 origins = ["*"]
 origins = ["*"]
 
 
 app.add_middleware(
 app.add_middleware(
@@ -213,6 +220,33 @@ async def get_app_config():
     }
     }
 
 
 
 
+@app.get("/api/config/model/filter")
+async def get_model_filter_config(user=Depends(get_admin_user)):
+    return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST}
+
+
+class ModelFilterConfigForm(BaseModel):
+    enabled: bool
+    models: List[str]
+
+
+@app.post("/api/config/model/filter")
+async def get_model_filter_config(
+    form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
+):
+
+    app.state.MODEL_FILTER_ENABLED = form_data.enabled
+    app.state.MODEL_LIST = form_data.models
+
+    ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
+    ollama_app.state.MODEL_LIST = app.state.MODEL_LIST
+
+    openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
+    openai_app.state.MODEL_LIST = app.state.MODEL_LIST
+
+    return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST}
+
+
 @app.get("/api/version")
 @app.get("/api/version")
 async def get_app_config():
 async def get_app_config():
 
 

+ 62 - 0
src/lib/apis/index.ts

@@ -77,3 +77,65 @@ export const getVersionUpdates = async () => {
 
 
 	return res;
 	return res;
 };
 };
+
+export const getModelFilterConfig = async (token: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_BASE_URL}/api/config/model/filter`, {
+		method: 'GET',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const updateModelFilterConfig = async (
+	token: string,
+	enabled: boolean,
+	models: string[]
+) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_BASE_URL}/api/config/model/filter`, {
+		method: 'POST',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			enabled: enabled,
+			models: models
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 113 - 0
src/lib/components/admin/Settings/Users.svelte

@@ -1,10 +1,14 @@
 <script lang="ts">
 <script lang="ts">
+	import { getModelFilterConfig, updateModelFilterConfig } from '$lib/apis';
 	import { getSignUpEnabledStatus, toggleSignUpEnabledStatus } from '$lib/apis/auths';
 	import { getSignUpEnabledStatus, toggleSignUpEnabledStatus } from '$lib/apis/auths';
 	import { getUserPermissions, updateUserPermissions } from '$lib/apis/users';
 	import { getUserPermissions, updateUserPermissions } from '$lib/apis/users';
+	import { models } from '$lib/stores';
 	import { onMount } from 'svelte';
 	import { onMount } from 'svelte';
 
 
 	export let saveHandler: Function;
 	export let saveHandler: Function;
 
 
+	let whitelistEnabled = false;
+	let whitelistModels = [''];
 	let permissions = {
 	let permissions = {
 		chat: {
 		chat: {
 			deletion: true
 			deletion: true
@@ -13,6 +17,13 @@
 
 
 	onMount(async () => {
 	onMount(async () => {
 		permissions = await getUserPermissions(localStorage.token);
 		permissions = await getUserPermissions(localStorage.token);
+
+		const res = await getModelFilterConfig(localStorage.token);
+		if (res) {
+			whitelistEnabled = res.enabled;
+
+			whitelistModels = res.models.length > 0 ? res.models : [''];
+		}
 	});
 	});
 </script>
 </script>
 
 
@@ -21,6 +32,8 @@
 	on:submit|preventDefault={async () => {
 	on:submit|preventDefault={async () => {
 		// console.log('submit');
 		// console.log('submit');
 		await updateUserPermissions(localStorage.token, permissions);
 		await updateUserPermissions(localStorage.token, permissions);
+
+		await updateModelFilterConfig(localStorage.token, whitelistEnabled, whitelistModels);
 		saveHandler();
 		saveHandler();
 	}}
 	}}
 >
 >
@@ -69,6 +82,106 @@
 				</button>
 				</button>
 			</div>
 			</div>
 		</div>
 		</div>
+
+		<hr class=" dark:border-gray-700 my-2" />
+
+		<div class="mt-2 space-y-3 pr-1.5">
+			<div>
+				<div class="mb-2">
+					<div class="flex justify-between items-center text-xs">
+						<div class=" text-sm font-medium">Manage Models</div>
+					</div>
+				</div>
+
+				<div class=" space-y-3">
+					<div>
+						<div class="flex justify-between items-center text-xs">
+							<div class=" text-xs font-medium">Model Whitelisting</div>
+
+							<button
+								class=" text-xs font-medium text-gray-500"
+								type="button"
+								on:click={() => {
+									whitelistEnabled = !whitelistEnabled;
+								}}>{whitelistEnabled ? 'On' : 'Off'}</button
+							>
+						</div>
+					</div>
+
+					{#if whitelistEnabled}
+						<div>
+							<div class=" space-y-1.5">
+								{#each whitelistModels as modelId, modelIdx}
+									<div class="flex w-full">
+										<div class="flex-1 mr-2">
+											<select
+												class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+												bind:value={modelId}
+												placeholder="Select a model"
+											>
+												<option value="" disabled selected>Select a model</option>
+												{#each $models.filter((model) => model.id) as model}
+													<option value={model.id} class="bg-gray-100 dark:bg-gray-700"
+														>{model.name}</option
+													>
+												{/each}
+											</select>
+										</div>
+
+										{#if modelIdx === 0}
+											<button
+												class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-900 dark:text-white rounded-lg transition"
+												type="button"
+												on:click={() => {
+													if (whitelistModels.at(-1) !== '') {
+														whitelistModels = [...whitelistModels, ''];
+													}
+												}}
+											>
+												<svg
+													xmlns="http://www.w3.org/2000/svg"
+													viewBox="0 0 16 16"
+													fill="currentColor"
+													class="w-4 h-4"
+												>
+													<path
+														d="M8.75 3.75a.75.75 0 0 0-1.5 0v3.5h-3.5a.75.75 0 0 0 0 1.5h3.5v3.5a.75.75 0 0 0 1.5 0v-3.5h3.5a.75.75 0 0 0 0-1.5h-3.5v-3.5Z"
+													/>
+												</svg>
+											</button>
+										{:else}
+											<button
+												class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-900 dark:text-white rounded-lg transition"
+												type="button"
+												on:click={() => {
+													whitelistModels.splice(modelIdx, 1);
+													whitelistModels = whitelistModels;
+												}}
+											>
+												<svg
+													xmlns="http://www.w3.org/2000/svg"
+													viewBox="0 0 16 16"
+													fill="currentColor"
+													class="w-4 h-4"
+												>
+													<path d="M3.75 7.25a.75.75 0 0 0 0 1.5h8.5a.75.75 0 0 0 0-1.5h-8.5Z" />
+												</svg>
+											</button>
+										{/if}
+									</div>
+								{/each}
+							</div>
+
+							<div class="flex justify-end items-center text-xs mt-1.5 text-right">
+								<div class=" text-xs font-medium">
+									{whitelistModels.length} Model(s) Whitelisted
+								</div>
+							</div>
+						</div>
+					{/if}
+				</div>
+			</div>
+		</div>
 	</div>
 	</div>
 
 
 	<div class="flex justify-end pt-3 text-sm font-medium">
 	<div class="flex justify-end pt-3 text-sm font-medium">

+ 0 - 82
src/lib/components/chat/Settings/Models.svelte

@@ -912,88 +912,6 @@
 					{/if}
 					{/if}
 				</div>
 				</div>
 			</div>
 			</div>
-
-			<!-- <div class="mt-2 space-y-3 pr-1.5">
-				<div>
-					<div class=" mb-2.5 text-sm font-medium">Add LiteLLM Model</div>
-					<div class="flex w-full mb-2">
-						<div class="flex-1">
-							<input
-								class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-								placeholder="Enter LiteLLM Model (e.g. ollama/mistral)"
-								bind:value={liteLLMModel}
-								autocomplete="off"
-							/>
-						</div>
-					</div>
-
-					<div class="flex justify-between items-center text-sm">
-						<div class="  font-medium">Advanced Model Params</div>
-						<button
-							class=" text-xs font-medium text-gray-500"
-							type="button"
-							on:click={() => {
-								showLiteLLMParams = !showLiteLLMParams;
-							}}>{showLiteLLMParams ? 'Hide' : 'Show'}</button
-						>
-					</div>
-
-					{#if showLiteLLMParams}
-						<div>
-							<div class=" mb-2.5 text-sm font-medium">LiteLLM API Key</div>
-							<div class="flex w-full">
-								<div class="flex-1">
-									<input
-										class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-										placeholder="Enter LiteLLM API Key (e.g. os.environ/AZURE_API_KEY_CA)"
-										bind:value={liteLLMAPIKey}
-										autocomplete="off"
-									/>
-								</div>
-							</div>
-						</div>
-
-						<div>
-							<div class=" mb-2.5 text-sm font-medium">LiteLLM API Base URL</div>
-							<div class="flex w-full">
-								<div class="flex-1">
-									<input
-										class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-										placeholder="Enter LiteLLM API Base URL"
-										bind:value={liteLLMAPIBase}
-										autocomplete="off"
-									/>
-								</div>
-							</div>
-						</div>
-
-						<div>
-							<div class=" mb-2.5 text-sm font-medium">LiteLLM API RPM</div>
-							<div class="flex w-full">
-								<div class="flex-1">
-									<input
-										class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-										placeholder="Enter LiteLLM API RPM"
-										bind:value={liteLLMRPM}
-										autocomplete="off"
-									/>
-								</div>
-							</div>
-						</div>
-					{/if}
-
-					<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
-						Not sure what to add?
-						<a
-							class=" text-gray-300 font-medium underline"
-							href="https://litellm.vercel.app/docs/proxy/configs#quick-start"
-							target="_blank"
-						>
-							Click here for help.
-						</a>
-					</div>
-				</div>
-			</div> -->
 		</div>
 		</div>
 	</div>
 	</div>
 </div>
 </div>

+ 1 - 1
src/routes/(app)/playground/+page.svelte

@@ -267,7 +267,7 @@
 
 
 <div class="min-h-screen max-h-[100dvh] w-full flex justify-center dark:text-white">
 <div class="min-h-screen max-h-[100dvh] w-full flex justify-center dark:text-white">
 	<div class=" flex flex-col justify-between w-full overflow-y-auto h-[100dvh]">
 	<div class=" flex flex-col justify-between w-full overflow-y-auto h-[100dvh]">
-		<div class="max-w-2xl mx-auto w-full px-3 p-3 md:px-0 h-full">
+		<div class="max-w-2xl mx-auto w-full px-3 md:px-0 my-10 h-full">
 			<div class=" flex flex-col h-full">
 			<div class=" flex flex-col h-full">
 				<div class="flex flex-col justify-between mb-2.5 gap-1">
 				<div class="flex flex-col justify-between mb-2.5 gap-1">
 					<div class="flex justify-between items-center gap-2">
 					<div class="flex justify-between items-center gap-2">