|
@@ -1,5 +1,6 @@
|
|
|
import logging
|
|
|
import os
|
|
|
+import uuid
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
import requests
|
|
@@ -197,8 +198,15 @@ def rag_template(template: str, context: str, query: str):
|
|
|
f"RAG template contains an unexpected number of '[context]' : {count}"
|
|
|
)
|
|
|
assert "[context]" in template, "RAG template does not contain '[context]'"
|
|
|
- template = template.replace("[context]", context)
|
|
|
- template = template.replace("[query]", query)
|
|
|
+
|
|
|
+ if "[query]" in context:
|
|
|
+ query_placeholder = str(uuid.uuid4())
|
|
|
+ template = template.replace("[QUERY]", query_placeholder)
|
|
|
+ template = template.replace("[context]", context)
|
|
|
+ template = template.replace(query_placeholder, query)
|
|
|
+ else:
|
|
|
+ template = template.replace("[context]", context)
|
|
|
+ template = template.replace("[query]", query)
|
|
|
return template
|
|
|
|
|
|
|