Explorar el Código

feat: convo tag filtering

Timothy J. Baek hace 1 año
padre
commit
220530c450

+ 29 - 18
backend/apps/web/models/chats.py

@@ -60,23 +60,23 @@ class ChatTitleIdResponse(BaseModel):
 
 
 
 
 class ChatTable:
 class ChatTable:
-
     def __init__(self, db):
     def __init__(self, db):
         self.db = db
         self.db = db
         db.create_tables([Chat])
         db.create_tables([Chat])
 
 
-    def insert_new_chat(self, user_id: str,
-                        form_data: ChatForm) -> Optional[ChatModel]:
+    def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
         id = str(uuid.uuid4())
         id = str(uuid.uuid4())
         chat = ChatModel(
         chat = ChatModel(
             **{
             **{
                 "id": id,
                 "id": id,
                 "user_id": user_id,
                 "user_id": user_id,
-                "title": form_data.chat["title"] if "title" in
-                form_data.chat else "New Chat",
+                "title": form_data.chat["title"]
+                if "title" in form_data.chat
+                else "New Chat",
                 "chat": json.dumps(form_data.chat),
                 "chat": json.dumps(form_data.chat),
                 "timestamp": int(time.time()),
                 "timestamp": int(time.time()),
-            })
+            }
+        )
 
 
         result = Chat.create(**chat.model_dump())
         result = Chat.create(**chat.model_dump())
         return chat if result else None
         return chat if result else None
@@ -109,25 +109,37 @@ class ChatTable:
         except:
         except:
             return None
             return None
 
 
-    def get_chat_lists_by_user_id(self,
-                                  user_id: str,
-                                  skip: int = 0,
-                                  limit: int = 50) -> List[ChatModel]:
+    def get_chat_lists_by_user_id(
+        self, user_id: str, skip: int = 0, limit: int = 50
+    ) -> List[ChatModel]:
         return [
         return [
-            ChatModel(**model_to_dict(chat)) for chat in Chat.select().where(
-                Chat.user_id == user_id).order_by(Chat.timestamp.desc())
+            ChatModel(**model_to_dict(chat))
+            for chat in Chat.select()
+            .where(Chat.user_id == user_id)
+            .order_by(Chat.timestamp.desc())
             # .limit(limit)
             # .limit(limit)
             # .offset(skip)
             # .offset(skip)
         ]
         ]
 
 
+    def get_chat_lists_by_chat_ids(
+        self, chat_ids: List[str], skip: int = 0, limit: int = 50
+    ) -> List[ChatModel]:
+        return [
+            ChatModel(**model_to_dict(chat))
+            for chat in Chat.select()
+            .where(Chat.id.in_(chat_ids))
+            .order_by(Chat.timestamp.desc())
+        ]
+
     def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
     def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
         return [
         return [
-            ChatModel(**model_to_dict(chat)) for chat in Chat.select().where(
-                Chat.user_id == user_id).order_by(Chat.timestamp.desc())
+            ChatModel(**model_to_dict(chat))
+            for chat in Chat.select()
+            .where(Chat.user_id == user_id)
+            .order_by(Chat.timestamp.desc())
         ]
         ]
 
 
-    def get_chat_by_id_and_user_id(self, id: str,
-                                   user_id: str) -> Optional[ChatModel]:
+    def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
         try:
         try:
             chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
             chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
             return ChatModel(**model_to_dict(chat))
             return ChatModel(**model_to_dict(chat))
@@ -142,8 +154,7 @@ class ChatTable:
 
 
     def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
     def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
         try:
         try:
-            query = Chat.delete().where((Chat.id == id)
-                                        & (Chat.user_id == user_id))
+            query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
             query.execute()  # Remove the rows, return number of rows removed.
             query.execute()  # Remove the rows, return number of rows removed.
 
 
             return True
             return True

+ 13 - 0
backend/apps/web/models/tags.py

