瀏覽代碼

refac: rag to backend

Timothy J. Baek 1 年之前
父節點
當前提交
c49491e516
共有 4 個文件被更改,包括 113 次插入50 次删除
  1. 8 0
      backend/apps/rag/utils.py
  2. 86 0
      backend/main.py
  3. 1 1
      src/lib/apis/rag/index.ts
  4. 18 49
      src/routes/(app)/+page.svelte

+ 8 - 0
backend/apps/rag/utils.py

@@ -1,3 +1,4 @@
+import re
 from typing import List
 from typing import List
 
 
 from config import CHROMA_CLIENT
 from config import CHROMA_CLIENT
@@ -87,3 +88,10 @@ def query_collection(
             pass
             pass
 
 
     return merge_and_sort_query_results(results, k)
     return merge_and_sort_query_results(results, k)
+
+
+def rag_template(template: str, context: str, query: str):
+    template = re.sub(r"\[context\]", context, template)
+    template = re.sub(r"\[query\]", query, template)
+
+    return template

+ 86 - 0
backend/main.py

@@ -12,6 +12,7 @@ from fastapi import HTTPException
 from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from starlette.exceptions import HTTPException as StarletteHTTPException
 from starlette.exceptions import HTTPException as StarletteHTTPException
+from starlette.middleware.base import BaseHTTPMiddleware
 
 
 
 
 from apps.ollama.main import app as ollama_app
 from apps.ollama.main import app as ollama_app
@@ -23,6 +24,8 @@ from apps.rag.main import app as rag_app
 from apps.web.main import app as webui_app
 from apps.web.main import app as webui_app
 
 
 
 
+from apps.rag.utils import query_doc, query_collection, rag_template
+
 from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
 from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 
 
@@ -56,6 +59,89 @@ async def on_startup():
     await litellm_app_startup()
     await litellm_app_startup()
 
 
 
 
+class RAGMiddleware(BaseHTTPMiddleware):
+    async def dispatch(self, request: Request, call_next):
+
+        print(request.url.path)
+        if request.method == "POST":
+            # Read the original request body
+            body = await request.body()
+            # Decode body to string
+            body_str = body.decode("utf-8")
+            # Parse string to JSON
+            data = json.loads(body_str) if body_str else {}
+
+            # Example: Add a new key-value pair or modify existing ones
+            # data["modified"] = True  # Example modification
+            if "docs" in data:
+                docs = data["docs"]
+                print(docs)
+
+                last_user_message_idx = None
+                for i in range(len(data["messages"]) - 1, -1, -1):
+                    if data["messages"][i]["role"] == "user":
+                        last_user_message_idx = i
+                        break
+
+                query = data["messages"][last_user_message_idx]["content"]
+
+                relevant_contexts = []
+
+                for doc in docs:
+                    context = None
+                    if doc["type"] == "collection":
+                        context = query_collection(
+                            collection_names=doc["collection_names"],
+                            query=query,
+                            k=rag_app.state.TOP_K,
+                            embedding_function=rag_app.state.sentence_transformer_ef,
+                        )
+                    else:
+                        context = query_doc(
+                            collection_name=doc["collection_name"],
+                            query=query,
+                            k=rag_app.state.TOP_K,
+                            embedding_function=rag_app.state.sentence_transformer_ef,
+                        )
+                    relevant_contexts.append(context)
+
+                context_string = ""
+                for context in relevant_contexts:
+                    if context:
+                        context_string += " ".join(context["documents"][0]) + "\n"
+
+                content = rag_template(
+                    template=rag_app.state.RAG_TEMPLATE,
+                    context=context_string,
+                    query=query,
+                )
+
+                new_user_message = {
+                    **data["messages"][last_user_message_idx],
+                    "content": content,
+                }
+                data["messages"][last_user_message_idx] = new_user_message
+                del data["docs"]
+
+            print("DATAAAAAAAAAAAAAAAAAA")
+            print(data)
+            modified_body_bytes = json.dumps(data).encode("utf-8")
+
+            # Create a new request with the modified body
+            scope = request.scope
+            scope["body"] = modified_body_bytes
+            request = Request(scope, receive=lambda: self._receive(modified_body_bytes))
+
+        response = await call_next(request)
+        return response
+
+    async def _receive(self, body: bytes):
+        return {"type": "http.request", "body": body, "more_body": False}
+
+
+app.add_middleware(RAGMiddleware)
+
+
 @app.middleware("http")
 @app.middleware("http")
 async def check_url(request: Request, call_next):
 async def check_url(request: Request, call_next):
     start_time = int(time.time())
     start_time = int(time.time())

+ 1 - 1
src/lib/apis/rag/index.ts

@@ -252,7 +252,7 @@ export const queryCollection = async (
 	token: string,
 	token: string,
 	collection_names: string,
 	collection_names: string,
 	query: string,
 	query: string,
-	k: number
+	k: number | null = null
 ) => {
 ) => {
 	let error = null;
 	let error = null;
 
 

+ 18 - 49
src/routes/(app)/+page.svelte

@@ -232,53 +232,6 @@
 	const sendPrompt = async (prompt, parentId) => {
 	const sendPrompt = async (prompt, parentId) => {
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 
 
-		const docs = messages
-			.filter((message) => message?.files ?? null)
-			.map((message) =>
-				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
-			)
-			.flat(1);
-
-		console.log(docs);
-		if (docs.length > 0) {
-			processing = 'Reading';
-			const query = history.messages[parentId].content;
-
-			let relevantContexts = await Promise.all(
-				docs.map(async (doc) => {
-					if (doc.type === 'collection') {
-						return await queryCollection(localStorage.token, doc.collection_names, query).catch(
-							(error) => {
-								console.log(error);
-								return null;
-							}
-						);
-					} else {
-						return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => {
-							console.log(error);
-							return null;
-						});
-					}
-				})
-			);
-			relevantContexts = relevantContexts.filter((context) => context);
-
-			const contextString = relevantContexts.reduce((a, context, i, arr) => {
-				return `${a}${context.documents.join(' ')}\n`;
-			}, '');
-
-			console.log(contextString);
-
-			history.messages[parentId].raContent = await RAGTemplate(
-				localStorage.token,
-				contextString,
-				query
-			);
-			history.messages[parentId].contexts = relevantContexts;
-			await tick();
-			processing = '';
-		}
-
 		await Promise.all(
 		await Promise.all(
 			selectedModels.map(async (modelId) => {
 			selectedModels.map(async (modelId) => {
 				const model = $models.filter((m) => m.id === modelId).at(0);
 				const model = $models.filter((m) => m.id === modelId).at(0);
@@ -368,6 +321,13 @@
 			}
 			}
 		});
 		});
 
 
+		const docs = messages
+			.filter((message) => message?.files ?? null)
+			.map((message) =>
+				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+			)
+			.flat(1);
+
 		const [res, controller] = await generateChatCompletion(localStorage.token, {
 		const [res, controller] = await generateChatCompletion(localStorage.token, {
 			model: model,
 			model: model,
 			messages: messagesBody,
 			messages: messagesBody,
@@ -375,7 +335,8 @@
 				...($settings.options ?? {})
 				...($settings.options ?? {})
 			},
 			},
 			format: $settings.requestFormat ?? undefined,
 			format: $settings.requestFormat ?? undefined,
-			keep_alive: $settings.keepAlive ?? undefined
+			keep_alive: $settings.keepAlive ?? undefined,
+			docs: docs
 		});
 		});
 
 
 		if (res && res.ok) {
 		if (res && res.ok) {
@@ -535,6 +496,13 @@
 		const responseMessage = history.messages[responseMessageId];
 		const responseMessage = history.messages[responseMessageId];
 		scrollToBottom();
 		scrollToBottom();
 
 
+		const docs = messages
+			.filter((message) => message?.files ?? null)
+			.map((message) =>
+				message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+			)
+			.flat(1);
+
 		const res = await generateOpenAIChatCompletion(
 		const res = await generateOpenAIChatCompletion(
 			localStorage.token,
 			localStorage.token,
 			{
 			{
@@ -583,7 +551,8 @@
 				top_p: $settings?.options?.top_p ?? undefined,
 				top_p: $settings?.options?.top_p ?? undefined,
 				num_ctx: $settings?.options?.num_ctx ?? undefined,
 				num_ctx: $settings?.options?.num_ctx ?? undefined,
 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
-				max_tokens: $settings?.options?.num_predict ?? undefined
+				max_tokens: $settings?.options?.num_predict ?? undefined,
+				docs: docs
 			},
 			},
 			model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
 			model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
 		);
 		);