瀏覽代碼

feat: rag docs as payload field

Timothy J. Baek 1 年之前
父節點
當前提交
6c58bb59be
共有 3 個文件被更改,包括 25 次插入53 次删除
  1. 0 2
      backend/main.py
  2. 4 2
      src/routes/(app)/+page.svelte
  3. 21 49
      src/routes/(app)/c/[id]/+page.svelte

+ 0 - 2
backend/main.py

@@ -123,8 +123,6 @@ class RAGMiddleware(BaseHTTPMiddleware):
                 data["messages"][last_user_message_idx] = new_user_message
                 data["messages"][last_user_message_idx] = new_user_message
                 del data["docs"]
                 del data["docs"]
 
 
-            print("DATAAAAAAAAAAAAAAAAAA")
-            print(data)
             modified_body_bytes = json.dumps(data).encode("utf-8")
             modified_body_bytes = json.dumps(data).encode("utf-8")
 
 
             # Create a new request with the modified body
             # Create a new request with the modified body

+ 4 - 2
src/routes/(app)/+page.svelte

@@ -336,7 +336,7 @@
 			},
 			},
 			format: $settings.requestFormat ?? undefined,
 			format: $settings.requestFormat ?? undefined,
 			keep_alive: $settings.keepAlive ?? undefined,
 			keep_alive: $settings.keepAlive ?? undefined,
-			docs: docs
+			docs: docs.length > 0 ? docs : undefined
 		});
 		});
 
 
 		if (res && res.ok) {
 		if (res && res.ok) {
@@ -503,6 +503,8 @@
 			)
 			)
 			.flat(1);
 			.flat(1);
 
 
+		console.log(docs);
+
 		const res = await generateOpenAIChatCompletion(
 		const res = await generateOpenAIChatCompletion(
 			localStorage.token,
 			localStorage.token,
 			{
 			{
@@ -552,7 +554,7 @@
 				num_ctx: $settings?.options?.num_ctx ?? undefined,
 				num_ctx: $settings?.options?.num_ctx ?? undefined,
 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
 				max_tokens: $settings?.options?.num_predict ?? undefined,
 				max_tokens: $settings?.options?.num_predict ?? undefined,
-				docs: docs
+				docs: docs.length > 0 ? docs : undefined
 			},
 			},
 			model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
 			model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
 		);
 		);

+ 21 - 49
src/routes/(app)/c/[id]/+page.svelte

@@ -245,53 +245,6 @@
 	const sendPrompt = async (prompt, parentId) => {
 	const sendPrompt = async (prompt, parentId) => {
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 
 
-		const docs = messages
-			.filter((message) => message?.files ?? null)
-			.map((message) =>
-				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
-			)
-			.flat(1);
-
-		console.log(docs);
-		if (docs.length > 0) {
-			processing = 'Reading';
-			const query = history.messages[parentId].content;
-
-			let relevantContexts = await Promise.all(
-				docs.map(async (doc) => {
-					if (doc.type === 'collection') {
-						return await queryCollection(localStorage.token, doc.collection_names, query).catch(
-							(error) => {
-								console.log(error);
-								return null;
-							}
-						);
-					} else {
-						return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => {
-							console.log(error);
-							return null;
-						});
-					}
-				})
-			);
-			relevantContexts = relevantContexts.filter((context) => context);
-
-			const contextString = relevantContexts.reduce((a, context, i, arr) => {
-				return `${a}${context.documents.join(' ')}\n`;
-			}, '');
-
-			console.log(contextString);
-
-			history.messages[parentId].raContent = await RAGTemplate(
-				localStorage.token,
-				contextString,
-				query
-			);
-			history.messages[parentId].contexts = relevantContexts;
-			await tick();
-			processing = '';
-		}
-
 		await Promise.all(
 		await Promise.all(
 			selectedModels.map(async (modelId) => {
 			selectedModels.map(async (modelId) => {
 				const model = $models.filter((m) => m.id === modelId).at(0);
 				const model = $models.filter((m) => m.id === modelId).at(0);
@@ -381,6 +334,13 @@
 			}
 			}
 		});
 		});
 
 
+		const docs = messages
+			.filter((message) => message?.files ?? null)
+			.map((message) =>
+				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+			)
+			.flat(1);
+
 		const [res, controller] = await generateChatCompletion(localStorage.token, {
 		const [res, controller] = await generateChatCompletion(localStorage.token, {
 			model: model,
 			model: model,
 			messages: messagesBody,
 			messages: messagesBody,
@@ -388,7 +348,8 @@
 				...($settings.options ?? {})
 				...($settings.options ?? {})
 			},
 			},
 			format: $settings.requestFormat ?? undefined,
 			format: $settings.requestFormat ?? undefined,
-			keep_alive: $settings.keepAlive ?? undefined
+			keep_alive: $settings.keepAlive ?? undefined,
+			docs: docs.length > 0 ? docs : undefined
 		});
 		});
 
 
 		if (res && res.ok) {
 		if (res && res.ok) {
@@ -548,6 +509,15 @@
 		const responseMessage = history.messages[responseMessageId];
 		const responseMessage = history.messages[responseMessageId];
 		scrollToBottom();
 		scrollToBottom();
 
 
+		const docs = messages
+			.filter((message) => message?.files ?? null)
+			.map((message) =>
+				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+			)
+			.flat(1);
+
+		console.log(docs);
+
 		const res = await generateOpenAIChatCompletion(
 		const res = await generateOpenAIChatCompletion(
 			localStorage.token,
 			localStorage.token,
 			{
 			{
@@ -596,7 +566,8 @@
 				top_p: $settings?.options?.top_p ?? undefined,
 				top_p: $settings?.options?.top_p ?? undefined,
 				num_ctx: $settings?.options?.num_ctx ?? undefined,
 				num_ctx: $settings?.options?.num_ctx ?? undefined,
 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
-				max_tokens: $settings?.options?.num_predict ?? undefined
+				max_tokens: $settings?.options?.num_predict ?? undefined,
+				docs: docs.length > 0 ? docs : undefined
 			},
 			},
 			model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
 			model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
 		);
 		);
@@ -710,6 +681,7 @@
 			await setChatTitle(_chatId, userPrompt);
 			await setChatTitle(_chatId, userPrompt);
 		}
 		}
 	};
 	};
+
 	const stopResponse = () => {
 	const stopResponse = () => {
 		stopResponseFlag = true;
 		stopResponseFlag = true;
 		console.log('stopResponse');
 		console.log('stopResponse');