فهرست منبع

feat: user_location

Timothy J. Baek 10 ماه پیش
والد
کامیت
4b6b33b08b

+ 48 - 0
backend/apps/webui/internal/migrations/013_add_user_info.py

@@ -0,0 +1,48 @@
+"""Peewee migrations -- 002_add_local_sharing.py.
+
+Some examples (model - class or model name)::
+
+    > Model = migrator.orm['table_name']            # Return model in current state by name
+    > Model = migrator.ModelClass                   # Return model in current state by name
+
+    > migrator.sql(sql)                             # Run custom SQL
+    > migrator.run(func, *args, **kwargs)           # Run python function with the given args
+    > migrator.create_model(Model)                  # Create a model (could be used as decorator)
+    > migrator.remove_model(model, cascade=True)    # Remove a model
+    > migrator.add_fields(model, **fields)          # Add fields to a model
+    > migrator.change_fields(model, **fields)       # Change fields
+    > migrator.remove_fields(model, *field_names, cascade=True)
+    > migrator.rename_field(model, old_field_name, new_field_name)
+    > migrator.rename_table(model, new_table_name)
+    > migrator.add_index(model, *col_names, unique=False)
+    > migrator.add_not_null(model, *field_names)
+    > migrator.add_default(model, field_name, default)
+    > migrator.add_constraint(model, name, sql)
+    > migrator.drop_index(model, *col_names)
+    > migrator.drop_not_null(model, *field_names)
+    > migrator.drop_constraints(model, *constraints)
+
+"""
+
+from contextlib import suppress
+
+import peewee as pw
+from peewee_migrate import Migrator
+
+
+with suppress(ImportError):
+    import playhouse.postgres_ext as pw_pext
+
+
+def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your migrations here."""
+
+    # Adding fields info to the 'user' table
+    migrator.add_fields("user", info=pw.TextField(null=True))
+
+
+def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your rollback migrations here."""
+
+    # Remove the settings field
+    migrator.remove_fields("user", "info")

+ 2 - 0
backend/apps/webui/models/users.py

@@ -26,6 +26,7 @@ class User(Model):
 
     api_key = CharField(null=True, unique=True)
     settings = JSONField(null=True)
+    info = JSONField(null=True)
 
     class Meta:
         database = DB
@@ -50,6 +51,7 @@ class UserModel(BaseModel):
 
     api_key: Optional[str] = None
     settings: Optional[UserSettings] = None
+    info: Optional[dict] = None
 
 
 ####################

+ 46 - 0
backend/apps/webui/routers/users.py

@@ -115,6 +115,52 @@ async def update_user_settings_by_session_user(
         )
 
 
+############################
+# GetUserInfoBySessionUser
+############################
+
+
+@router.get("/user/info", response_model=Optional[dict])
+async def get_user_info_by_session_user(user=Depends(get_verified_user)):
+    user = Users.get_user_by_id(user.id)
+    if user:
+        return user.info
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.USER_NOT_FOUND,
+        )
+
+
+############################
+# UpdateUserInfoBySessionUser
+############################
+
+
+@router.post("/user/info/update", response_model=Optional[dict])
+async def update_user_settings_by_session_user(
+    form_data: dict, user=Depends(get_verified_user)
+):
+    user = Users.get_user_by_id(user.id)
+    if user:
+        if user.info is None:
+            user.info = {}
+
+        user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
+        if user:
+            return user.info
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.USER_NOT_FOUND,
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.USER_NOT_FOUND,
+        )
+
+
 ############################
 # GetUserById
 ############################

+ 13 - 3
backend/main.py

