Browse Source

feat: show current running models

Timothy J. Baek 11 months ago
parent
commit
2be9c25ba7

+ 65 - 0
backend/apps/socket/main.py

@@ -1,4 +1,6 @@
 import socketio
 import socketio
+import asyncio
+
 
 
 from apps.webui.models.users import Users
 from apps.webui.models.users import Users
 from utils.utils import decode_token
 from utils.utils import decode_token
@@ -10,6 +12,9 @@ app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")
 
 
 
 
 USER_POOL = {}
 USER_POOL = {}
+USAGE_POOL = {}
+# Timeout duration in seconds
+TIMEOUT_DURATION = 3
 
 
 
 
 @sio.event
 @sio.event
@@ -57,6 +62,66 @@ async def user_count(sid):
     await sio.emit("user-count", {"count": len(set(USER_POOL))})
     await sio.emit("user-count", {"count": len(set(USER_POOL))})
 
 
 
 
+def get_models_in_use():
+    # Aggregate all models in use
+
+    models_in_use = []
+    for sid, data in USAGE_POOL.items():
+        models_in_use.extend(data["models"])
+    print(f"Models in use: {models_in_use}")
+
+    return models_in_use
+
+
+@sio.on("usage")
+async def usage(sid, data):
+    print(f'Received "usage" event from {sid}: {data}')
+
+    # Cancel previous task if there is one
+    if sid in USAGE_POOL:
+        USAGE_POOL[sid]["task"].cancel()
+
+    # Store the new usage data and task
+    model_id = data["model"]
+
+    if sid in USAGE_POOL and "models" in USAGE_POOL[sid]:
+
+        print(USAGE_POOL[sid])
+
+        models = USAGE_POOL[sid]["models"]
+        if model_id not in models:
+            models.append(model_id)
+            USAGE_POOL[sid] = {"models": models}
+
+    else:
+        USAGE_POOL[sid] = {"models": [model_id]}
+
+    # Schedule a task to remove the usage data after TIMEOUT_DURATION
+    USAGE_POOL[sid]["task"] = asyncio.create_task(remove_after_timeout(sid, model_id))
+
+    models_in_use = get_models_in_use()
+    # Broadcast the usage data to all clients
+    await sio.emit("usage", {"models": models_in_use})
+
+
+async def remove_after_timeout(sid, model_id):
+    try:
+        await asyncio.sleep(TIMEOUT_DURATION)
+        if sid in USAGE_POOL:
+            if model_id in USAGE_POOL[sid]["models"]:
+                USAGE_POOL[sid]["models"].remove(model_id)
+            if len(USAGE_POOL[sid]["models"]) == 0:
+                del USAGE_POOL[sid]
+            print(f"Removed usage data for {sid} due to timeout")
+
+            models_in_use = get_models_in_use()
+            # Broadcast the usage data to all clients
+            await sio.emit("usage", {"models": models_in_use})
+    except asyncio.CancelledError:
+        # Task was cancelled due to new 'usage' event
+        pass
+
+
 @sio.event
 @sio.event
 async def disconnect(sid):
 async def disconnect(sid):
     if sid in USER_POOL:
     if sid in USER_POOL:

+ 19 - 1
src/lib/components/chat/Chat.svelte

@@ -18,7 +18,8 @@
 		tags as _tags,
 		tags as _tags,
 		WEBUI_NAME,
 		WEBUI_NAME,
 		banners,
 		banners,
-		user
+		user,
+		socket
 	} from '$lib/stores';
 	} from '$lib/stores';
 	import {
 	import {
 		convertMessagesToHistory,
 		convertMessagesToHistory,
@@ -280,6 +281,16 @@
 		}
 		}
 	};
 	};
 
 
+	const getChatEventEmitter = async (modelId: string, chatId: string = '') => {
+		return setInterval(() => {
+			$socket?.emit('usage', {
+				action: 'chat',
+				model: modelId,
+				chat_id: chatId
+			});
+		}, 1000);
+	};
+
 	//////////////////////////
 	//////////////////////////
 	// Ollama functions
 	// Ollama functions
 	//////////////////////////
 	//////////////////////////
@@ -451,6 +462,8 @@
 					}
 					}
 					responseMessage.userContext = userContext;
 					responseMessage.userContext = userContext;
 
 
+					const chatEventEmitter = await getChatEventEmitter(model.id, _chatId);
+
 					if (webSearchEnabled) {
 					if (webSearchEnabled) {
 						await getWebSearchResults(model.id, parentId, responseMessageId);
 						await getWebSearchResults(model.id, parentId, responseMessageId);
 					}
 					}
@@ -460,6 +473,10 @@
 					} else if (model) {
 					} else if (model) {
 						await sendPromptOllama(model, prompt, responseMessageId, _chatId);
 						await sendPromptOllama(model, prompt, responseMessageId, _chatId);
 					}
 					}
+
+					console.log('chatEventEmitter', chatEventEmitter);
+
+					if (chatEventEmitter) clearInterval(chatEventEmitter);
 				} else {
 				} else {
 					toast.error($i18n.t(`Model {{modelId}} not found`, { modelId }));
 					toast.error($i18n.t(`Model {{modelId}} not found`, { modelId }));
 				}
 				}
