Browse Source

enh: connection tags

Timothy Jaeryang Baek 1 month ago
parent
commit
c309412980

+ 18 - 8
backend/open_webui/main.py

@@ -965,14 +965,24 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
 
         return filtered_models
 
-    models = await get_all_models(request, user=user)
-
-    # Filter out filter pipelines
-    models = [
-        model
-        for model in models
-        if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
-    ]
+    all_models = await get_all_models(request, user=user)
+
+    models = []
+    for model in all_models:
+        # Filter out filter pipelines
+        if "pipeline" in model and model["pipeline"].get("type", None) == "filter":
+            continue
+
+        model_tags = [
+            tag.get("name")
+            for tag in model.get("info", {}).get("meta", {}).get("tags", [])
+        ]
+        tags = [tag.get("name") for tag in model.get("tags", [])]
+
+        tags = list(set(model_tags + tags))
+        model["tags"] = [{"name": tag} for tag in tags]
+
+        models.append(model)
 
     model_order_list = request.app.state.config.MODEL_ORDER_LIST
     if model_order_list:

+ 6 - 1
backend/open_webui/routers/ollama.py

@@ -295,7 +295,7 @@ async def update_config(
     }
 
 
-@cached(ttl=3)
+@cached(ttl=1)
 async def get_all_models(request: Request, user: UserModel = None):
     log.info("get_all_models()")
     if request.app.state.config.ENABLE_OLLAMA_API:
@@ -336,6 +336,7 @@ async def get_all_models(request: Request, user: UserModel = None):
                 )
 
                 prefix_id = api_config.get("prefix_id", None)
+                tags = api_config.get("tags", [])
                 model_ids = api_config.get("model_ids", [])
 
                 if len(model_ids) != 0 and "models" in response:
@@ -350,6 +351,10 @@ async def get_all_models(request: Request, user: UserModel = None):
                     for model in response.get("models", []):
                         model["model"] = f"{prefix_id}.{model['model']}"
 
+                if tags:
+                    for model in response.get("models", []):
+                        model["tags"] = tags
+
         def merge_models_lists(model_lists):
             merged_models = {}
 

+ 8 - 1
backend/open_webui/routers/openai.py

@@ -353,6 +353,7 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list:
             )
 
             prefix_id = api_config.get("prefix_id", None)
+            tags = api_config.get("tags", [])
 
             if prefix_id:
                 for model in (
@@ -360,6 +361,12 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list:
                 ):
                     model["id"] = f"{prefix_id}.{model['id']}"
 
+            if tags:
+                for model in (
+                    response if isinstance(response, list) else response.get("data", [])
+                ):
+                    model["tags"] = tags
+
     log.debug(f"get_all_models:responses() {responses}")
     return responses
 
@@ -377,7 +384,7 @@ async def get_filtered_models(models, user):
     return filtered_models
 
 
-@cached(ttl=3)
+@cached(ttl=1)
 async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
     log.info("get_all_models()")
 

+ 1 - 0
backend/open_webui/utils/models.py

@@ -49,6 +49,7 @@ async def get_all_base_models(request: Request, user: UserModel = None):
                 "created": int(time.time()),
                 "owned_by": "ollama",
                 "ollama": model,
+                "tags": model.get("tags", []),
             }
             for model in ollama_models["models"]
         ]

+ 7 - 0
src/lib/apis/index.ts

@@ -114,6 +114,13 @@ export const getModels = async (
 					}
 				}
 
+				const tags = apiConfig.tags;
+				if (tags) {
+					for (const model of models) {
+						model.tags = tags;
+					}
+				}
+
 				localModels = localModels.concat(models);
 			}
 		}

+ 28 - 0
src/lib/components/AddConnectionModal.svelte

@@ -14,6 +14,7 @@
 	import SensitiveInput from '$lib/components/common/SensitiveInput.svelte';
 	import Tooltip from '$lib/components/common/Tooltip.svelte';
 	import Switch from '$lib/components/common/Switch.svelte';
+	import Tags from './common/Tags.svelte';
 
 	export let onSubmit: Function = () => {};
 	export let onDelete: Function = () => {};
@@ -31,6 +32,7 @@
 
 	let prefixId = '';
 	let enable = true;
+	let tags = [];
 
 	let modelId = '';
 	let modelIds = [];
@@ -88,6 +90,7 @@
 			key,
 			config: {
 				enable: enable,
+				tags: tags,
 				prefix_id: prefixId,
 				model_ids: modelIds
 			}
@@ -101,6 +104,7 @@
 		url = '';
 		key = '';
 		prefixId = '';
+		tags = [];
 		modelIds = [];
 	};
 
@@ -110,6 +114,7 @@
 			key = connection.key;
 
 			enable = connection.config?.enable ?? true;
+			tags = connection.config?.tags ?? [];
 			prefixId = connection.config?.prefix_id ?? '';
 			modelIds = connection.config?.model_ids ?? [];
 		}
@@ -244,6 +249,29 @@
 							</div>
 						</div>
 