@@ -764,7 +764,12 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
     template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
 
     content = title_generation_template(
-        template, form_data["prompt"], user.model_dump()
+        template,
+        form_data["prompt"],
+        {
+            "name": user.name,
+            "location": user.info.get("location") if user.info else None,
+        },
     )
 
     payload = {
@@ -830,7 +835,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
     template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
 
     content = search_query_generation_template(
-        template, form_data["prompt"], user.model_dump()
+        template, form_data["prompt"], {"name": user.name}
     )
 
     payload = {
@@ -893,7 +898,12 @@ Message: """{{prompt}}"""
 '''
 
     content = title_generation_template(
-        template, form_data["prompt"], user.model_dump()
+        template,
+        form_data["prompt"],
+        {
+            "name": user.name,
+            "location": user.info.get("location") if user.info else None,
+        },
     )
 
     payload = {

+ 6 - 6
backend/utils/task.py

@@ -6,7 +6,7 @@ from typing import Optional
 
 
 def prompt_template(
-    template: str, user_name: str = None, current_location: str = None
+    template: str, user_name: str = None, user_location: str = None
 ) -> str:
     # Get the current date
     current_date = datetime.now()
@@ -25,9 +25,9 @@ def prompt_template(
         # Replace {{USER_NAME}} in the template with the user's name
         template = template.replace("{{USER_NAME}}", user_name)
 
-    if current_location:
-        # Replace {{CURRENT_LOCATION}} in the template with the current location
-        template = template.replace("{{CURRENT_LOCATION}}", current_location)
+    if user_location:
+        # Replace {{USER_LOCATION}} in the template with the current location
+        template = template.replace("{{USER_LOCATION}}", user_location)
 
     return template
 
@@ -65,7 +65,7 @@ def title_generation_template(
     template = prompt_template(
         template,
         **(
-            {"user_name": user.get("name"), "current_location": user.get("location")}
+            {"user_name": user.get("name"), "user_location": user.get("location")}
             if user
             else {}
         ),
@@ -108,7 +108,7 @@ def search_query_generation_template(
     template = prompt_template(
         template,
         **(
-            {"user_name": user.get("name"), "current_location": user.get("location")}
+            {"user_name": user.get("name"), "user_location": user.get("location")}
             if user
             else {}
         ),

+ 70 - 0
src/lib/apis/users/index.ts

@@ -1,4 +1,5 @@
 import { WEBUI_API_BASE_URL } from '$lib/constants';
+import { getUserPosition } from '$lib/utils';
 
 export const getUserPermissions = async (token: string) => {
 	let error = null;
@@ -198,6 +199,75 @@ export const getUserById = async (token: string, userId: string) => {
 	return res;
 };
 
+export const getUserInfo = async (token: string) => {
+	let error = null;
+	const res = await fetch(`${WEBUI_API_BASE_URL}/users/user/info`, {
+		method: 'GET',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const updateUserInfo = async (token: string, info: object) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/users/user/info/update`, {
+		method: 'POST',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			...info
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getAndUpdateUserLocation = async (token: string) => {
+	const location = await getUserPosition().catch((err) => {
+		throw err;
+	});
+
+	if (location) {
+		await updateUserInfo(token, { location: location });
+		return location;
+	} else {
+		throw new Error('Failed to get user location');
+	}
+};
+
 export const deleteUserById = async (token: string, userId: string) => {
 	let error = null;
 

+ 16 - 3
src/lib/components/chat/Chat.svelte

@@ -31,6 +31,7 @@
 		convertMessagesToHistory,
 		copyToClipboard,
 		extractSentencesForAudio,
+		getUserPosition,
 		promptTemplate,
 		splitStream
 	} from '$lib/utils';
@@ -50,7 +51,7 @@
 	import { runWebSearch } from '$lib/apis/rag';
 	import { createOpenAITextStream } from '$lib/apis/streaming';
 	import { queryMemory } from '$lib/apis/memories';
-	import { getUserSettings } from '$lib/apis/users';
+	import { getAndUpdateUserLocation, getUserSettings } from '$lib/apis/users';
 	import { chatCompleted, generateTitle, generateSearchQuery } from '$lib/apis';
 
 	import Banner from '../common/Banner.svelte';
@@ -533,7 +534,13 @@
 			$settings.system || (responseMessage?.userContext ?? null)
 				? {
 						role: 'system',
-						content: `${promptTemplate($settings?.system ?? '', $user.name)}${
+						content: `${promptTemplate(
+							$settings?.system ?? '',
+							$user.name,
+							$settings?.userLocation
+								? await getAndUpdateUserLocation(localStorage.token)
+								: undefined
+						)}${
 							responseMessage?.userContext ?? null
 								? `\n\nUser Context:\n${(responseMessage?.userContext ?? []).join('\n')}`
 								: ''
@@ -871,7 +878,13 @@
 						$settings.system || (responseMessage?.userContext ?? null)
 							? {
 									role: 'system',
-									content: `${promptTemplate($settings?.system ?? '', $user.name)}${
+									content: `${promptTemplate(
+										$settings?.system ?? '',
+										$user.name,
+										$settings?.userLocation
+											? await getAndUpdateUserLocation(localStorage.token)
+											: undefined
+									)}${
 										responseMessage?.userContext ?? null
 											? `\n\nUser Context:\n${(responseMessage?.userContext ?? []).join('\n')}`
 											: ''

+ 47 - 3
src/lib/components/chat/Settings/Interface.svelte

@@ -5,6 +5,8 @@
 	import { createEventDispatcher, onMount, getContext } from 'svelte';
 	import { toast } from 'svelte-sonner';
 	import Tooltip from '$lib/components/common/Tooltip.svelte';
+	import { updateUserInfo } from '$lib/apis/users';
+	import { getUserPosition } from '$lib/utils';
 	const dispatch = createEventDispatcher();
 
 	const i18n = getContext('i18n');
@@ -16,6 +18,7 @@
 	let responseAutoCopy = false;
 	let widescreenMode = false;
 	let splitLargeChunks = false;
+	let userLocation = false;
 
 	// Interface
 	let defaultModelId = '';
@@ -51,6 +54,26 @@
 		saveSettings({ showEmojiInCall: showEmojiInCall });
 	};
 
+	const toggleUserLocation = async () => {
+		userLocation = !userLocation;
+
+		if (userLocation) {
+			const position = await getUserPosition().catch((error) => {
+				toast.error(error.message);
+				return null;
+			});
+
+			if (position) {
+				await updateUserInfo(localStorage.token, { location: position });
+				toast.success('User location successfully retrieved.');
+			} else {
+				userLocation = false;
+			}
+		}
+
+		saveSettings({ userLocation });
+	};
+
 	const toggleTitleAutoGenerate = async () => {
 		titleAutoGenerate = !titleAutoGenerate;
 		saveSettings({
@@ -106,6 +129,7 @@
 		widescreenMode = $settings.widescreenMode ?? false;
 		splitLargeChunks = $settings.splitLargeChunks ?? false;
 		chatDirection = $settings.chatDirection ?? 'LTR';
+		userLocation = $settings.userLocation ?? false;
 
 		defaultModelId = ($settings?.models ?? ['']).at(0);
 	});
@@ -142,6 +166,26 @@
 				</div>
 			</div>
 
+			<div>
+				<div class=" py-0.5 flex w-full justify-between">
+					<div class=" self-center text-xs font-medium">{$i18n.t('Widescreen Mode')}</div>
+
+					<button
+						class="p-1 px-3 text-xs flex rounded transition"
+						on:click={() => {
+							togglewidescreenMode();
+						}}
+						type="button"
+					>
+						{#if widescreenMode === true}
+							<span class="ml-2 self-center">{$i18n.t('On')}</span>
+						{:else}
+							<span class="ml-2 self-center">{$i18n.t('Off')}</span>
+						{/if}
+					</button>
+				</div>
+			</div>
+
 			<div>
 				<div class=" py-0.5 flex w-full justify-between">
 					<div class=" self-center text-xs font-medium">{$i18n.t('Title Auto-Generation')}</div>
@@ -186,16 +230,16 @@
 
 			<div>
 				<div class=" py-0.5 flex w-full justify-between">
-					<div class=" self-center text-xs font-medium">{$i18n.t('Widescreen Mode')}</div>
+					<div class=" self-center text-xs font-medium">{$i18n.t('Allow User Location')}</div>
 
 					<button
 						class="p-1 px-3 text-xs flex rounded transition"
 						on:click={() => {
-							togglewidescreenMode();
+							toggleUserLocation();
 						}}
 						type="button"
 					>
-						{#if widescreenMode === true}
+						{#if userLocation === true}
 							<span class="ml-2 self-center">{$i18n.t('On')}</span>
 						{:else}
 							<span class="ml-2 self-center">{$i18n.t('Off')}</span>

+ 27 - 4
src/lib/utils/index.ts

@@ -302,6 +302,29 @@ export const getImportOrigin = (_chats) => {
 	return 'webui';
 };
 
+export const getUserPosition = async (raw = false) => {
+	// Get the user's location using the Geolocation API
+	const position = await new Promise((resolve, reject) => {
+		navigator.geolocation.getCurrentPosition(resolve, reject);
+	}).catch((error) => {
+		console.error('Error getting user location:', error);
+		throw error;
+	});
+
+	if (!position) {
+		return 'Location not available';
+	}
+
+	// Extract the latitude and longitude from the position
+	const { latitude, longitude } = position.coords;
+
+	if (raw) {
+		return { latitude, longitude };
+	} else {
+		return `${latitude.toFixed(3)}, ${longitude.toFixed(3)} (lat, long)`;
+	}
+};
+
 const convertOpenAIMessages = (convo) => {
 	// Parse OpenAI chat messages and create chat dictionary for creating new chats
 	const mapping = convo['mapping'];
@@ -474,7 +497,7 @@ export const blobToFile = (blob, fileName) => {
 export const promptTemplate = (
 	template: string,
 	user_name?: string,
-	current_location?: string
+	user_location?: string
 ): string => {
 	// Get the current date
 	const currentDate = new Date();
@@ -509,9 +532,9 @@ export const promptTemplate = (
 		template = template.replace('{{USER_NAME}}', user_name);
 	}
 
-	if (current_location) {
-		// Replace {{CURRENT_LOCATION}} in the template with the current location
-		template = template.replace('{{CURRENT_LOCATION}}', current_location);
+	if (user_location) {
+		// Replace {{USER_LOCATION}} in the template with the current location
+		template = template.replace('{{USER_LOCATION}}', user_location);
 	}
 
 	return template;