|
@@ -15,7 +15,7 @@ from fastapi.middleware.wsgi import WSGIMiddleware
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
-
|
|
|
+from starlette.responses import StreamingResponse
|
|
|
|
|
|
from apps.ollama.main import app as ollama_app
|
|
|
from apps.openai.main import app as openai_app
|
|
@@ -102,6 +102,8 @@ origins = ["*"]
|
|
|
|
|
|
class RAGMiddleware(BaseHTTPMiddleware):
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
|
+ return_citations = False
|
|
|
+
|
|
|
if request.method == "POST" and (
|
|
|
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
|
|
|
):
|
|
@@ -114,11 +116,15 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|
|
# Parse string to JSON
|
|
|
data = json.loads(body_str) if body_str else {}
|
|
|
|
|
|
+ return_citations = data.get("citations", False)
|
|
|
+ if "citations" in data:
|
|
|
+ del data["citations"]
|
|
|
+
|
|
|
# Example: Add a new key-value pair or modify existing ones
|
|
|
# data["modified"] = True # Example modification
|
|
|
if "docs" in data:
|
|
|
data = {**data}
|
|
|
- data["messages"] = rag_messages(
|
|
|
+ data["messages"], citations = rag_messages(
|
|
|
docs=data["docs"],
|
|
|
messages=data["messages"],
|
|
|
template=rag_app.state.RAG_TEMPLATE,
|
|
@@ -130,7 +136,9 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|
|
)
|
|
|
del data["docs"]
|
|
|
|
|
|
- log.debug(f"data['messages']: {data['messages']}")
|
|
|
+ log.debug(
|
|
|
+ f"data['messages']: {data['messages']}, citations: {citations}"
|
|
|
+ )
|
|
|
|
|
|
modified_body_bytes = json.dumps(data).encode("utf-8")
|
|
|
|
|
@@ -148,11 +156,36 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|
|
]
|
|
|
|
|
|
response = await call_next(request)
|
|
|
+
|
|
|
+ if return_citations:
|
|
|
+ # Inject the citations into the response
|
|
|
+ if isinstance(response, StreamingResponse):
|
|
|
+ # If it's a streaming response, inject it as SSE event or NDJSON line
|
|
|
+ content_type = response.headers.get("Content-Type")
|
|
|
+ if "text/event-stream" in content_type:
|
|
|
+ return StreamingResponse(
|
|
|
+ self.openai_stream_wrapper(response.body_iterator, citations),
|
|
|
+ )
|
|
|
+ if "application/x-ndjson" in content_type:
|
|
|
+ return StreamingResponse(
|
|
|
+ self.ollama_stream_wrapper(response.body_iterator, citations),
|
|
|
+ )
|
|
|
+
|
|
|
return response
|
|
|
|
|
|
async def _receive(self, body: bytes):
|
|
|
return {"type": "http.request", "body": body, "more_body": False}
|
|
|
|
|
|
+ async def openai_stream_wrapper(self, original_generator, citations):
|
|
|
+ yield f"data: {json.dumps({'citations': citations})}\n\n"
|
|
|
+ async for data in original_generator:
|
|
|
+ yield data
|
|
|
+
|
|
|
+ async def ollama_stream_wrapper(self, original_generator, citations):
|
|
|
+ yield f"{json.dumps({'citations': citations})}\n"
|
|
|
+ async for data in original_generator:
|
|
|
+ yield data
|
|
|
+
|
|
|
|
|
|
app.add_middleware(RAGMiddleware)
|
|
|
|