Browse Source

Merge pull request #1113 from open-webui/rag

feat: rag api
Timothy Jaeryang Baek 1 year ago
parent
commit
d99114518c

+ 8 - 0
backend/apps/rag/utils.py

@@ -1,3 +1,4 @@
+import re
 from typing import List
 from typing import List
 
 
 from config import CHROMA_CLIENT
 from config import CHROMA_CLIENT
@@ -87,3 +88,10 @@ def query_collection(
             pass
             pass
 
 
     return merge_and_sort_query_results(results, k)
     return merge_and_sort_query_results(results, k)
+
+
+def rag_template(template: str, context: str, query: str):
+    template = re.sub(r"\[context\]", context, template)
+    template = re.sub(r"\[query\]", query, template)
+
+    return template

+ 121 - 0
backend/main.py

@@ -12,6 +12,7 @@ from fastapi import HTTPException
 from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from starlette.exceptions import HTTPException as StarletteHTTPException
 from starlette.exceptions import HTTPException as StarletteHTTPException
+from starlette.middleware.base import BaseHTTPMiddleware
 
 
 
 
 from apps.ollama.main import app as ollama_app
 from apps.ollama.main import app as ollama_app
@@ -23,6 +24,8 @@ from apps.rag.main import app as rag_app
 from apps.web.main import app as webui_app
 from apps.web.main import app as webui_app
 
 
 
 
+from apps.rag.utils import query_doc, query_collection, rag_template
+
 from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
 from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 
 
@@ -56,6 +59,124 @@ async def on_startup():
     await litellm_app_startup()
     await litellm_app_startup()
 
 
 
 
+class RAGMiddleware(BaseHTTPMiddleware):
+    async def dispatch(self, request: Request, call_next):
+        if request.method == "POST" and (
+            "/api/chat" in request.url.path or "/chat/completions" in request.url.path
+        ):
+            print(request.url.path)
+
+            # Read the original request body
+            body = await request.body()
+            # Decode body to string
+            body_str = body.decode("utf-8")
+            # Parse string to JSON
+            data = json.loads(body_str) if body_str else {}
+
+            # Example: Add a new key-value pair or modify existing ones
+            # data["modified"] = True  # Example modification
+            if "docs" in data:
+                docs = data["docs"]
+                print(docs)
+
+                last_user_message_idx = None
+                for i in range(len(data["messages"]) - 1, -1, -1):
+                    if data["messages"][i]["role"] == "user":
+                        last_user_message_idx = i
+                        break
+
+                user_message = data["messages"][last_user_message_idx]
+
+                if isinstance(user_message["content"], list):
+                    # Handle list content input
+                    content_type = "list"
+                    query = ""
+                    for content_item in user_message["content"]:
+                        if content_item["type"] == "text":
+                            query = content_item["text"]
+                            break
+                elif isinstance(user_message["content"], str):
+                    # Handle text content input
+                    content_type = "text"
+                    query = user_message["content"]
+                else:
+                    # Fallback in case the input does not match expected types
+                    content_type = None
+                    query = ""
+
+                relevant_contexts = []
+
+                for doc in docs:
+                    context = None
+
+                    try:
+                        if doc["type"] == "collection":
+                            context = query_collection(
+                                collection_names=doc["collection_names"],
+                                query=query,
+                                k=rag_app.state.TOP_K,
+                                embedding_function=rag_app.state.sentence_transformer_ef,
+                            )
+                        else:
+                            context = query_doc(
+                                collection_name=doc["collection_name"],
+                                query=query,
+                                k=rag_app.state.TOP_K,
+                                embedding_function=rag_app.state.sentence_transformer_ef,
+                            )
+                    except Exception as e:
+                        print(e)
+                        context = None
+
+                    relevant_contexts.append(context)
+
+                context_string = ""
+                for context in relevant_contexts:
+                    if context:
+                        context_string += " ".join(context["documents"][0]) + "\n"
+
+                ra_content = rag_template(
+                    template=rag_app.state.RAG_TEMPLATE,
+                    context=context_string,
+                    query=query,
+                )
+
+                if content_type == "list":
+                    new_content = []
+                    for content_item in user_message["content"]:
+                        if content_item["type"] == "text":
+                            # Update the text item's content with ra_content
+                            new_content.append({"type": "text", "text": ra_content})
+                        else:
+                            # Keep other types of content as they are
+                            new_content.append(content_item)
+                    new_user_message = {**user_message, "content": new_content}
+                else:
+                    new_user_message = {
+                        **user_message,
+                        "content": ra_content,
+                    }
+
+                data["messages"][last_user_message_idx] = new_user_message
+                del data["docs"]
+
+            modified_body_bytes = json.dumps(data).encode("utf-8")
+
+            # Create a new request with the modified body
+            scope = request.scope
+            scope["body"] = modified_body_bytes
+            request = Request(scope, receive=lambda: self._receive(modified_body_bytes))
+
+        response = await call_next(request)
+        return response
+
+    async def _receive(self, body: bytes):
+        return {"type": "http.request", "body": body, "more_body": False}
+
+
+app.add_middleware(RAGMiddleware)
+
+
 @app.middleware("http")
 @app.middleware("http")
 async def check_url(request: Request, call_next):
 async def check_url(request: Request, call_next):
     start_time = int(time.time())
     start_time = int(time.time())

+ 1 - 1
src/lib/apis/rag/index.ts

@@ -252,7 +252,7 @@ export const queryCollection = async (
 	token: string,
 	token: string,
 	collection_names: string,
 	collection_names: string,
 	query: string,
 	query: string,
-	k: number
+	k: number | null = null
 ) => {
 ) => {
 	let error = null;
 	let error = null;
 
 

+ 39 - 58
src/routes/(app)/+page.svelte

@@ -232,53 +232,6 @@
 	const sendPrompt = async (prompt, parentId) => {
 	const sendPrompt = async (prompt, parentId) => {
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 
 
-		const docs = messages
-			.filter((message) => message?.files ?? null)
-			.map((message) =>
-				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
-			)
-			.flat(1);
-
-		console.log(docs);
-		if (docs.length > 0) {
-			processing = 'Reading';
-			const query = history.messages[parentId].content;
-
-			let relevantContexts = await Promise.all(
-				docs.map(async (doc) => {
-					if (doc.type === 'collection') {
-						return await queryCollection(localStorage.token, doc.collection_names, query).catch(
-							(error) => {
-								console.log(error);
-								return null;
-							}
-						);
-					} else {
-						return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => {
-							console.log(error);
-							return null;
-						});
-					}
-				})
-			);
-			relevantContexts = relevantContexts.filter((context) => context);
-
-			const contextString = relevantContexts.reduce((a, context, i, arr) => {
-				return `${a}${context.documents.join(' ')}\n`;
-			}, '');
-
-			console.log(contextString);
-
-			history.messages[parentId].raContent = await RAGTemplate(
-				localStorage.token,
-				contextString,
-				query
-			);
-			history.messages[parentId].contexts = relevantContexts;
-			await tick();
-			processing = '';
-		}
-
 		await Promise.all(
 		await Promise.all(
 			selectedModels.map(async (modelId) => {
 			selectedModels.map(async (modelId) => {
 				const model = $models.filter((m) => m.id === modelId).at(0);
 				const model = $models.filter((m) => m.id === modelId).at(0);
@@ -342,15 +295,25 @@
 			...messages
 			...messages
 		]
 		]
 			.filter((message) => message)
 			.filter((message) => message)
-			.map((message, idx, arr) => ({
-				role: message.role,
-				content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content,
-				...(message.files && {
-					images: message.files
-						.filter((file) => file.type === 'image')
-						.map((file) => file.url.slice(file.url.indexOf(',') + 1))
-				})
-			}));
+			.map((message, idx, arr) => {
+				// Prepare the base message object
+				const baseMessage = {
+					role: message.role,
+					content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content
+				};
+
+				// Extract and format image URLs if any exist
+				const imageUrls = message.files
+					?.filter((file) => file.type === 'image')
+					.map((file) => file.url.slice(file.url.indexOf(',') + 1));
+
+				// Add images array only if it contains elements
+				if (imageUrls && imageUrls.length > 0) {
+					baseMessage.images = imageUrls;
+				}
+
+				return baseMessage;
+			});
 
 
 		let lastImageIndex = -1;
 		let lastImageIndex = -1;
 
 
@@ -368,6 +331,13 @@
 			}
 			}
 		});
 		});
 
 
+		const docs = messages
+			.filter((message) => message?.files ?? null)
+			.map((message) =>
+				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+			)
+			.flat(1);
+
 		const [res, controller] = await generateChatCompletion(localStorage.token, {
 		const [res, controller] = await generateChatCompletion(localStorage.token, {
 			model: model,
 			model: model,
 			messages: messagesBody,
 			messages: messagesBody,
@@ -375,7 +345,8 @@
 				...($settings.options ?? {})
 				...($settings.options ?? {})
 			},
 			},
 			format: $settings.requestFormat ?? undefined,
 			format: $settings.requestFormat ?? undefined,
-			keep_alive: $settings.keepAlive ?? undefined
+			keep_alive: $settings.keepAlive ?? undefined,
+			docs: docs.length > 0 ? docs : undefined
 		});
 		});
 
 
 		if (res && res.ok) {
 		if (res && res.ok) {
@@ -535,6 +506,15 @@
 		const responseMessage = history.messages[responseMessageId];
 		const responseMessage = history.messages[responseMessageId];
 		scrollToBottom();
 		scrollToBottom();
 
 
+		const docs = messages
+			.filter((message) => message?.files ?? null)
+			.map((message) =>
+				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+			)
+			.flat(1);
+
+		console.log(docs);
+
 		const res = await generateOpenAIChatCompletion(
 		const res = await generateOpenAIChatCompletion(
 			localStorage.token,
 			localStorage.token,
 			{
 			{
@@ -583,7 +563,8 @@
 				top_p: $settings?.options?.top_p ?? undefined,
 				top_p: $settings?.options?.top_p ?? undefined,
 				num_ctx: $settings?.options?.num_ctx ?? undefined,
 				num_ctx: $settings?.options?.num_ctx ?? undefined,
 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
-				max_tokens: $settings?.options?.num_predict ?? undefined
+				max_tokens: $settings?.options?.num_predict ?? undefined,
+				docs: docs.length > 0 ? docs : undefined
 			},
 			},
 			model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
 			model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
 		);
 		);

+ 40 - 58
src/routes/(app)/c/[id]/+page.svelte

@@ -245,53 +245,6 @@
 	const sendPrompt = async (prompt, parentId) => {
 	const sendPrompt = async (prompt, parentId) => {
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 
 
-		const docs = messages
-			.filter((message) => message?.files ?? null)
-			.map((message) =>
-				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
-			)
-			.flat(1);
-
-		console.log(docs);
-		if (docs.length > 0) {
-			processing = 'Reading';
-			const query = history.messages[parentId].content;
-
-			let relevantContexts = await Promise.all(
-				docs.map(async (doc) => {
-					if (doc.type === 'collection') {
-						return await queryCollection(localStorage.token, doc.collection_names, query).catch(
-							(error) => {
-								console.log(error);
-								return null;
-							}
-						);
-					} else {
-						return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => {
-							console.log(error);
-							return null;
-						});
-					}
-				})
-			);
-			relevantContexts = relevantContexts.filter((context) => context);
-
-			const contextString = relevantContexts.reduce((a, context, i, arr) => {
-				return `${a}${context.documents.join(' ')}\n`;
-			}, '');
-
-			console.log(contextString);
-
-			history.messages[parentId].raContent = await RAGTemplate(
-				localStorage.token,
-				contextString,
-				query
-			);
-			history.messages[parentId].contexts = relevantContexts;
-			await tick();
-			processing = '';
-		}
-
 		await Promise.all(
 		await Promise.all(
 			selectedModels.map(async (modelId) => {
 			selectedModels.map(async (modelId) => {
 				const model = $models.filter((m) => m.id === modelId).at(0);
 				const model = $models.filter((m) => m.id === modelId).at(0);
@@ -355,15 +308,25 @@
 			...messages
 			...messages
 		]
 		]
 			.filter((message) => message)
 			.filter((message) => message)
-			.map((message, idx, arr) => ({
-				role: message.role,
-				content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content,
-				...(message.files && {
-					images: message.files
-						.filter((file) => file.type === 'image')
-						.map((file) => file.url.slice(file.url.indexOf(',') + 1))
-				})
-			}));
+			.map((message, idx, arr) => {
+				// Prepare the base message object
+				const baseMessage = {
+					role: message.role,
+					content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content
+				};
+
+				// Extract and format image URLs if any exist
+				const imageUrls = message.files
+					?.filter((file) => file.type === 'image')
+					.map((file) => file.url.slice(file.url.indexOf(',') + 1));
+
+				// Add images array only if it contains elements
+				if (imageUrls && imageUrls.length > 0) {
+					baseMessage.images = imageUrls;
+				}
+
+				return baseMessage;
+			});
 
 
 		let lastImageIndex = -1;
 		let lastImageIndex = -1;
 
 
@@ -381,6 +344,13 @@
 			}
 			}
 		});
 		});
 
 
+		const docs = messages
+			.filter((message) => message?.files ?? null)
+			.map((message) =>
+				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+			)
+			.flat(1);
+
 		const [res, controller] = await generateChatCompletion(localStorage.token, {
 		const [res, controller] = await generateChatCompletion(localStorage.token, {
 			model: model,
 			model: model,
 			messages: messagesBody,
 			messages: messagesBody,
@@ -388,7 +358,8 @@
 				...($settings.options ?? {})
 				...($settings.options ?? {})
 			},
 			},
 			format: $settings.requestFormat ?? undefined,
 			format: $settings.requestFormat ?? undefined,
-			keep_alive: $settings.keepAlive ?? undefined
+			keep_alive: $settings.keepAlive ?? undefined,
+			docs: docs.length > 0 ? docs : undefined
 		});
 		});
 
 
 		if (res && res.ok) {
 		if (res && res.ok) {
@@ -548,6 +519,15 @@
 		const responseMessage = history.messages[responseMessageId];
 		const responseMessage = history.messages[responseMessageId];
 		scrollToBottom();
 		scrollToBottom();
 
 
+		const docs = messages
+			.filter((message) => message?.files ?? null)
+			.map((message) =>
+				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+			)
+			.flat(1);
+
+		console.log(docs);
+
 		const res = await generateOpenAIChatCompletion(
 		const res = await generateOpenAIChatCompletion(
 			localStorage.token,
 			localStorage.token,
 			{
 			{
@@ -596,7 +576,8 @@
 				top_p: $settings?.options?.top_p ?? undefined,
 				top_p: $settings?.options?.top_p ?? undefined,
 				num_ctx: $settings?.options?.num_ctx ?? undefined,
 				num_ctx: $settings?.options?.num_ctx ?? undefined,
 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
-				max_tokens: $settings?.options?.num_predict ?? undefined
+				max_tokens: $settings?.options?.num_predict ?? undefined,
+				docs: docs.length > 0 ? docs : undefined
 			},
 			},
 			model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
 			model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
 		);
 		);
@@ -710,6 +691,7 @@
 			await setChatTitle(_chatId, userPrompt);
 			await setChatTitle(_chatId, userPrompt);
 		}
 		}
 	};
 	};
+
 	const stopResponse = () => {
 	const stopResponse = () => {
 		stopResponseFlag = true;
 		stopResponseFlag = true;
 		console.log('stopResponse');
 		console.log('stopResponse');