Browse Source

feat: global filter

Timothy J. Baek 10 months ago
parent
commit
edbd07f893

+ 49 - 0
backend/apps/webui/internal/migrations/018_add_function_is_global.py

@@ -0,0 +1,49 @@
+"""Peewee migrations -- 017_add_user_oauth_sub.py.
+
+Some examples (model - class or model name)::
+
+    > Model = migrator.orm['table_name']            # Return model in current state by name
+    > Model = migrator.ModelClass                   # Return model in current state by name
+
+    > migrator.sql(sql)                             # Run custom SQL
+    > migrator.run(func, *args, **kwargs)           # Run python function with the given args
+    > migrator.create_model(Model)                  # Create a model (could be used as decorator)
+    > migrator.remove_model(model, cascade=True)    # Remove a model
+    > migrator.add_fields(model, **fields)          # Add fields to a model
+    > migrator.change_fields(model, **fields)       # Change fields
+    > migrator.remove_fields(model, *field_names, cascade=True)
+    > migrator.rename_field(model, old_field_name, new_field_name)
+    > migrator.rename_table(model, new_table_name)
+    > migrator.add_index(model, *col_names, unique=False)
+    > migrator.add_not_null(model, *field_names)
+    > migrator.add_default(model, field_name, default)
+    > migrator.add_constraint(model, name, sql)
+    > migrator.drop_index(model, *col_names)
+    > migrator.drop_not_null(model, *field_names)
+    > migrator.drop_constraints(model, *constraints)
+
+"""
+
+from contextlib import suppress
+
+import peewee as pw
+from peewee_migrate import Migrator
+
+
+with suppress(ImportError):
+    import playhouse.postgres_ext as pw_pext
+
+
+def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your migrations here."""
+
+    migrator.add_fields(
+        "function",
+        is_global=pw.BooleanField(default=False),
+    )
+
+
+def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your rollback migrations here."""
+
+    migrator.remove_fields("function", "is_global")

+ 13 - 0
backend/apps/webui/models/functions.py

@@ -30,6 +30,7 @@ class Function(Model):
     meta = JSONField()
     valves = JSONField()
     is_active = BooleanField(default=False)
+    is_global = BooleanField(default=False)
     updated_at = BigIntegerField()
     created_at = BigIntegerField()
 
@@ -50,6 +51,7 @@ class FunctionModel(BaseModel):
     content: str
     meta: FunctionMeta
     is_active: bool = False
+    is_global: bool = False
     updated_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
 
@@ -66,6 +68,7 @@ class FunctionResponse(BaseModel):
     name: str
     meta: FunctionMeta
     is_active: bool
+    is_global: bool
     updated_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
 
@@ -144,6 +147,16 @@ class FunctionsTable:
                 for function in Function.select().where(Function.type == type)
             ]
 
+    def get_global_filter_functions(self) -> List[FunctionModel]:
+        return [
+            FunctionModel(**model_to_dict(function))
+            for function in Function.select().where(
+                Function.type == "filter",
+                Function.is_active == True,
+                Function.is_global == True,
+            )
+        ]
+
     def get_function_valves_by_id(self, id: str) -> Optional[dict]:
         try:
             function = Function.get(Function.id == id)

+ 27 - 0
backend/apps/webui/routers/functions.py

@@ -147,6 +147,33 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
         )
 
 
+############################
+# ToggleGlobalById
+############################
+
+
+@router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
+async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
+    function = Functions.get_function_by_id(id)
+    if function:
+        function = Functions.update_function_by_id(
+            id, {"is_global": not function.is_global}
+        )
+
+        if function:
+            return function
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
 ############################
 # UpdateFunctionById
 ############################

+ 24 - 25
backend/main.py

@@ -416,21 +416,23 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                     )
                 return 0
 
-            filter_ids = []
+            filter_ids = [
+                function.id for function in Functions.get_global_filter_functions()
+            ]
             if "info" in model and "meta" in model["info"]:
-                enabled_filter_ids = [
-                    function.id
-                    for function in Functions.get_functions_by_type(
-                        "filter", active_only=True
-                    )
-                ]
-                filter_ids = [
-                    filter_id
-                    for filter_id in enabled_filter_ids
-                    if filter_id in model["info"]["meta"].get("filterIds", [])
-                ]
+                filter_ids.extend(model["info"]["meta"].get("filterIds", []))
                 filter_ids = list(set(filter_ids))
 
+            enabled_filter_ids = [
+                function.id
+                for function in Functions.get_functions_by_type(
+                    "filter", active_only=True
+                )
+            ]
+            filter_ids = [
+                filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
+            ]
+
             filter_ids.sort(key=get_priority)
             for filter_id in filter_ids:
                 filter = Functions.get_function_by_id(filter_id)
@@ -919,7 +921,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
         )
 
     model = app.state.MODELS[model_id]
-    print(model)
 
     pipe = model.get("pipe")
     if pipe:
@@ -1010,21 +1011,19 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
             return (function.valves if function.valves else {}).get("priority", 0)
         return 0
 
-    filter_ids = []
+    filter_ids = [function.id for function in Functions.get_global_filter_functions()]
     if "info" in model and "meta" in model["info"]:
-        enabled_filter_ids = [
-            function.id
-            for function in Functions.get_functions_by_type(
-                "filter", active_only=True
-            )
-        ]
-        filter_ids = [
-            filter_id
-            for filter_id in enabled_filter_ids
-            if filter_id in model["info"]["meta"].get("filterIds", [])
-        ]
+        filter_ids.extend(model["info"]["meta"].get("filterIds", []))
         filter_ids = list(set(filter_ids))
 
