Browse Source

feat: save user settings to db

Timothy J. Baek 11 months ago
parent
commit
ccbafca74c

+ 3 - 3
backend/apps/openai/main.py

@@ -198,7 +198,7 @@ async def fetch_url(url, key):
 
 
 def merge_models_lists(model_lists):
-    log.info(f"merge_models_lists {model_lists}")
+    log.debug(f"merge_models_lists {model_lists}")
     merged_list = []
 
     for idx, models in enumerate(model_lists):
@@ -237,7 +237,7 @@ async def get_all_models():
         ]
 
         responses = await asyncio.gather(*tasks)
-        log.info(f"get_all_models:responses() {responses}")
+        log.debug(f"get_all_models:responses() {responses}")
 
         models = {
             "data": merge_models_lists(
@@ -254,7 +254,7 @@ async def get_all_models():
             )
         }
 
-        log.info(f"models: {models}")
+        log.debug(f"models: {models}")
         app.state.MODELS = {model["id"]: model for model in models["data"]}
 
     return models

+ 48 - 0
backend/apps/webui/internal/migrations/011_add_user_settings.py

@@ -0,0 +1,48 @@
+"""Peewee migrations -- 002_add_local_sharing.py.
+
+Some examples (model - class or model name)::
+
+    > Model = migrator.orm['table_name']            # Return model in current state by name
+    > Model = migrator.ModelClass                   # Return model in current state by name
+
+    > migrator.sql(sql)                             # Run custom SQL
+    > migrator.run(func, *args, **kwargs)           # Run python function with the given args
+    > migrator.create_model(Model)                  # Create a model (could be used as decorator)
+    > migrator.remove_model(model, cascade=True)    # Remove a model
+    > migrator.add_fields(model, **fields)          # Add fields to a model
+    > migrator.change_fields(model, **fields)       # Change fields
+    > migrator.remove_fields(model, *field_names, cascade=True)
+    > migrator.rename_field(model, old_field_name, new_field_name)
+    > migrator.rename_table(model, new_table_name)
+    > migrator.add_index(model, *col_names, unique=False)
+    > migrator.add_not_null(model, *field_names)
+    > migrator.add_default(model, field_name, default)
+    > migrator.add_constraint(model, name, sql)
+    > migrator.drop_index(model, *col_names)
+    > migrator.drop_not_null(model, *field_names)
+    > migrator.drop_constraints(model, *constraints)
+
+"""
+
+from contextlib import suppress
+
+import peewee as pw
+from peewee_migrate import Migrator
+
+
+with suppress(ImportError):
+    import playhouse.postgres_ext as pw_pext
+
+
+def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your migrations here."""
+
+    # Adding fields settings to the 'user' table
+    migrator.add_fields("user", settings=pw.TextField(null=True))
+
+
+def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your rollback migrations here."""
+
+    # Remove the settings field
+    migrator.remove_fields("user", "settings")

+ 10 - 2
backend/apps/webui/models/users.py

@@ -1,11 +1,11 @@
-from pydantic import BaseModel
+from pydantic import BaseModel, ConfigDict
 from peewee import *
 from playhouse.shortcuts import model_to_dict
 from typing import List, Union, Optional
 import time
 from utils.misc import get_gravatar_url
 
-from apps.webui.internal.db import DB
+from apps.webui.internal.db import DB, JSONField
 from apps.webui.models.chats import Chats
 
 ####################
@@ -25,11 +25,18 @@ class User(Model):
     created_at = BigIntegerField()
 
     api_key = CharField(null=True, unique=True)
+    settings = JSONField(null=True)
 
     class Meta:
         database = DB
 
 
+class UserSettings(BaseModel):
+    ui: Optional[dict] = {}
+    model_config = ConfigDict(extra="allow")
+    pass
+
+
 class UserModel(BaseModel):
     id: str
     name: str
@@ -42,6 +49,7 @@ class UserModel(BaseModel):
     created_at: int  # timestamp in epoch
 
     api_key: Optional[str] = None
+    settings: Optional[UserSettings] = None
 
 
 ####################

+ 45 - 1
backend/apps/webui/routers/users.py

@@ -9,7 +9,13 @@ import time
 import uuid
 import logging
 
-from apps.webui.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
+from apps.webui.models.users import (
+    UserModel,
+    UserUpdateForm,
+    UserRoleUpdateForm,
+    UserSettings,
+    Users,
+)
 from apps.webui.models.auths import Auths
 from apps.webui.models.chats import Chats
 
@@ -68,6 +74,42 @@ async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin
     )
 
 