@@ -120,6 +120,19 @@ class TagTable:
         except:
         except:
             return None
             return None
 
 
+    def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
+        tag_names = [
+            ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
+            for chat_id_tag in ChatIdTag.select()
+            .where(ChatIdTag.user_id == user_id)
+            .order_by(ChatIdTag.timestamp.desc())
+        ]
+
+        return [
+            TagModel(**model_to_dict(tag))
+            for tag in Tag.select().where(Tag.name.in_(tag_names))
+        ]
+
     def get_tags_by_chat_id_and_user_id(
     def get_tags_by_chat_id_and_user_id(
         self, chat_id: str, user_id: str
         self, chat_id: str, user_id: str
     ) -> List[TagModel]:
     ) -> List[TagModel]:

+ 36 - 0
backend/apps/web/routers/chats.py

@@ -74,6 +74,42 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
         )
         )
 
 
 
 
+############################
+# GetAllTags
+############################
+
+
+@router.get("/tags/all", response_model=List[TagModel])
+async def get_all_tags(user=Depends(get_current_user)):
+    try:
+        tags = Tags.get_tags_by_user_id(user.id)
+        return tags
+    except Exception as e:
+        print(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+
+############################
+# GetChatsByTags
+############################
+
+
+@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse])
+async def get_user_chats_by_tag_name(
+    tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50
+):
+    chat_ids = [
+        chat_id_tag.chat_id
+        for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id)
+    ]
+
+    print(chat_ids)
+
+    return Chats.get_chat_lists_by_chat_ids(chat_ids, skip, limit)
+
+
 ############################
 ############################
 # GetChatById
 # GetChatById
 ############################
 ############################

+ 62 - 0
src/lib/apis/chats/index.ts

@@ -93,6 +93,68 @@ export const getAllChats = async (token: string) => {
 	return res;
 	return res;
 };
 };
 
 