+						<div class="flex gap-2 mt-2">
+							<div class="flex flex-col w-full">
+								<div class=" mb-1.5 text-xs text-gray-500">{$i18n.t('Tags')}</div>
+
+								<div class="flex-1">
+									<Tags
+										bind:tags
+										on:add={(e) => {
+											tags = [
+												...tags,
+												{
+													name: e.detail
+												}
+											];
+										}}
+										on:delete={(e) => {
+											tags = tags.filter((tag) => tag.name !== e.detail);
+										}}
+									/>
+								</div>
+							</div>
+						</div>
+
 						<hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" />
 
 						<div class="flex flex-col w-full">

+ 21 - 20
src/lib/components/chat/ModelSelector/Selector.svelte

@@ -77,7 +77,7 @@
 			const _item = {
 				...item,
 				modelName: item.model?.name,
-				tags: item.model?.info?.meta?.tags?.map((tag) => tag.name).join(' '),
+				tags: (item.model?.tags ?? []).map((tag) => tag.name).join(' '),
 				desc: item.model?.info?.meta?.description
 			};
 			return _item;
@@ -98,7 +98,7 @@
 					if (selectedTag === '') {
 						return true;
 					}
-					return item.model?.info?.meta?.tags?.map((tag) => tag.name).includes(selectedTag);
+					return (item.model?.tags ?? []).map((tag) => tag.name).includes(selectedTag);
 				})
 				.filter((item) => {
 					if (selectedConnectionType === '') {
@@ -116,7 +116,7 @@
 					if (selectedTag === '') {
 						return true;
 					}
-					return item.model?.info?.meta?.tags?.map((tag) => tag.name).includes(selectedTag);
+					return (item.model?.tags ?? []).map((tag) => tag.name).includes(selectedTag);
 				})
 				.filter((item) => {
 					if (selectedConnectionType === '') {
@@ -262,7 +262,7 @@
 		ollamaVersion = await getOllamaVersion(localStorage.token).catch((error) => false);
 
 		if (items) {
-			tags = items.flatMap((item) => item.model?.info?.meta?.tags ?? []).map((tag) => tag.name);
+			tags = items.flatMap((item) => item.model?.tags ?? []).map((tag) => tag.name);
 
 			// Remove duplicates and sort
 			tags = Array.from(new Set(tags)).sort((a, b) => a.localeCompare(b));
@@ -291,12 +291,12 @@
 	onOpenChange={async () => {
 		searchValue = '';
 		// Do NOT reset filters - keep the previously selected tag/connection type
-		
+
 		await tick();
-		
+
 		// First check if the currently selected model is visible in the filtered list
-		const selectedInFiltered = filteredItems.findIndex(item => item.value === value);
-		
+		const selectedInFiltered = filteredItems.findIndex((item) => item.value === value);
+
 		if (selectedInFiltered >= 0) {
 			// The selected model is visible in the current filter
 			selectedModelIdx = selectedInFiltered;
@@ -304,22 +304,23 @@
 			// The selected model is not visible, default to first item in filtered list
 			selectedModelIdx = 0;
 		}
-		
+
 		await tick();
-		
+
 		// Scroll to the selected item if it exists in the current filtered view
-		const itemToScrollTo = selectedInFiltered >= 0
-			? document.querySelector(`[data-value="${value}"]`)
-			: document.querySelector('[data-arrow-selected="true"]');
-			
+		const itemToScrollTo =
+			selectedInFiltered >= 0
+				? document.querySelector(`[data-value="${value}"]`)
+				: document.querySelector('[data-arrow-selected="true"]');
+
 		if (itemToScrollTo) {
 			const container = itemToScrollTo.closest('.overflow-y-auto');
 			if (container) {
 				const itemTop = itemToScrollTo.offsetTop;
 				const containerHeight = container.clientHeight;
 				const itemHeight = itemToScrollTo.clientHeight;
-				
-				container.scrollTop = itemTop - (containerHeight / 2) + (itemHeight / 2);
+
+				container.scrollTop = itemTop - containerHeight / 2 + itemHeight / 2;
 			}
 		}
 	}}
@@ -483,9 +484,9 @@
 						}}
 					>
 						<div class="flex flex-col">
-							{#if $mobile && (item?.model?.info?.meta?.tags ?? []).length > 0}
+							{#if $mobile && (item?.model?.tags ?? []).length > 0}
 								<div class="flex gap-0.5 self-start h-full mb-1.5 -translate-x-1">
-									{#each item.model?.info?.meta.tags as tag}
+									{#each item.model?.tags as tag}
 										<div
 											class=" text-xs font-bold px-1 rounded-sm uppercase line-clamp-1 bg-gray-500/20 text-gray-700 dark:text-gray-200"
 										>
@@ -605,11 +606,11 @@
 									</Tooltip>
 								{/if}
 
-								{#if !$mobile && (item?.model?.info?.meta?.tags ?? []).length > 0}
+								{#if !$mobile && (item?.model?.tags ?? []).length > 0}
 									<div
 										class="flex gap-0.5 self-center items-center h-full translate-y-[0.5px] overflow-x-auto scrollbar-none"
 									>
-										{#each item.model?.info?.meta.tags as tag}
+										{#each item.model?.tags as tag}
 											<Tooltip content={tag.name} className="flex-shrink-0">
 												<div
 													class=" text-xs font-bold px-1 rounded-sm uppercase bg-gray-500/20 text-gray-700 dark:text-gray-200"