Browse Source

feat: continue generation

Timothy J. Baek 1 year ago
parent
commit
26438c29d5

+ 2 - 0
src/lib/components/chat/Messages.svelte

@@ -14,6 +14,7 @@
 
 	export let chatId = '';
 	export let sendPrompt: Function;
+	export let continueGeneration: Function;
 	export let regenerateResponse: Function;
 
 	export let processing = '';
@@ -296,6 +297,7 @@
 							{showNextMessage}
 							{rateMessage}
 							{copyToClipboard}
+							{continueGeneration}
 							{regenerateResponse}
 						/>
 					{/if}

+ 43 - 1
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -29,6 +29,7 @@
 	export let rateMessage: Function;
 
 	export let copyToClipboard: Function;
+	export let continueGeneration: Function;
 	export let regenerateResponse: Function;
 
 	let edit = false;
@@ -362,7 +363,7 @@
 								{/if}
 
 								{#if message.done}
-									<div class=" flex justify-start space-x-1 -mt-2">
+									<div class=" flex justify-start space-x-1 -mt-2 overflow-x-auto buttons">
 										{#if siblings.length > 1}
 											<div class="flex self-center">
 												<button
@@ -610,6 +611,36 @@
 										{/if}
 
 										{#if isLastMessage}
+											<button
+												type="button"
+												class="{isLastMessage
+													? 'visible'
+													: 'invisible group-hover:visible'} p-1 rounded dark:hover:text-white transition regenerate-response-button"
+												on:click={() => {
+													continueGeneration();
+												}}
+											>
+												<svg
+													xmlns="http://www.w3.org/2000/svg"
+													fill="none"
+													viewBox="0 0 24 24"
+													stroke-width="1.5"
+													stroke="currentColor"
+													class="w-4 h-4"
+												>
+													<path
+														stroke-linecap="round"
+														stroke-linejoin="round"
+														d="M21 12a9 9 0 1 1-18 0 9 9 0 0 1 18 0Z"
+													/>
+													<path
+														stroke-linecap="round"
+														stroke-linejoin="round"
+														d="M15.91 11.672a.375.375 0 0 1 0 .656l-5.603 3.113a.375.375 0 0 1-.557-.328V8.887c0-.286.307-.466.557-.327l5.603 3.112Z"
+													/>
+												</svg>
+											</button>
+
 											<button
 												type="button"
 												class="{isLastMessage
@@ -643,3 +674,14 @@
 		</div>
 	</div>
 {/key}
+
+<style>
+	.buttons::-webkit-scrollbar {
+		display: none; /* for Chrome, Safari and Opera */
+	}
+
+	.buttons {
+		-ms-overflow-style: none; /* IE and Edge */
+		scrollbar-width: none; /* Firefox */
+	}
+</style>

+ 49 - 47
src/routes/(app)/+page.svelte

@@ -272,10 +272,34 @@
 				console.log(model);
 				const modelTag = $models.filter((m) => m.name === model).at(0);
 
+				// Create response message
+				let responseMessageId = uuidv4();
+				let responseMessage = {
+					parentId: parentId,
+					id: responseMessageId,
+					childrenIds: [],
+					role: 'assistant',
+					content: '',
+					model: model,
+					timestamp: Math.floor(Date.now() / 1000) // Unix epoch
+				};
+
+				// Add message to history and Set currentId to messageId
+				history.messages[responseMessageId] = responseMessage;
+				history.currentId = responseMessageId;
+
+				// Append messageId to childrenIds of parent message
+				if (parentId !== null) {
+					history.messages[parentId].childrenIds = [
+						...history.messages[parentId].childrenIds,
+						responseMessageId
+					];
+				}
+
 				if (modelTag?.external) {
-					await sendPromptOpenAI(model, prompt, parentId, _chatId);
+					await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
 				} else if (modelTag) {
-					await sendPromptOllama(model, prompt, parentId, _chatId);
+					await sendPromptOllama(model, prompt, responseMessageId, _chatId);
 				} else {
 					toast.error(`Model ${model} not found`);
 				}
@@ -285,30 +309,8 @@
 		await chats.set(await getChatList(localStorage.token));
 	};
 
-	const sendPromptOllama = async (model, userPrompt, parentId, _chatId) => {
-		// Create response message
-		let responseMessageId = uuidv4();
-		let responseMessage = {
-			parentId: parentId,
-			id: responseMessageId,
-			childrenIds: [],
-			role: 'assistant',
-			content: '',
-			model: model,
-			timestamp: Math.floor(Date.now() / 1000) // Unix epoch
-		};
-
-		// Add message to history and Set currentId to messageId
-		history.messages[responseMessageId] = responseMessage;
-		history.currentId = responseMessageId;
-
-		// Append messageId to childrenIds of parent message
-		if (parentId !== null) {
-			history.messages[parentId].childrenIds = [
-				...history.messages[parentId].childrenIds,
-				responseMessageId
-			];
-		}
+	const sendPromptOllama = async (model, userPrompt, responseMessageId, _chatId) => {
+		const responseMessage = history.messages[responseMessageId];
 
 		// Wait until history/message have been updated
 		await tick();
@@ -515,27 +517,8 @@
 		}
 	};
 
-	const sendPromptOpenAI = async (model, userPrompt, parentId, _chatId) => {
-		let responseMessageId = uuidv4();
-
-		let responseMessage = {
-			parentId: parentId,
-			id: responseMessageId,
-			childrenIds: [],
-			role: 'assistant',
-			content: '',
-			model: model,
-			timestamp: Math.floor(Date.now() / 1000) // Unix epoch
-		};
-
-		history.messages[responseMessageId] = responseMessage;
-		history.currentId = responseMessageId;
-		if (parentId !== null) {
-			history.messages[parentId].childrenIds = [
-				...history.messages[parentId].childrenIds,
-				responseMessageId
-			];
-		}
+	const sendPromptOpenAI = async (model, userPrompt, responseMessageId, _chatId) => {
+		const responseMessage = history.messages[responseMessageId];
 
 		window.scrollTo({ top: document.body.scrollHeight });
 
@@ -716,6 +699,24 @@
 		}
 	};
 
+	const continueGeneration = async () => {
+		console.log('continueGeneration');
+		const _chatId = JSON.parse(JSON.stringify($chatId));
+
+		if (messages.length != 0 && messages.at(-1).done == true) {
+			const responseMessage = history.messages[history.currentId];
+			const modelTag = $models.filter((m) => m.name === responseMessage.model).at(0);
+
+			if (modelTag?.external) {
+				await sendPromptOpenAI(responseMessage.model, prompt, responseMessage.id, _chatId);
+			} else if (modelTag) {
+				await sendPromptOllama(responseMessage.model, prompt, responseMessage.id, _chatId);
+			} else {
+				toast.error(`Model ${model} not found`);
+			}
+		}
+	};
+
 	const generateChatTitle = async (_chatId, userPrompt) => {
 		if ($settings.titleAutoGenerate ?? true) {
 			const title = await generateTitle(
@@ -800,6 +801,7 @@
 				bind:autoScroll
 				bottomPadding={files.length > 0}
 				{sendPrompt}
+				{continueGeneration}
 				{regenerateResponse}
 			/>
 		</div>

+ 49 - 47
src/routes/(app)/c/[id]/+page.svelte

@@ -286,10 +286,34 @@
 				console.log(model);
 				const modelTag = $models.filter((m) => m.name === model).at(0);
 
+				// Create response message
+				let responseMessageId = uuidv4();
+				let responseMessage = {
+					parentId: parentId,
+					id: responseMessageId,
+					childrenIds: [],
+					role: 'assistant',
+					content: '',
+					model: model,
+					timestamp: Math.floor(Date.now() / 1000) // Unix epoch
+				};
+
+				// Add message to history and Set currentId to messageId
+				history.messages[responseMessageId] = responseMessage;
+				history.currentId = responseMessageId;
+
+				// Append messageId to childrenIds of parent message
+				if (parentId !== null) {
+					history.messages[parentId].childrenIds = [
+						...history.messages[parentId].childrenIds,
+						responseMessageId
+					];
+				}
+
 				if (modelTag?.external) {
-					await sendPromptOpenAI(model, prompt, parentId, _chatId);
+					await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
 				} else if (modelTag) {
-					await sendPromptOllama(model, prompt, parentId, _chatId);
+					await sendPromptOllama(model, prompt, responseMessageId, _chatId);
 				} else {
 					toast.error(`Model ${model} not found`);
 				}
@@ -299,30 +323,8 @@
 		await chats.set(await getChatList(localStorage.token));
 	};
 
-	const sendPromptOllama = async (model, userPrompt, parentId, _chatId) => {
-		// Create response message
-		let responseMessageId = uuidv4();
-		let responseMessage = {
-			parentId: parentId,
-			id: responseMessageId,
-			childrenIds: [],
-			role: 'assistant',
-			content: '',
-			model: model,
-			timestamp: Math.floor(Date.now() / 1000) // Unix epoch
-		};
-
-		// Add message to history and Set currentId to messageId
-		history.messages[responseMessageId] = responseMessage;
-		history.currentId = responseMessageId;
-
-		// Append messageId to childrenIds of parent message
-		if (parentId !== null) {
-			history.messages[parentId].childrenIds = [
-				...history.messages[parentId].childrenIds,
-				responseMessageId
-			];
-		}
+	const sendPromptOllama = async (model, userPrompt, responseMessageId, _chatId) => {
+		const responseMessage = history.messages[responseMessageId];
 
 		// Wait until history/message have been updated
 		await tick();
@@ -529,27 +531,8 @@
 		}
 	};
 
-	const sendPromptOpenAI = async (model, userPrompt, parentId, _chatId) => {
-		let responseMessageId = uuidv4();
-
-		let responseMessage = {
-			parentId: parentId,
-			id: responseMessageId,
-			childrenIds: [],
-			role: 'assistant',
-			content: '',
-			model: model,
-			timestamp: Math.floor(Date.now() / 1000) // Unix epoch
-		};
-
-		history.messages[responseMessageId] = responseMessage;
-		history.currentId = responseMessageId;
-		if (parentId !== null) {
-			history.messages[parentId].childrenIds = [
-				...history.messages[parentId].childrenIds,
-				responseMessageId
-			];
-		}
+	const sendPromptOpenAI = async (model, userPrompt, responseMessageId, _chatId) => {
+		const responseMessage = history.messages[responseMessageId];
 
 		window.scrollTo({ top: document.body.scrollHeight });
 
@@ -717,6 +700,24 @@
 		console.log('stopResponse');
 	};
 
+	const continueGeneration = async () => {
+		console.log('continueGeneration');
+		const _chatId = JSON.parse(JSON.stringify($chatId));
+
+		if (messages.length != 0 && messages.at(-1).done == true) {
+			const responseMessage = history.messages[history.currentId];
+			const modelTag = $models.filter((m) => m.name === responseMessage.model).at(0);
+
+			if (modelTag?.external) {
+				await sendPromptOpenAI(responseMessage.model, prompt, responseMessage.id, _chatId);
+			} else if (modelTag) {
+				await sendPromptOllama(responseMessage.model, prompt, responseMessage.id, _chatId);
+			} else {
+				toast.error(`Model ${model} not found`);
+			}
+		}
+	};
+
 	const regenerateResponse = async () => {
 		console.log('regenerateResponse');
 		if (messages.length != 0 && messages.at(-1).done == true) {
@@ -832,6 +833,7 @@
 					bind:autoScroll
 					bottomPadding={files.length > 0}
 					{sendPrompt}
+					{continueGeneration}
 					{regenerateResponse}
 				/>
 			</div>