浏览代码

refac: many model chat

Timothy J. Baek 8 月之前
父节点
当前提交
28e3e6e8cb
共有 2 个文件被更改,包括 42 次插入36 次删除
  1. 16 7
      src/lib/components/chat/Chat.svelte
  2. 26 29
      src/lib/components/chat/Messages/CompareMessages.svelte

+ 16 - 7
src/lib/components/chat/Chat.svelte

@@ -562,7 +562,7 @@
 				content: userPrompt,
 				files: _files.length > 0 ? _files : undefined,
 				timestamp: Math.floor(Date.now() / 1000), // Unix epoch
-				models: selectedModels.filter((m, mIdx) => selectedModels.indexOf(m) === mIdx)
+				models: selectedModels
 			};
 
 			// Add message to history and Set currentId to messageId
@@ -582,7 +582,11 @@
 		return _responses;
 	};
 
-	const sendPrompt = async (prompt, parentId, { modelId = null, newChat = false } = {}) => {
+	const sendPrompt = async (
+		prompt,
+		parentId,
+		{ modelId = null, modelIdx = null, newChat = false } = {}
+	) => {
 		let _responses = [];
 
 		// If modelId is provided, use it, else use selected model
@@ -594,7 +598,7 @@
 
 		// Create response messages for each selected model
 		const responseMessageIds = {};
-		for (const modelId of selectedModelIds) {
+		for (const [_modelIdx, modelId] of selectedModelIds.entries()) {
 			const model = $models.filter((m) => m.id === modelId).at(0);
 
 			if (model) {
@@ -607,6 +611,7 @@
 					content: '',
 					model: model.id,
 					modelName: model.name ?? model.id,
+					modelIdx: modelIdx ? modelIdx : _modelIdx,
 					userContext: null,
 					timestamp: Math.floor(Date.now() / 1000) // Unix epoch
 				};
@@ -623,7 +628,7 @@
 					];
 				}
 
-				responseMessageIds[modelId] = responseMessageId;
+				responseMessageIds[`${modelId}-${modelIdx ? modelIdx : _modelIdx}`] = responseMessageId;
 			}
 		}
 		await tick();
@@ -655,7 +660,7 @@
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 
 		await Promise.all(
-			selectedModelIds.map(async (modelId) => {
+			selectedModelIds.map(async (modelId, _modelIdx) => {
 				console.log('modelId', modelId);
 				const model = $models.filter((m) => m.id === modelId).at(0);
 
@@ -673,7 +678,8 @@
 						);
 					}
 
-					let responseMessageId = responseMessageIds[modelId];
+					let responseMessageId =
+						responseMessageIds[`${modelId}-${modelIdx ? modelIdx : _modelIdx}`];
 					let responseMessage = history.messages[responseMessageId];
 
 					let userContext = null;
@@ -1350,7 +1356,10 @@
 			} else {
 				// If there are multiple models selected, use the model of the response message for regeneration
 				// e.g. many model chat
-				await sendPrompt(userPrompt, userMessage.id, { modelId: message.model });
+				await sendPrompt(userPrompt, userMessage.id, {
+					modelId: message.model,
+					modelIdx: message.modelIdx
+				});
 			}
 		}
 	};

+ 26 - 29
src/lib/components/chat/Messages/CompareMessages.svelte

@@ -26,24 +26,24 @@
 	const dispatch = createEventDispatcher();
 
 	let currentMessageId;
-
-	let groupedMessagesIdx = {};
 	let groupedMessages = {};
+	let groupedMessagesIdx = {};
 