@@ -542,6 +559,7 @@
 
 
 	const sendPromptOllama = async (model, userPrompt, responseMessageId, _chatId) => {
 	const sendPromptOllama = async (model, userPrompt, responseMessageId, _chatId) => {
 		model = model.id;
 		model = model.id;
+
 		const responseMessage = history.messages[responseMessageId];
 		const responseMessage = history.messages[responseMessageId];
 
 
 		// Wait until history/message have been updated
 		// Wait until history/message have been updated

+ 4 - 0
src/lib/components/common/Tooltip.svelte

@@ -21,6 +21,10 @@
 				touch: touch
 				touch: touch
 			});
 			});
 		}
 		}
+	} else if (tooltipInstance && content === '') {
+		if (tooltipInstance) {
+			tooltipInstance.destroy();
+		}
 	}
 	}
 
 
 	onDestroy(() => {
 	onDestroy(() => {

+ 25 - 18
src/lib/components/layout/Sidebar/UserMenu.svelte

@@ -5,8 +5,9 @@
 	import { flyAndScale } from '$lib/utils/transitions';
 	import { flyAndScale } from '$lib/utils/transitions';
 	import { goto } from '$app/navigation';
 	import { goto } from '$app/navigation';
 	import ArchiveBox from '$lib/components/icons/ArchiveBox.svelte';
 	import ArchiveBox from '$lib/components/icons/ArchiveBox.svelte';
-	import { showSettings, activeUserCount } from '$lib/stores';
+	import { showSettings, activeUserCount, USAGE_POOL } from '$lib/stores';
 	import { fade, slide } from 'svelte/transition';
 	import { fade, slide } from 'svelte/transition';
+	import Tooltip from '$lib/components/common/Tooltip.svelte';
 
 
 	const i18n = getContext('i18n');
 	const i18n = getContext('i18n');
 
 
@@ -142,25 +143,31 @@
 			{#if $activeUserCount}
 			{#if $activeUserCount}
 				<hr class=" dark:border-gray-800 my-1.5 p-0" />
 				<hr class=" dark:border-gray-800 my-1.5 p-0" />
 
 
-				<div class="flex rounded-md py-1.5 px-3 text-xs gap-2.5 items-center">
-					<div class=" flex items-center">
-						<span class="relative flex size-2">
-							<span
-								class="animate-ping absolute inline-flex h-full w-full rounded-full bg-green-400 opacity-75"
-							/>
-							<span class="relative inline-flex rounded-full size-2 bg-green-500" />
-						</span>
-					</div>
+				<Tooltip
+					content={$USAGE_POOL && $USAGE_POOL.length > 0
+						? `Running: ${$USAGE_POOL.join(',')} ✨`
+						: ''}
+				>
+					<div class="flex rounded-md py-1.5 px-3 text-xs gap-2.5 items-center">
+						<div class=" flex items-center">
+							<span class="relative flex size-2">
+								<span
+									class="animate-ping absolute inline-flex h-full w-full rounded-full bg-green-400 opacity-75"
+								/>
+								<span class="relative inline-flex rounded-full size-2 bg-green-500" />
+							</span>
+						</div>
 
 
-					<div class=" translate-y-[0.25px]">
-						<span class=" font-medium">
-							{$i18n.t('Active Users')}:
-						</span>
-						<span class=" font-semibold">
-							{$activeUserCount}
-						</span>
+						<div class=" ">
+							<span class=" font-medium">
+								{$i18n.t('Active Users')}:
+							</span>
+							<span class=" font-semibold">
+								{$activeUserCount}
+							</span>
+						</div>
 					</div>
 					</div>
-				</div>
+				</Tooltip>
 			{/if}
 			{/if}
 
 
 			<!-- <DropdownMenu.Item class="flex items-center px-3 py-2 text-sm  font-medium">
 			<!-- <DropdownMenu.Item class="flex items-center px-3 py-2 text-sm  font-medium">

+ 1 - 0
src/lib/stores/index.ts

@@ -16,6 +16,7 @@ export const mobile = writable(false);
 
 
 export const socket: Writable<null | Socket> = writable(null);
 export const socket: Writable<null | Socket> = writable(null);
 export const activeUserCount: Writable<null | number> = writable(null);
 export const activeUserCount: Writable<null | number> = writable(null);
+export const USAGE_POOL: Writable<null | string[]> = writable(null);
 
 
 export const theme = writable('system');
 export const theme = writable('system');
 export const chatId = writable('');
 export const chatId = writable('');

+ 15 - 1
src/routes/+layout.svelte

@@ -2,7 +2,16 @@
 	import { io } from 'socket.io-client';
 	import { io } from 'socket.io-client';
 
 
 	import { onMount, tick, setContext } from 'svelte';
 	import { onMount, tick, setContext } from 'svelte';
-	import { config, user, theme, WEBUI_NAME, mobile, socket, activeUserCount } from '$lib/stores';
+	import {
+		config,
+		user,
+		theme,
+		WEBUI_NAME,
+		mobile,
+		socket,
+		activeUserCount,
+		USAGE_POOL
+	} from '$lib/stores';
 	import { goto } from '$app/navigation';
 	import { goto } from '$app/navigation';
 	import { Toaster, toast } from 'svelte-sonner';
 	import { Toaster, toast } from 'svelte-sonner';
 
 
@@ -76,6 +85,11 @@
 					activeUserCount.set(data.count);
 					activeUserCount.set(data.count);
 				});
 				});
 
 
+				_socket.on('usage', (data) => {
+					console.log('usage', data);
+					USAGE_POOL.set(data['models']);
+				});
+
 				if (localStorage.token) {
 				if (localStorage.token) {
 					// Get Session User Info
 					// Get Session User Info
 					const sessionUser = await getSessionUser(localStorage.token).catch((error) => {
 					const sessionUser = await getSessionUser(localStorage.token).catch((error) => {