+    enabled_filter_ids = [
+        function.id
+        for function in Functions.get_functions_by_type("filter", active_only=True)
+    ]
+    filter_ids = [
+        filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
+    ]
+
     # Sort filter_ids by priority, using the get_priority function
     filter_ids.sort(key=get_priority)
 

+ 32 - 0
src/lib/apis/functions/index.ts

@@ -224,6 +224,38 @@ export const toggleFunctionById = async (token: string, id: string) => {
 	return res;
 };
 
+export const toggleGlobalById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/toggle/global`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const getFunctionValvesById = async (token: string, id: string) => {
 	let error = null;
 

+ 19 - 0
src/lib/components/icons/GlobeAlt.svelte

@@ -0,0 +1,19 @@
+<script lang="ts">
+	export let className = 'w-4 h-4';
+	export let strokeWidth = '1.5';
+</script>
+
+<svg
+	xmlns="http://www.w3.org/2000/svg"
+	fill="none"
+	viewBox="0 0 24 24"
+	stroke-width={strokeWidth}
+	stroke="currentColor"
+	class={className}
+>
+	<path
+		stroke-linecap="round"
+		stroke-linejoin="round"
+		d="M12 21a9.004 9.004 0 0 0 8.716-6.747M12 21a9.004 9.004 0 0 1-8.716-6.747M12 21c2.485 0 4.5-4.03 4.5-9S14.485 3 12 3m0 18c-2.485 0-4.5-4.03-4.5-9S9.515 3 12 3m0 0a8.997 8.997 0 0 1 7.843 4.582M12 3a8.997 8.997 0 0 0-7.843 4.582m15.686 0A11.953 11.953 0 0 1 12 10.5c-2.998 0-5.74-1.1-7.843-2.918m15.686 0A8.959 8.959 0 0 1 21 12c0 .778-.099 1.533-.284 2.253m0 0A17.919 17.919 0 0 1 12 16.5c-3.162 0-6.133-.815-8.716-2.247m0 0A9.015 9.015 0 0 1 3 12c0-1.605.42-3.113 1.157-4.418"
+	/>
+</svg>

+ 24 - 1
src/lib/components/workspace/Functions.svelte

@@ -14,7 +14,8 @@
 		exportFunctions,
 		getFunctionById,
 		getFunctions,
-		toggleFunctionById
+		toggleFunctionById,
+		toggleGlobalById
 	} from '$lib/apis/functions';
 
 	import ArrowDownTray from '../icons/ArrowDownTray.svelte';
@@ -113,6 +114,22 @@
 			models.set(await getModels(localStorage.token));
 		}
 	};
+
+	const toggleGlobalHandler = async (func) => {
+		const res = await toggleGlobalById(localStorage.token, func.id).catch((error) => {
+			toast.error(error);
+		});
+
+		if (res) {
+			if (func.is_global) {
+				toast.success($i18n.t('Filter is now globally enabled'));
+			} else {
+				toast.success($i18n.t('Filter is now globally disabled'));
+			}
+
+			functions.set(await getFunctions(localStorage.token));
+		}
+	};
 </script>
 
 <svelte:head>
@@ -259,6 +276,7 @@
 				</Tooltip>
 
 				<FunctionMenu
+					{func}
 					editHandler={() => {
 						goto(`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`);
 					}}
@@ -275,6 +293,11 @@
 						selectedFunction = func;
 						showDeleteConfirm = true;
 					}}
+					toggleGlobalHandler={() => {
+						if (func.type === 'filter') {
+							toggleGlobalHandler(func);
+						}
+					}}
 					onClose={() => {}}
 				>
 					<button

+ 24 - 3
src/lib/components/workspace/Functions/FunctionMenu.svelte

@@ -5,21 +5,24 @@
 
 	import Dropdown from '$lib/components/common/Dropdown.svelte';
 	import GarbageBin from '$lib/components/icons/GarbageBin.svelte';
-	import Pencil from '$lib/components/icons/Pencil.svelte';
 	import Tooltip from '$lib/components/common/Tooltip.svelte';
-	import Tags from '$lib/components/chat/Tags.svelte';
 	import Share from '$lib/components/icons/Share.svelte';
-	import ArchiveBox from '$lib/components/icons/ArchiveBox.svelte';
 	import DocumentDuplicate from '$lib/components/icons/DocumentDuplicate.svelte';
 	import ArrowDownTray from '$lib/components/icons/ArrowDownTray.svelte';
+	import Switch from '$lib/components/common/Switch.svelte';
+	import GlobeAlt from '$lib/components/icons/GlobeAlt.svelte';
 
 	const i18n = getContext('i18n');
 
+	export let func;
+
 	export let editHandler: Function;
 	export let shareHandler: Function;
 	export let cloneHandler: Function;
 	export let exportHandler: Function;
 	export let deleteHandler: Function;
+	export let toggleGlobalHandler: Function;
+
 	export let onClose: Function;
 
 	let show = false;
@@ -45,6 +48,24 @@
 			align="start"
 			transition={flyAndScale}
 		>
+			{#if func.type === 'filter'}
+				<div
+					class="flex gap-2 justify-between items-center px-3 py-2 text-sm font-medium cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-800 rounded-md"
+				>
+					<div class="flex gap-2 items-center">
+						<GlobeAlt />
+
+						<div class="flex items-center">{$i18n.t('Global')}</div>
+					</div>
+
+					<div>
+						<Switch on:change={toggleGlobalHandler} bind:state={func.is_global} />
+					</div>
+				</div>
+
+				<hr class="border-gray-100 dark:border-gray-800 my-1" />
+			{/if}
+
 			<DropdownMenu.Item
 				class="flex gap-2 items-center px-3 py-2 text-sm  font-medium cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-800  rounded-md"
 				on:click={() => {