-	$: groupedMessages = parentMessage?.models.reduce((a, model) => {
+	$: groupedMessages = parentMessage?.models.reduce((a, model, modelIdx) => {
+		// Find all messages that are children of the parent message and have the same model
 		const modelMessages = parentMessage?.childrenIds
 			.map((id) => history.messages[id])
-			.filter((m) => m.model === model);
+			.filter((m) => m.modelIdx === modelIdx);
 
 		return {
 			...a,
-			[model]: { messages: modelMessages }
+			[modelIdx]: { messages: modelMessages }
 		};
 	}, {});
 
-	const showPreviousMessage = (model) => {
-		groupedMessagesIdx[model] = Math.max(0, groupedMessagesIdx[model] - 1);
-		let messageId = groupedMessages[model].messages[groupedMessagesIdx[model]].id;
+	const showPreviousMessage = (modelIdx) => {
+		groupedMessagesIdx[modelIdx] = Math.max(0, groupedMessagesIdx[modelIdx] - 1);
+		let messageId = groupedMessages[modelIdx].messages[groupedMessagesIdx[modelIdx]].id;
 
 		console.log(messageId);
 		let messageChildrenIds = history.messages[messageId].childrenIds;
@@ -54,17 +54,16 @@
 		}
 
 		history.currentId = messageId;
-
 		dispatch('change');
 	};
 
-	const showNextMessage = (model) => {
-		groupedMessagesIdx[model] = Math.min(
-			groupedMessages[model].messages.length - 1,
-			groupedMessagesIdx[model] + 1
+	const showNextMessage = (modelIdx) => {
+		groupedMessagesIdx[modelIdx] = Math.min(
+			groupedMessages[modelIdx].messages.length - 1,
+			groupedMessagesIdx[modelIdx] + 1
 		);
 
-		let messageId = groupedMessages[model].messages[groupedMessagesIdx[model]].id;
+		let messageId = groupedMessages[modelIdx].messages[groupedMessagesIdx[modelIdx]].id;
 		console.log(messageId);
 
 		let messageChildrenIds = history.messages[messageId].childrenIds;
@@ -75,7 +74,6 @@
 		}
 
 		history.currentId = messageId;
-
 		dispatch('change');
 	};
 
@@ -83,13 +81,12 @@
 		await tick();
 		currentMessageId = messages[messageIdx].id;
 
-		for (const model of parentMessage?.models) {
-			const idx = groupedMessages[model].messages.findIndex((m) => m.id === currentMessageId);
-
+		for (const [modelIdx, model] of parentMessage?.models.entries()) {
+			const idx = groupedMessages[modelIdx].messages.findIndex((m) => m.id === currentMessageId);
 			if (idx !== -1) {
-				groupedMessagesIdx[model] = idx;
+				groupedMessagesIdx[modelIdx] = idx;
 			} else {
-				groupedMessagesIdx[model] = 0;
+				groupedMessagesIdx[modelIdx] = 0;
 			}
 		}
 	});
@@ -101,16 +98,16 @@
 		id="responses-container-{parentMessage.id}"
 	>
 		{#key currentMessageId}
-			{#each Object.keys(groupedMessages) as model}
-				{#if groupedMessagesIdx[model] !== undefined && groupedMessages[model].messages.length > 0}
+			{#each Object.keys(groupedMessages) as modelIdx}
+				{#if groupedMessagesIdx[modelIdx] !== undefined && groupedMessages[modelIdx].messages.length > 0}
 					<!-- svelte-ignore a11y-no-static-element-interactions -->
 					<!-- svelte-ignore a11y-click-events-have-key-events -->
-					{@const message = groupedMessages[model].messages[groupedMessagesIdx[model]]}
+					{@const message = groupedMessages[modelIdx].messages[groupedMessagesIdx[modelIdx]]}
 
 					<div
 						class=" snap-center min-w-80 w-full max-w-full m-1 border {history.messages[
 							currentMessageId
-						].model === model
+						].modelIdx === modelIdx
 							? 'border-gray-100 dark:border-gray-800 border-[1.5px]'
 							: 'border-gray-50 dark:border-gray-850 '} transition p-5 rounded-3xl"
 						on:click={() => {
@@ -131,13 +128,13 @@
 					>
 						{#key history.currentId}
 							<ResponseMessage
-								message={groupedMessages[model].messages[groupedMessagesIdx[model]]}
-								siblings={groupedMessages[model].messages.map((m) => m.id)}
+								message={groupedMessages[modelIdx].messages[groupedMessagesIdx[modelIdx]]}
+								siblings={groupedMessages[modelIdx].messages.map((m) => m.id)}
 								isLastMessage={true}
 								{updateChatMessages}
 								{confirmEditResponseMessage}
-								showPreviousMessage={() => showPreviousMessage(model)}
-								showNextMessage={() => showNextMessage(model)}
+								showPreviousMessage={() => showPreviousMessage(modelIdx)}
+								showNextMessage={() => showNextMessage(modelIdx)}
 								{readOnly}
 								{rateMessage}
 								{copyToClipboard}
@@ -145,7 +142,7 @@
 								regenerateResponse={async (message) => {
 									regenerateResponse(message);
 									await tick();
-									groupedMessagesIdx[model] = groupedMessages[model].messages.length - 1;
+									groupedMessagesIdx[modelIdx] = groupedMessages[modelIdx].messages.length - 1;
 								}}
 								on:save={async (e) => {
 									console.log('save', e);