|
@@ -12,6 +12,7 @@ from fastapi import HTTPException
|
|
|
from fastapi.middleware.wsgi import WSGIMiddleware
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
|
+from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
|
|
|
|
|
from apps.ollama.main import app as ollama_app
|
|
@@ -23,6 +24,8 @@ from apps.rag.main import app as rag_app
|
|
|
from apps.web.main import app as webui_app
|
|
|
|
|
|
|
|
|
+from apps.rag.utils import query_doc, query_collection, rag_template
|
|
|
+
|
|
|
from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
|
|
|
from constants import ERROR_MESSAGES
|
|
|
|
|
@@ -56,6 +59,124 @@ async def on_startup():
|
|
|
await litellm_app_startup()
|
|
|
|
|
|
|
|
|
+class RAGMiddleware(BaseHTTPMiddleware):
|
|
|
+ async def dispatch(self, request: Request, call_next):
|
|
|
+ if request.method == "POST" and (
|
|
|
+ "/api/chat" in request.url.path or "/chat/completions" in request.url.path
|
|
|
+ ):
|
|
|
+ print(request.url.path)
|
|
|
+
|
|
|
+ # Read the original request body
|
|
|
+ body = await request.body()
|
|
|
+ # Decode body to string
|
|
|
+ body_str = body.decode("utf-8")
|
|
|
+ # Parse string to JSON
|
|
|
+ data = json.loads(body_str) if body_str else {}
|
|
|
+
|
|
|
+ # Example: Add a new key-value pair or modify existing ones
|
|
|
+ # data["modified"] = True # Example modification
|
|
|
+ if "docs" in data:
|
|
|
+ docs = data["docs"]
|
|
|
+ print(docs)
|
|
|
+
|
|
|
+ last_user_message_idx = None
|
|
|
+ for i in range(len(data["messages"]) - 1, -1, -1):
|
|
|
+ if data["messages"][i]["role"] == "user":
|
|
|
+ last_user_message_idx = i
|
|
|
+ break
|
|
|
+
|
|
|
+ user_message = data["messages"][last_user_message_idx]
|
|
|
+
|
|
|
+ if isinstance(user_message["content"], list):
|
|
|
+ # Handle list content input
|
|
|
+ content_type = "list"
|
|
|
+ query = ""
|
|
|
+ for content_item in user_message["content"]:
|
|
|
+ if content_item["type"] == "text":
|
|
|
+ query = content_item["text"]
|
|
|
+ break
|
|
|
+ elif isinstance(user_message["content"], str):
|
|
|
+ # Handle text content input
|
|
|
+ content_type = "text"
|
|
|
+ query = user_message["content"]
|
|
|
+ else:
|
|
|
+ # Fallback in case the input does not match expected types
|
|
|
+ content_type = None
|
|
|
+ query = ""
|
|
|
+
|
|
|
+ relevant_contexts = []
|
|
|
+
|
|
|
+ for doc in docs:
|
|
|
+ context = None
|
|
|
+
|
|
|
+ try:
|
|
|
+ if doc["type"] == "collection":
|
|
|
+ context = query_collection(
|
|
|
+ collection_names=doc["collection_names"],
|
|
|
+ query=query,
|
|
|
+ k=rag_app.state.TOP_K,
|
|
|
+ embedding_function=rag_app.state.sentence_transformer_ef,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ context = query_doc(
|
|
|
+ collection_name=doc["collection_name"],
|
|
|
+ query=query,
|
|
|
+ k=rag_app.state.TOP_K,
|
|
|
+ embedding_function=rag_app.state.sentence_transformer_ef,
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ print(e)
|
|
|
+ context = None
|
|
|
+
|
|
|
+ relevant_contexts.append(context)
|
|
|
+
|
|
|
+ context_string = ""
|
|
|
+ for context in relevant_contexts:
|
|
|
+ if context:
|
|
|
+ context_string += " ".join(context["documents"][0]) + "\n"
|
|
|
+
|
|
|
+ ra_content = rag_template(
|
|
|
+ template=rag_app.state.RAG_TEMPLATE,
|
|
|
+ context=context_string,
|
|
|
+ query=query,
|
|
|
+ )
|
|
|
+
|
|
|
+ if content_type == "list":
|
|
|
+ new_content = []
|
|
|
+ for content_item in user_message["content"]:
|
|
|
+ if content_item["type"] == "text":
|
|
|
+ # Update the text item's content with ra_content
|
|
|
+ new_content.append({"type": "text", "text": ra_content})
|
|
|
+ else:
|
|
|
+ # Keep other types of content as they are
|
|
|
+ new_content.append(content_item)
|
|
|
+ new_user_message = {**user_message, "content": new_content}
|
|
|
+ else:
|
|
|
+ new_user_message = {
|
|
|
+ **user_message,
|
|
|
+ "content": ra_content,
|
|
|
+ }
|
|
|
+
|
|
|
+ data["messages"][last_user_message_idx] = new_user_message
|
|
|
+ del data["docs"]
|
|
|
+
|
|
|
+ modified_body_bytes = json.dumps(data).encode("utf-8")
|
|
|
+
|
|
|
+ # Create a new request with the modified body
|
|
|
+ scope = request.scope
|
|
|
+ scope["body"] = modified_body_bytes
|
|
|
+ request = Request(scope, receive=lambda: self._receive(modified_body_bytes))
|
|
|
+
|
|
|
+ response = await call_next(request)
|
|
|
+ return response
|
|
|
+
|
|
|
+ async def _receive(self, body: bytes):
|
|
|
+ return {"type": "http.request", "body": body, "more_body": False}
|
|
|
+
|
|
|
+
|
|
|
+app.add_middleware(RAGMiddleware)
|
|
|
+
|
|
|
+
|
|
|
@app.middleware("http")
|
|
|
async def check_url(request: Request, call_next):
|
|
|
start_time = int(time.time())
|