|
@@ -12,6 +12,7 @@ from fastapi import HTTPException
|
|
from fastapi.middleware.wsgi import WSGIMiddleware
|
|
from fastapi.middleware.wsgi import WSGIMiddleware
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
|
|
+from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
|
|
|
|
|
|
from apps.ollama.main import app as ollama_app
|
|
from apps.ollama.main import app as ollama_app
|
|
@@ -23,6 +24,8 @@ from apps.rag.main import app as rag_app
|
|
from apps.web.main import app as webui_app
|
|
from apps.web.main import app as webui_app
|
|
|
|
|
|
|
|
|
|
|
|
+from apps.rag.utils import query_doc, query_collection, rag_template
|
|
|
|
+
|
|
from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
|
|
from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
|
|
from constants import ERROR_MESSAGES
|
|
from constants import ERROR_MESSAGES
|
|
|
|
|
|
@@ -56,6 +59,89 @@ async def on_startup():
|
|
await litellm_app_startup()
|
|
await litellm_app_startup()
|
|
|
|
|
|
|
|
|
|
|
|
+class RAGMiddleware(BaseHTTPMiddleware):
|
|
|
|
+ async def dispatch(self, request: Request, call_next):
|
|
|
|
+
|
|
|
|
+ print(request.url.path)
|
|
|
|
+ if request.method == "POST":
|
|
|
|
+ # Read the original request body
|
|
|
|
+ body = await request.body()
|
|
|
|
+ # Decode body to string
|
|
|
|
+ body_str = body.decode("utf-8")
|
|
|
|
+ # Parse string to JSON
|
|
|
|
+ data = json.loads(body_str) if body_str else {}
|
|
|
|
+
|
|
|
|
+ # Example: Add a new key-value pair or modify existing ones
|
|
|
|
+ # data["modified"] = True # Example modification
|
|
|
|
+ if "docs" in data:
|
|
|
|
+ docs = data["docs"]
|
|
|
|
+ print(docs)
|
|
|
|
+
|
|
|
|
+ last_user_message_idx = None
|
|
|
|
+ for i in range(len(data["messages"]) - 1, -1, -1):
|
|
|
|
+ if data["messages"][i]["role"] == "user":
|
|
|
|
+ last_user_message_idx = i
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+ query = data["messages"][last_user_message_idx]["content"]
|
|
|
|
+
|
|
|
|
+ relevant_contexts = []
|
|
|
|
+
|
|
|
|
+ for doc in docs:
|
|
|
|
+ context = None
|
|
|
|
+ if doc["type"] == "collection":
|
|
|
|
+ context = query_collection(
|
|
|
|
+ collection_names=doc["collection_names"],
|
|
|
|
+ query=query,
|
|
|
|
+ k=rag_app.state.TOP_K,
|
|
|
|
+ embedding_function=rag_app.state.sentence_transformer_ef,
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ context = query_doc(
|
|
|
|
+ collection_name=doc["collection_name"],
|
|
|
|
+ query=query,
|
|
|
|
+ k=rag_app.state.TOP_K,
|
|
|
|
+ embedding_function=rag_app.state.sentence_transformer_ef,
|
|
|
|
+ )
|
|
|
|
+ relevant_contexts.append(context)
|
|
|
|
+
|
|
|
|
+ context_string = ""
|
|
|
|
+ for context in relevant_contexts:
|
|
|
|
+ if context:
|
|
|
|
+ context_string += " ".join(context["documents"][0]) + "\n"
|
|
|
|
+
|
|
|
|
+ content = rag_template(
|
|
|
|
+ template=rag_app.state.RAG_TEMPLATE,
|
|
|
|
+ context=context_string,
|
|
|
|
+ query=query,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ new_user_message = {
|
|
|
|
+ **data["messages"][last_user_message_idx],
|
|
|
|
+ "content": content,
|
|
|
|
+ }
|
|
|
|
+ data["messages"][last_user_message_idx] = new_user_message
|
|
|
|
+ del data["docs"]
|
|
|
|
+
|
|
|
|
+ print("DATAAAAAAAAAAAAAAAAAA")
|
|
|
|
+ print(data)
|
|
|
|
+ modified_body_bytes = json.dumps(data).encode("utf-8")
|
|
|
|
+
|
|
|
|
+ # Create a new request with the modified body
|
|
|
|
+ scope = request.scope
|
|
|
|
+ scope["body"] = modified_body_bytes
|
|
|
|
+ request = Request(scope, receive=lambda: self._receive(modified_body_bytes))
|
|
|
|
+
|
|
|
|
+ response = await call_next(request)
|
|
|
|
+ return response
|
|
|
|
+
|
|
|
|
+ async def _receive(self, body: bytes):
|
|
|
|
+ return {"type": "http.request", "body": body, "more_body": False}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+app.add_middleware(RAGMiddleware)
|
|
|
|
+
|
|
|
|
+
|
|
@app.middleware("http")
|
|
@app.middleware("http")
|
|
async def check_url(request: Request, call_next):
|
|
async def check_url(request: Request, call_next):
|
|
start_time = int(time.time())
|
|
start_time = int(time.time())
|