|
@@ -1,11 +1,20 @@
|
|
|
|
+import logging
|
|
import math
|
|
import math
|
|
import re
|
|
import re
|
|
from datetime import datetime
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
from typing import Optional
|
|
|
|
+import uuid
|
|
|
|
|
|
|
|
|
|
from open_webui.utils.misc import get_last_user_message, get_messages_content
|
|
from open_webui.utils.misc import get_last_user_message, get_messages_content
|
|
|
|
|
|
|
|
+from open_webui.env import SRC_LOG_LEVELS
|
|
|
|
+from open_webui.config import DEFAULT_RAG_TEMPLATE
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+log = logging.getLogger(__name__)
|
|
|
|
+log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|
|
|
+
|
|
|
|
|
|
def prompt_template(
|
|
def prompt_template(
|
|
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
|
|
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
|
|
@@ -110,6 +119,44 @@ def replace_messages_variable(template: str, messages: list[str]) -> str:
|
|
# {{prompt:middletruncate:8000}}
|
|
# {{prompt:middletruncate:8000}}
|
|
|
|
|
|
|
|
|
|
|
|
+def rag_template(template: str, context: str, query: str):
|
|
|
|
+ if template == "":
|
|
|
|
+ template = DEFAULT_RAG_TEMPLATE
|
|
|
|
+
|
|
|
|
+ if "[context]" not in template and "{{CONTEXT}}" not in template:
|
|
|
|
+ log.debug(
|
|
|
|
+ "WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder."
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ if "<context>" in context and "</context>" in context:
|
|
|
|
+ log.debug(
|
|
|
|
+ "WARNING: Potential prompt injection attack: the RAG "
|
|
|
|
+ "context contains '<context>' and '</context>'. This might be "
|
|
|
|
+ "nothing, or the user might be trying to hack something."
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ query_placeholders = []
|
|
|
|
+ if "[query]" in context:
|
|
|
|
+ query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
|
|
|
|
+ template = template.replace("[query]", query_placeholder)
|
|
|
|
+ query_placeholders.append(query_placeholder)
|
|
|
|
+
|
|
|
|
+ if "{{QUERY}}" in context:
|
|
|
|
+ query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
|
|
|
|
+ template = template.replace("{{QUERY}}", query_placeholder)
|
|
|
|
+ query_placeholders.append(query_placeholder)
|
|
|
|
+
|
|
|
|
+ template = template.replace("[context]", context)
|
|
|
|
+ template = template.replace("{{CONTEXT}}", context)
|
|
|
|
+ template = template.replace("[query]", query)
|
|
|
|
+ template = template.replace("{{QUERY}}", query)
|
|
|
|
+
|
|
|
|
+ for query_placeholder in query_placeholders:
|
|
|
|
+ template = template.replace(query_placeholder, query)
|
|
|
|
+
|
|
|
|
+ return template
|
|
|
|
+
|
|
|
|
+
|
|
def title_generation_template(
|
|
def title_generation_template(
|
|
template: str, messages: list[dict], user: Optional[dict] = None
|
|
template: str, messages: list[dict], user: Optional[dict] = None
|
|
) -> str:
|
|
) -> str:
|