+############################
+# GetUserSettingsBySessionUser
+############################
+
+
+@router.get("/user/settings", response_model=Optional[UserSettings])
+async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
+    user = Users.get_user_by_id(user.id)
+    if user:
+        return user.settings
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.USER_NOT_FOUND,
+        )
+
+
+############################
+# UpdateUserSettingsBySessionUser
+############################
+
+
+@router.post("/user/settings/update", response_model=UserSettings)
+async def update_user_settings_by_session_user(
+    form_data: UserSettings, user=Depends(get_verified_user)
+):
+    user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
+    if user:
+        return user.settings
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.USER_NOT_FOUND,
+        )
+
+
 ############################
 # GetUserById
 ############################
@@ -81,6 +123,8 @@ class UserResponse(BaseModel):
 @router.get("/{user_id}", response_model=UserResponse)
 async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
 
+    # Check if user_id is a shared chat
+    # If it is, get the user_id from the chat
     if user_id.startswith("shared-"):
         chat_id = user_id.replace("shared-", "")
         chat = Chats.get_chat_by_id(chat_id)

+ 56 - 0
src/lib/apis/users/index.ts

@@ -115,6 +115,62 @@ export const getUsers = async (token: string) => {
 	return res ? res : [];
 };
 
+export const getUserSettings = async (token: string) => {
+	let error = null;
+	const res = await fetch(`${WEBUI_API_BASE_URL}/users/user/settings`, {
+		method: 'GET',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const updateUserSettings = async (token: string, settings: object) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/users/user/settings/update`, {
+		method: 'POST',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			...settings
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const getUserById = async (token: string, userId: string) => {
 	let error = null;
 

+ 19 - 8
src/lib/components/chat/Chat.svelte

@@ -42,6 +42,7 @@
 	import type { Writable } from 'svelte/store';
 	import type { i18n as i18nType } from 'i18next';
 	import Banner from '../common/Banner.svelte';
+	import { getUserSettings } from '$lib/apis/users';
 
 	const i18n: Writable<i18nType> = getContext('i18n');
 
@@ -154,10 +155,13 @@
 			$models.map((m) => m.id).includes(modelId) ? modelId : ''
 		);
 
-		let _settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
-		settings.set({
-			..._settings
-		});
+		const userSettings = await getUserSettings(localStorage.token);
+
+		if (userSettings) {
+			settings.set(userSettings.ui);
+		} else {
+			settings.set(JSON.parse(localStorage.getItem('settings') ?? '{}'));
+		}
 
 		const chatInput = document.getElementById('chat-textarea');
 		setTimeout(() => chatInput?.focus(), 0);
@@ -187,11 +191,18 @@
 						: convertMessagesToHistory(chatContent.messages);
 				title = chatContent.title;
 
-				let _settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
+				const userSettings = await getUserSettings(localStorage.token);
+
+				if (userSettings) {
+					await settings.set(userSettings.ui);
+				} else {
+					await settings.set(JSON.parse(localStorage.getItem('settings') ?? '{}'));
+				}
+
 				await settings.set({
-					..._settings,
-					system: chatContent.system ?? _settings.system,
-					params: chatContent.options ?? _settings.params
+					...$settings,
+					system: chatContent.system ?? $settings.system,
+					params: chatContent.options ?? $settings.params
 				});
 				autoScroll = true;
 				await tick();

+ 5 - 3
src/lib/components/chat/ModelSelector.svelte

@@ -1,13 +1,13 @@
 <script lang="ts">
-	import { Collapsible } from 'bits-ui';
-
-	import { setDefaultModels } from '$lib/apis/configs';
 	import { models, showSettings, settings, user, mobile } from '$lib/stores';
 	import { onMount, tick, getContext } from 'svelte';
 	import { toast } from 'svelte-sonner';
 	import Selector from './ModelSelector/Selector.svelte';
 	import Tooltip from '../common/Tooltip.svelte';
 
+	import { setDefaultModels } from '$lib/apis/configs';
+	import { updateUserSettings } from '$lib/apis/users';
+
 	const i18n = getContext('i18n');
 
 	export let selectedModels = [''];
@@ -22,7 +22,9 @@
 			return;
 		}
 		settings.set({ ...$settings, models: selectedModels });
+
 		localStorage.setItem('settings', JSON.stringify($settings));
+		await updateUserSettings(localStorage.token, { ui: $settings });
 
 		if ($user.role === 'admin') {
 			console.log('setting default models globally');

+ 9 - 11
src/lib/components/chat/Settings/Audio.svelte

@@ -1,6 +1,6 @@
 <script lang="ts">
 	import { getAudioConfig, updateAudioConfig } from '$lib/apis/audio';
-	import { user } from '$lib/stores';
+	import { user, settings } from '$lib/stores';
 	import { createEventDispatcher, onMount, getContext } from 'svelte';
 	import { toast } from 'svelte-sonner';
 	const dispatch = createEventDispatcher();
@@ -99,16 +99,14 @@
 	};
 
 	onMount(async () => {
-		let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
-
-		conversationMode = settings.conversationMode ?? false;
-		speechAutoSend = settings.speechAutoSend ?? false;
-		responseAutoPlayback = settings.responseAutoPlayback ?? false;
-
-		STTEngine = settings?.audio?.STTEngine ?? '';
-		TTSEngine = settings?.audio?.TTSEngine ?? '';
-		speaker = settings?.audio?.speaker ?? '';
-		model = settings?.audio?.model ?? '';
+		conversationMode = $settings.conversationMode ?? false;
+		speechAutoSend = $settings.speechAutoSend ?? false;
+		responseAutoPlayback = $settings.responseAutoPlayback ?? false;
+
+		STTEngine = $settings?.audio?.STTEngine ?? '';
+		TTSEngine = $settings?.audio?.TTSEngine ?? '';
+		speaker = $settings?.audio?.speaker ?? '';
+		model = $settings?.audio?.model ?? '';
 
 		if (TTSEngine === 'openai') {
 			getOpenAIVoices();

+ 2 - 4
src/lib/components/chat/Settings/Chats.svelte

@@ -2,7 +2,7 @@
 	import fileSaver from 'file-saver';
 	const { saveAs } = fileSaver;
 
-	import { chats, user, config } from '$lib/stores';
+	import { chats, user, settings } from '$lib/stores';
 
 	import {
 		archiveAllChats,
@@ -99,9 +99,7 @@
 	};
 
 	onMount(async () => {
-		let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
-
-		saveChatHistory = settings.saveChatHistory ?? true;
+		saveChatHistory = $settings.saveChatHistory ?? true;
 	});
 </script>
 

+ 13 - 14
src/lib/components/chat/Settings/General.svelte

@@ -4,7 +4,7 @@
 	import { getLanguages } from '$lib/i18n';
 	const dispatch = createEventDispatcher();
 
-	import { models, user, theme } from '$lib/stores';
+	import { models, settings, theme } from '$lib/stores';
 
 	const i18n = getContext('i18n');
 
@@ -71,23 +71,22 @@
 	onMount(async () => {
 		selectedTheme = localStorage.theme ?? 'system';
 
-		let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
 		languages = await getLanguages();
 
-		notificationEnabled = settings.notificationEnabled ?? false;
-		system = settings.system ?? '';
+		notificationEnabled = $settings.notificationEnabled ?? false;
+		system = $settings.system ?? '';
 
-		requestFormat = settings.requestFormat ?? '';
-		keepAlive = settings.keepAlive ?? null;
+		requestFormat = $settings.requestFormat ?? '';
+		keepAlive = $settings.keepAlive ?? null;
 
-		params.seed = settings.seed ?? 0;
-		params.temperature = settings.temperature ?? '';
-		params.frequency_penalty = settings.frequency_penalty ?? '';
-		params.top_k = settings.top_k ?? '';
-		params.top_p = settings.top_p ?? '';
-		params.num_ctx = settings.num_ctx ?? '';
-		params = { ...params, ...settings.params };
-		params.stop = settings?.params?.stop ? (settings?.params?.stop ?? []).join(',') : null;
+		params.seed = $settings.seed ?? 0;
+		params.temperature = $settings.temperature ?? '';
+		params.frequency_penalty = $settings.frequency_penalty ?? '';
+		params.top_k = $settings.top_k ?? '';
+		params.top_p = $settings.top_p ?? '';
+		params.num_ctx = $settings.num_ctx ?? '';
+		params = { ...params, ...$settings.params };
+		params.stop = $settings?.params?.stop ? ($settings?.params?.stop ?? []).join(',') : null;
 	});
 
 	const applyTheme = (_theme: string) => {

+ 11 - 16
src/lib/components/chat/Settings/Interface.svelte

@@ -104,23 +104,18 @@
 			promptSuggestions = $config?.default_prompt_suggestions;
 		}
 
-		let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
-
-		titleAutoGenerate = settings?.title?.auto ?? true;
-		titleAutoGenerateModel = settings?.title?.model ?? '';
-		titleAutoGenerateModelExternal = settings?.title?.modelExternal ?? '';
+		titleAutoGenerate = $settings?.title?.auto ?? true;
+		titleAutoGenerateModel = $settings?.title?.model ?? '';
+		titleAutoGenerateModelExternal = $settings?.title?.modelExternal ?? '';
 		titleGenerationPrompt =
-			settings?.title?.prompt ??
-			$i18n.t(
-				"Create a concise, 3-5 word phrase as a header for the following query, strictly adhering to the 3-5 word limit and avoiding the use of the word 'title':"
-			) + ' {{prompt}}';
-
-		responseAutoCopy = settings.responseAutoCopy ?? false;
-		showUsername = settings.showUsername ?? false;
-		chatBubble = settings.chatBubble ?? true;
-		fullScreenMode = settings.fullScreenMode ?? false;
-		splitLargeChunks = settings.splitLargeChunks ?? false;
-		chatDirection = settings.chatDirection ?? 'LTR';
+			$settings?.title?.prompt ??
+			`Create a concise, 3-5 word phrase as a header for the following query, strictly adhering to the 3-5 word limit and avoiding the use of the word 'title': {{prompt}}`;
+		responseAutoCopy = $settings.responseAutoCopy ?? false;
+		showUsername = $settings.showUsername ?? false;
+		chatBubble = $settings.chatBubble ?? true;
+		fullScreenMode = $settings.fullScreenMode ?? false;
+		splitLargeChunks = $settings.splitLargeChunks ?? false;
+		chatDirection = $settings.chatDirection ?? 'LTR';
 	});
 </script>
 

+ 1 - 2
src/lib/components/chat/Settings/Personalization.svelte

@@ -19,8 +19,7 @@
 	let enableMemory = false;
 
 	onMount(async () => {
-		let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
-		enableMemory = settings?.memory ?? false;
+		enableMemory = $settings?.memory ?? false;
 	});
 </script>
 

+ 3 - 0
src/lib/components/chat/SettingsModal.svelte

@@ -17,6 +17,7 @@
 	import Images from './Settings/Images.svelte';
 	import User from '../icons/User.svelte';
 	import Personalization from './Settings/Personalization.svelte';
+	import { updateUserSettings } from '$lib/apis/users';
 
 	const i18n = getContext('i18n');
 
@@ -26,7 +27,9 @@
 		console.log(updated);
 		await settings.set({ ...$settings, ...updated });
 		await models.set(await getModels());
+
 		localStorage.setItem('settings', JSON.stringify($settings));
+		await updateUserSettings(localStorage.token, { ui: $settings });
 	};
 
 	const getModels = async () => {

+ 3 - 0
src/lib/components/layout/Sidebar.svelte

@@ -33,6 +33,7 @@
 	import ArchiveBox from '../icons/ArchiveBox.svelte';
 	import ArchivedChatsModal from './Sidebar/ArchivedChatsModal.svelte';
 	import UserMenu from './Sidebar/UserMenu.svelte';
+	import { updateUserSettings } from '$lib/apis/users';
 
 	const BREAKPOINT = 768;
 
@@ -184,6 +185,8 @@
 	const saveSettings = async (updated) => {
 		await settings.set({ ...$settings, ...updated });
 		localStorage.setItem('settings', JSON.stringify($settings));
+		await updateUserSettings(localStorage.token, { ui: $settings });
+
 		location.href = '/';
 	};
 

+ 8 - 1
src/routes/(app)/+layout.svelte

@@ -35,6 +35,7 @@
 	import ChangelogModal from '$lib/components/ChangelogModal.svelte';
 	import Tooltip from '$lib/components/common/Tooltip.svelte';
 	import { getBanners } from '$lib/apis/configs';
+	import { getUserSettings } from '$lib/apis/users';
 
 	const i18n = getContext('i18n');
 
@@ -72,7 +73,13 @@
 				// IndexedDB Not Found
 			}
 
-			settings.set(JSON.parse(localStorage.getItem('settings') ?? '{}'));
+			const userSettings = await getUserSettings(localStorage.token);
+
+			if (userSettings) {
+				await settings.set(userSettings.ui);
+			} else {
+				await settings.set(JSON.parse(localStorage.getItem('settings') ?? '{}'));
+			}
 
 			await Promise.all([
 				(async () => {

+ 0 - 6
src/routes/s/[id]/+page.svelte

@@ -98,12 +98,6 @@
 						: convertMessagesToHistory(chatContent.messages);
 				title = chatContent.title;
 
-				let _settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
-				await settings.set({
-					..._settings,
-					system: chatContent.system ?? _settings.system,
-					options: chatContent.options ?? _settings.options
-				});
 				autoScroll = true;
 				await tick();