+export const getAllChatTags = async (token: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/all`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getChatListByTagName = async (token: string = '', tagName: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/tag/${tagName}`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const getChatById = async (token: string, id: string) => {
 export const getChatById = async (token: string, id: string) => {
 	let error = null;
 	let error = null;
 
 

+ 0 - 1
src/lib/components/layout/Navbar.svelte

@@ -6,7 +6,6 @@
 	import { getChatById } from '$lib/apis/chats';
 	import { getChatById } from '$lib/apis/chats';
 	import { chatId, modelfiles } from '$lib/stores';
 	import { chatId, modelfiles } from '$lib/stores';
 	import ShareChatModal from '../chat/ShareChatModal.svelte';
 	import ShareChatModal from '../chat/ShareChatModal.svelte';
-	import { stringify } from 'postcss';
 
 
 	export let initNewChat: Function;
 	export let initNewChat: Function;
 	export let title: string = 'Ollama Web UI';
 	export let title: string = 'Ollama Web UI';

+ 36 - 2
src/lib/components/layout/Sidebar.svelte

@@ -6,9 +6,14 @@
 
 
 	import { goto, invalidateAll } from '$app/navigation';
 	import { goto, invalidateAll } from '$app/navigation';
 	import { page } from '$app/stores';
 	import { page } from '$app/stores';
-	import { user, chats, settings, showSettings, chatId } from '$lib/stores';
+	import { user, chats, settings, showSettings, chatId, tags } from '$lib/stores';
 	import { onMount } from 'svelte';
 	import { onMount } from 'svelte';
-	import { deleteChatById, getChatList, updateChatById } from '$lib/apis/chats';
+	import {
+		deleteChatById,
+		getChatList,
+		getChatListByTagName,
+		updateChatById
+	} from '$lib/apis/chats';
 
 
 	let show = false;
 	let show = false;
 	let navElement;
 	let navElement;
@@ -28,6 +33,12 @@
 		}
 		}
 
 
 		await chats.set(await getChatList(localStorage.token));
 		await chats.set(await getChatList(localStorage.token));
+
+		tags.subscribe(async (value) => {
+			if (value.length === 0) {
+				await chats.set(await getChatList(localStorage.token));
+			}
+		});
 	});
 	});
 
 
 	const loadChat = async (id) => {
 	const loadChat = async (id) => {
@@ -281,6 +292,29 @@
 				</div>
 				</div>
 			</div>
 			</div>
 
 
+			{#if $tags.length > 0}
+				<div class="px-2.5 mt-0.5 mb-2 flex gap-1 flex-wrap">
+					<button
+						class="px-2.5 text-xs font-medium bg-gray-900 hover:bg-gray-800 transition rounded-full"
+						on:click={async () => {
+							await chats.set(await getChatList(localStorage.token));
+						}}
+					>
+						all
+					</button>
+					{#each $tags as tag}
+						<button
+							class="px-2.5 text-xs font-medium bg-gray-900 hover:bg-gray-800 transition rounded-full"
+							on:click={async () => {
+								await chats.set(await getChatListByTagName(localStorage.token, tag.name));
+							}}
+						>
+							{tag.name}
+						</button>
+					{/each}
+				</div>
+			{/if}
+
 			<div class="pl-2.5 my-2 flex-1 flex flex-col space-y-1 overflow-y-auto">
 			<div class="pl-2.5 my-2 flex-1 flex flex-col space-y-1 overflow-y-auto">
 				{#each $chats.filter((chat) => {
 				{#each $chats.filter((chat) => {
 					if (search === '') {
 					if (search === '') {

+ 1 - 0
src/lib/stores/index.ts

@@ -10,6 +10,7 @@ export const theme = writable('dark');
 export const chatId = writable('');
 export const chatId = writable('');
 
 
 export const chats = writable([]);
 export const chats = writable([]);
+export const tags = writable([]);
 export const models = writable([]);
 export const models = writable([]);
 
 
 export const modelfiles = writable([]);
 export const modelfiles = writable([]);

+ 4 - 1
src/routes/(app)/+layout.svelte

@@ -20,7 +20,8 @@
 		models,
 		models,
 		modelfiles,
 		modelfiles,
 		prompts,
 		prompts,
-		documents
+		documents,
+		tags
 	} from '$lib/stores';
 	} from '$lib/stores';
 	import { REQUIRED_OLLAMA_VERSION, WEBUI_API_BASE_URL } from '$lib/constants';
 	import { REQUIRED_OLLAMA_VERSION, WEBUI_API_BASE_URL } from '$lib/constants';
 
 
@@ -29,6 +30,7 @@
 	import { checkVersion } from '$lib/utils';
 	import { checkVersion } from '$lib/utils';
 	import ShortcutsModal from '$lib/components/chat/ShortcutsModal.svelte';
 	import ShortcutsModal from '$lib/components/chat/ShortcutsModal.svelte';
 	import { getDocs } from '$lib/apis/documents';
 	import { getDocs } from '$lib/apis/documents';
+	import { getAllChatTags } from '$lib/apis/chats';
 
 
 	let ollamaVersion = '';
 	let ollamaVersion = '';
 	let loaded = false;
 	let loaded = false;
@@ -106,6 +108,7 @@
 			await modelfiles.set(await getModelfiles(localStorage.token));
 			await modelfiles.set(await getModelfiles(localStorage.token));
 			await prompts.set(await getPrompts(localStorage.token));
 			await prompts.set(await getPrompts(localStorage.token));
 			await documents.set(await getDocs(localStorage.token));
 			await documents.set(await getDocs(localStorage.token));
+			await tags.set(await getAllChatTags(localStorage.token));
 
 
 			modelfiles.subscribe(async () => {
 			modelfiles.subscribe(async () => {
 				// should fetch models
 				// should fetch models

+ 15 - 1
src/routes/(app)/+page.svelte

@@ -6,7 +6,16 @@
 	import { goto } from '$app/navigation';
 	import { goto } from '$app/navigation';
 	import { page } from '$app/stores';
 	import { page } from '$app/stores';
 
 
-	import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores';
+	import {
+		models,
+		modelfiles,
+		user,
+		settings,
+		chats,
+		chatId,
+		config,
+		tags as _tags
+	} from '$lib/stores';
 	import { copyToClipboard, splitStream } from '$lib/utils';
 	import { copyToClipboard, splitStream } from '$lib/utils';
 
 
 	import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama';
 	import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama';
@@ -14,6 +23,7 @@
 		addTagById,
 		addTagById,
 		createNewChat,
 		createNewChat,
 		deleteTagById,
 		deleteTagById,
+		getAllChatTags,
 		getChatList,
 		getChatList,
 		getTagsById,
 		getTagsById,
 		updateChatById
 		updateChatById
@@ -695,6 +705,8 @@
 		chat = await updateChatById(localStorage.token, $chatId, {
 		chat = await updateChatById(localStorage.token, $chatId, {
 			tags: tags
 			tags: tags
 		});
 		});
+
+		_tags.set(await getAllChatTags(localStorage.token));
 	};
 	};
 
 
 	const deleteTag = async (tagName) => {
 	const deleteTag = async (tagName) => {
@@ -704,6 +716,8 @@
 		chat = await updateChatById(localStorage.token, $chatId, {
 		chat = await updateChatById(localStorage.token, $chatId, {
 			tags: tags
 			tags: tags
 		});
 		});
+
+		_tags.set(await getAllChatTags(localStorage.token));
 	};
 	};
 
 
 	const setChatTitle = async (_chatId, _title) => {
 	const setChatTitle = async (_chatId, _title) => {

+ 17 - 3
src/routes/(app)/c/[id]/+page.svelte

@@ -6,7 +6,16 @@
 	import { goto } from '$app/navigation';
 	import { goto } from '$app/navigation';
 	import { page } from '$app/stores';
 	import { page } from '$app/stores';
 
 
-	import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores';
+	import {
+		models,
+		modelfiles,
+		user,
+		settings,
+		chats,
+		chatId,
+		config,
+		tags as _tags
+	} from '$lib/stores';
 	import { copyToClipboard, splitStream, convertMessagesToHistory } from '$lib/utils';
 	import { copyToClipboard, splitStream, convertMessagesToHistory } from '$lib/utils';
 
 
 	import { generateChatCompletion, generateTitle } from '$lib/apis/ollama';
 	import { generateChatCompletion, generateTitle } from '$lib/apis/ollama';
@@ -14,6 +23,7 @@
 		addTagById,
 		addTagById,
 		createNewChat,
 		createNewChat,
 		deleteTagById,
 		deleteTagById,
+		getAllChatTags,
 		getChatById,
 		getChatById,
 		getChatList,
 		getChatList,
 		getTagsById,
 		getTagsById,
@@ -709,8 +719,10 @@
 		tags = await getTags();
 		tags = await getTags();
 
 
 		chat = await updateChatById(localStorage.token, $chatId, {
 		chat = await updateChatById(localStorage.token, $chatId, {
-			tags: tags.map((tag) => tag.name)
+			tags: tags
 		});
 		});
+
+		_tags.set(await getAllChatTags(localStorage.token));
 	};
 	};
 
 
 	const deleteTag = async (tagName) => {
 	const deleteTag = async (tagName) => {
@@ -718,8 +730,10 @@
 		tags = await getTags();
 		tags = await getTags();
 
 
 		chat = await updateChatById(localStorage.token, $chatId, {
 		chat = await updateChatById(localStorage.token, $chatId, {
-			tags: tags.map((tag) => tag.name)
+			tags: tags
 		});
 		});
+
+		_tags.set(await getAllChatTags(localStorage.token));
 	};
 	};
 
 
 	onMount(async () => {
 	onMount(async () => {