|
@@ -85,7 +85,24 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|
|
last_user_message_idx = i
|
|
|
break
|
|
|
|
|
|
- query = data["messages"][last_user_message_idx]["content"]
|
|
|
+ user_message = data["messages"][last_user_message_idx]
|
|
|
+
|
|
|
+ if isinstance(user_message["content"], list):
|
|
|
+ # Handle list content input
|
|
|
+ content_type = "list"
|
|
|
+ query = ""
|
|
|
+ for content_item in user_message["content"]:
|
|
|
+ if content_item["type"] == "text":
|
|
|
+ query = content_item["text"]
|
|
|
+ break
|
|
|
+ elif isinstance(user_message["content"], str):
|
|
|
+ # Handle text content input
|
|
|
+ content_type = "text"
|
|
|
+ query = user_message["content"]
|
|
|
+ else:
|
|
|
+ # Fallback in case the input does not match expected types
|
|
|
+ content_type = None
|
|
|
+ query = ""
|
|
|
|
|
|
relevant_contexts = []
|
|
|
|
|
@@ -112,16 +129,28 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|
|
if context:
|
|
|
context_string += " ".join(context["documents"][0]) + "\n"
|
|
|
|
|
|
- content = rag_template(
|
|
|
+ ra_content = rag_template(
|
|
|
template=rag_app.state.RAG_TEMPLATE,
|
|
|
context=context_string,
|
|
|
query=query,
|
|
|
)
|
|
|
|
|
|
- new_user_message = {
|
|
|
- **data["messages"][last_user_message_idx],
|
|
|
- "content": content,
|
|
|
- }
|
|
|
+ if content_type == "list":
|
|
|
+ new_content = []
|
|
|
+ for content_item in user_message["content"]:
|
|
|
+ if content_item["type"] == "text":
|
|
|
+ # Update the text item's content with ra_content
|
|
|
+ new_content.append({"type": "text", "text": ra_content})
|
|
|
+ else:
|
|
|
+ # Keep other types of content as they are
|
|
|
+ new_content.append(content_item)
|
|
|
+ new_user_message = {**user_message, "content": new_content}
|
|
|
+ else:
|
|
|
+ new_user_message = {
|
|
|
+ **user_message,
|
|
|
+ "content": ra_content,
|
|
|
+ }
|
|
|
+
|
|
|
data["messages"][last_user_message_idx] = new_user_message
|
|
|
del data["docs"]
|
|
|
|