浏览代码

feat: chat auto tag

Timothy J. Baek 6 月之前
父节点
当前提交
d795940ced

+ 1 - 0
backend/open_webui/constants.py

@@ -108,6 +108,7 @@ class TASKS(str, Enum):
 
     DEFAULT = lambda task="": f"{task if task else 'generation'}"
     TITLE_GENERATION = "title_generation"
+    TAGS_GENERATION = "tags_generation"
     EMOJI_GENERATION = "emoji_generation"
     QUERY_GENERATION = "query_generation"
     FUNCTION_CALLING = "function_calling"

+ 67 - 0
backend/open_webui/main.py

@@ -134,6 +134,7 @@ from open_webui.utils.misc import (
 )
 from open_webui.utils.task import (
     moa_response_generation_template,
+    tags_generation_template,
     search_query_generation_template,
     title_generation_template,
     tools_function_calling_generation_template,
@@ -1545,6 +1546,72 @@ Prompt: {{prompt:middletruncate:8000}}"""
     return await generate_chat_completions(form_data=payload, user=user)
 
 
+@app.post("/api/task/tags/completions")
+async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
+    print("generate_chat_tags")
+    model_id = form_data["model"]
+    if model_id not in app.state.MODELS:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="Model not found",
+        )
+
+    # Check if the user has a custom task model
+    # If the user has a custom task model, use that model
+    task_model_id = get_task_model_id(model_id)
+    print(task_model_id)
+
+    template = """### Task:
+Generate 1-3 broad tags categorizing the main themes of the chat history.
+
+### Guidelines:
+- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education)
+- Only add more specific subdomains if they are strongly represented throughout the conversation
+- If content is too short (less than 3 messages) or too diverse, use only ["General"]
+- Use the chat's primary language; default to English if multilingual
+- Prioritize accuracy over specificity
+
+### Output:
+JSON format: { "tags": ["tag1", "tag2", "tag3"] }
+
+### Chat History:
+<chat_history>
+{{MESSAGES:END:6}}
+</chat_history>"""
+
+    content = tags_generation_template(
+        template, form_data["messages"], {"name": user.name}
+    )
+
+    print("content", content)
+    payload = {
+        "model": task_model_id,
+        "messages": [{"role": "user", "content": content}],
+        "stream": False,
+        "metadata": {"task": str(TASKS.TAGS_GENERATION), "task_body": form_data},
+    }
+    log.debug(payload)
+
+    # Handle pipeline filters
+    try:
+        payload = filter_pipeline(payload, user)
+    except Exception as e:
+        if len(e.args) > 1:
+            return JSONResponse(
+                status_code=e.args[0],
+                content={"detail": e.args[1]},
+            )
+        else:
+            return JSONResponse(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                content={"detail": str(e)},
+            )
+    if "chat_id" in payload:
+        del payload["chat_id"]
+
+    return await generate_chat_completions(form_data=payload, user=user)
+
+
 @app.post("/api/task/query/completions")
 async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
     print("generate_search_query")

+ 18 - 0
backend/open_webui/utils/task.py

@@ -123,6 +123,24 @@ def replace_messages_variable(template: str, messages: list[str]) -> str:
     return template
 
 
+def tags_generation_template(
+    template: str, messages: list[dict], user: Optional[dict] = None
+) -> str:
+    prompt = get_last_user_message(messages)
+    template = replace_prompt_variable(template, prompt)
+    template = replace_messages_variable(template, messages)
+
+    template = prompt_template(
+        template,
+        **(
+            {"user_name": user.get("name"), "user_location": user.get("location")}
+            if user
+            else {}
+        ),
+    )
+    return template
+
+
 def search_query_generation_template(
     template: str, messages: list[dict], user: Optional[dict] = None
 ) -> str:

+ 72 - 0
src/lib/apis/index.ts

@@ -245,6 +245,78 @@ export const generateTitle = async (
 	return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat';
 };
 
+export const generateTags = async (
+	token: string = '',
+	model: string,
+	messages: string,
+	chat_id?: string
+) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_BASE_URL}/api/task/tags/completions`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			model: model,
+			messages: messages,
+			...(chat_id && { chat_id: chat_id })
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	try {
+		// Step 1: Safely extract the response string
+		const response = res?.choices[0]?.message?.content ?? '';
+
+		// Step 2: Attempt to fix common JSON format issues like single quotes
+		const sanitizedResponse = response.replace(/['‘’`]/g, '"'); // Convert single quotes to double quotes for valid JSON
+
+		// Step 3: Find the relevant JSON block within the response
+		const jsonStartIndex = sanitizedResponse.indexOf('{');
+		const jsonEndIndex = sanitizedResponse.lastIndexOf('}');
+
+		// Step 4: Check if we found a valid JSON block (with both `{` and `}`)
+		if (jsonStartIndex !== -1 && jsonEndIndex !== -1) {
+			const jsonResponse = sanitizedResponse.substring(jsonStartIndex, jsonEndIndex + 1);
+
+			// Step 5: Parse the JSON block
+			const parsed = JSON.parse(jsonResponse);
+
+			// Step 6: If there's a "tags" key, return the tags array; otherwise, return an empty array
+			if (parsed && parsed.tags) {
+				return Array.isArray(parsed.tags) ? parsed.tags : [];
+			} else {
+				return [];
+			}
+		}
+
+		// If no valid JSON block found, return an empty array
+		return [];
+	} catch (e) {
+		// Catch and safely return empty array on any parsing errors
+		console.error('Failed to parse response: ', e);
+		return [];
+	}
+};
+
 export const generateEmoji = async (
 	token: string = '',
 	model: string,

+ 45 - 9
src/lib/components/chat/Chat.svelte

@@ -10,7 +10,7 @@
 	import { goto } from '$app/navigation';
 	import { page } from '$app/stores';
 
-	import type { Unsubscriber, Writable } from 'svelte/store';
+	import { get, type Unsubscriber, type Writable } from 'svelte/store';
 	import type { i18n as i18nType } from 'i18next';
 	import { WEBUI_BASE_URL } from '$lib/constants';
 
@@ -20,6 +20,7 @@
 		config,
 		type Model,
 		models,
+		tags as allTags,
 		settings,
 		showSidebar,
 		WEBUI_NAME,
@@ -46,7 +47,9 @@
 
 	import { generateChatCompletion } from '$lib/apis/ollama';
 	import {
+		addTagById,
 		createNewChat,
+		getAllTags,
 		getChatById,
 		getChatList,
 		getTagsById,
@@ -62,7 +65,8 @@
 		generateTitle,
 		generateSearchQuery,
 		chatAction,
-		generateMoACompletion
+		generateMoACompletion,
+		generateTags
 	} from '$lib/apis';
 
 	import Banner from '../common/Banner.svelte';
@@ -537,7 +541,10 @@
 		});
 
 		if (chat) {
-			tags = await getTags();
+			tags = await getTagsById(localStorage.token, $chatId).catch(async (error) => {
+				return [];
+			});
+
 			const chatContent = chat.chat;
 
 			if (chatContent) {
@@ -1393,6 +1400,10 @@
 			window.history.replaceState(history.state, '', `/c/${_chatId}`);
 			const title = await generateChatTitle(userPrompt);
 			await setChatTitle(_chatId, title);
+
+			if ($settings?.autoTags ?? true) {
+				await setChatTags(messages);
+			}
 		}
 
 		return _response;
@@ -1707,6 +1718,10 @@
 			window.history.replaceState(history.state, '', `/c/${_chatId}`);
 			const title = await generateChatTitle(userPrompt);
 			await setChatTitle(_chatId, title);
+
+			if ($settings?.autoTags ?? true) {
+				await setChatTags(messages);
+			}
 		}
 
 		return _response;
@@ -1893,6 +1908,33 @@
 		}
 	};
 
+	const setChatTags = async (messages) => {
+		if (!$temporaryChatEnabled) {
+			let generatedTags = await generateTags(
+				localStorage.token,
+				selectedModels[0],
+				messages,
+				$chatId
+			).catch((error) => {
+				console.error(error);
+				return [];
+			});
+
+			const currentTags = await getTagsById(localStorage.token, $chatId);
+			generatedTags = generatedTags.filter(
+				(tag) => !currentTags.find((t) => t.id === tag.replaceAll(' ', '_').toLowerCase())
+			);
+			console.log(generatedTags);
+
+			for (const tag of generatedTags) {
+				await addTagById(localStorage.token, $chatId, tag);
+			}
+
+			chat = await getChatById(localStorage.token, $chatId);
+			allTags.set(await getAllTags(localStorage.token));
+		}
+	};
+
 	const getWebSearchResults = async (
 		model: string,
 		parentId: string,
@@ -1978,12 +2020,6 @@
 		}
 	};
 
-	const getTags = async () => {
-		return await getTagsById(localStorage.token, $chatId).catch(async (error) => {
-			return [];
-		});
-	};
-
 	const initChatHandler = async () => {
 		if (!$temporaryChatEnabled) {
 			chat = await createNewChat(localStorage.token, {

+ 28 - 0
src/lib/components/chat/Settings/Interface.svelte

@@ -19,6 +19,8 @@
 
 	// Addons
 	let titleAutoGenerate = true;
+	let autoTags = true;
+
 	let responseAutoCopy = false;
 	let widescreenMode = false;
 	let splitLargeChunks = false;
@@ -112,6 +114,11 @@
 		});
 	};
 
+	const toggleAutoTags = async () => {
+		autoTags = !autoTags;
+		saveSettings({ autoTags });
+	};
+
 	const toggleResponseAutoCopy = async () => {
 		const permission = await navigator.clipboard
 			.readText()
@@ -149,6 +156,7 @@
 
 	onMount(async () => {
 		titleAutoGenerate = $settings?.title?.auto ?? true;
+		autoTags = $settings.autoTags ?? true;
 
 		responseAutoCopy = $settings.responseAutoCopy ?? false;
 		showUsername = $settings.showUsername ?? false;
@@ -431,6 +439,26 @@
 				</div>
 			</div>
 
+			<div>
+				<div class=" py-0.5 flex w-full justify-between">
+					<div class=" self-center text-xs">{$i18n.t('Chat Tags Auto-Generation')}</div>
+
+					<button
+						class="p-1 px-3 text-xs flex rounded transition"
+						on:click={() => {
+							toggleAutoTags();
+						}}
+						type="button"
+					>
+						{#if autoTags === true}
+							<span class="ml-2 self-center">{$i18n.t('On')}</span>
+						{:else}
+							<span class="ml-2 self-center">{$i18n.t('Off')}</span>
+						{/if}
+					</button>
+				</div>
+			</div>
+
 			<div>
 				<div class=" py-0.5 flex w-full justify-between">
 					<div class=" self-center text-xs">

+ 2 - 2
src/lib/components/layout/Sidebar/SearchInput.svelte

@@ -144,7 +144,7 @@
 				{#if filteredTags.length > 0}
 					<div class="px-1 font-medium dark:text-gray-300 text-gray-700 mb-1">Tags</div>
 
-					<div class="">
+					<div class="max-h-60 overflow-auto">
 						{#each filteredTags as tag, tagIdx}
 							<button
 								class=" px-1.5 py-0.5 flex gap-1 hover:bg-gray-100 dark:hover:bg-gray-900 w-full rounded {selectedIdx ===
@@ -174,7 +174,7 @@
 				{:else if filteredOptions.length > 0}
 					<div class="px-1 font-medium dark:text-gray-300 text-gray-700 mb-1">Search options</div>
 
-					<div class="">
+					<div class=" max-h-60 overflow-auto">
 						{#each filteredOptions as option, optionIdx}
 							<button
 								class=" px-1.5 py-0.5 flex gap-1 hover:bg-gray-100 dark:hover:bg-gray-900 w-full rounded {selectedIdx ===