|
@@ -64,7 +64,7 @@ from utils.task import (
|
|
)
|
|
)
|
|
from utils.misc import get_last_user_message, add_or_update_system_message
|
|
from utils.misc import get_last_user_message, add_or_update_system_message
|
|
|
|
|
|
-from apps.rag.utils import rag_messages, rag_template
|
|
|
|
|
|
+from apps.rag.utils import get_rag_context, rag_template
|
|
|
|
|
|
from config import (
|
|
from config import (
|
|
CONFIG_DATA,
|
|
CONFIG_DATA,
|
|
@@ -248,6 +248,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
# Parse string to JSON
|
|
# Parse string to JSON
|
|
data = json.loads(body_str) if body_str else {}
|
|
data = json.loads(body_str) if body_str else {}
|
|
|
|
|
|
|
|
+ user = get_current_user(
|
|
|
|
+ get_http_authorization_cred(request.headers.get("Authorization"))
|
|
|
|
+ )
|
|
|
|
+
|
|
# Remove the citations from the body
|
|
# Remove the citations from the body
|
|
return_citations = data.get("citations", False)
|
|
return_citations = data.get("citations", False)
|
|
if "citations" in data:
|
|
if "citations" in data:
|
|
@@ -276,13 +280,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
):
|
|
):
|
|
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
|
|
|
|
|
+ context = ""
|
|
|
|
+
|
|
|
|
+ # If tool_ids field is present, call the functions
|
|
if "tool_ids" in data:
|
|
if "tool_ids" in data:
|
|
- user = get_current_user(
|
|
|
|
- get_http_authorization_cred(request.headers.get("Authorization"))
|
|
|
|
- )
|
|
|
|
prompt = get_last_user_message(data["messages"])
|
|
prompt = get_last_user_message(data["messages"])
|
|
- context = ""
|
|
|
|
-
|
|
|
|
for tool_id in data["tool_ids"]:
|
|
for tool_id in data["tool_ids"]:
|
|
print(tool_id)
|
|
print(tool_id)
|
|
response = await get_function_call_response(
|
|
response = await get_function_call_response(
|
|
@@ -295,37 +297,37 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
|
|
|
if response:
|
|
if response:
|
|
context += ("\n" if context != "" else "") + response
|
|
context += ("\n" if context != "" else "") + response
|
|
-
|
|
|
|
- if context != "":
|
|
|
|
- system_prompt = rag_template(
|
|
|
|
- rag_app.state.config.RAG_TEMPLATE, context, prompt
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- print(system_prompt)
|
|
|
|
-
|
|
|
|
- data["messages"] = add_or_update_system_message(
|
|
|
|
- f"\n{system_prompt}", data["messages"]
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
del data["tool_ids"]
|
|
del data["tool_ids"]
|
|
|
|
|
|
# If docs field is present, generate RAG completions
|
|
# If docs field is present, generate RAG completions
|
|
if "docs" in data:
|
|
if "docs" in data:
|
|
data = {**data}
|
|
data = {**data}
|
|
- data["messages"], citations = rag_messages(
|
|
|
|
|
|
+ rag_context, citations = get_rag_context(
|
|
docs=data["docs"],
|
|
docs=data["docs"],
|
|
messages=data["messages"],
|
|
messages=data["messages"],
|
|
- template=rag_app.state.config.RAG_TEMPLATE,
|
|
|
|
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
|
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
|
k=rag_app.state.config.TOP_K,
|
|
k=rag_app.state.config.TOP_K,
|
|
reranking_function=rag_app.state.sentence_transformer_rf,
|
|
reranking_function=rag_app.state.sentence_transformer_rf,
|
|
r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
|
r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
|
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
|
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
|
)
|
|
)
|
|
|
|
+
|
|
|
|
+ if rag_context:
|
|
|
|
+ context += ("\n" if context != "" else "") + rag_context
|
|
|
|
+
|
|
del data["docs"]
|
|
del data["docs"]
|
|
|
|
|
|
- log.debug(
|
|
|
|
- f"data['messages']: {data['messages']}, citations: {citations}"
|
|
|
|
|
|
+ log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
|
|
|
+
|
|
|
|
+ if context != "":
|
|
|
|
+ system_prompt = rag_template(
|
|
|
|
+ rag_app.state.config.RAG_TEMPLATE, context, prompt
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ print(system_prompt)
|
|
|
|
+
|
|
|
|
+ data["messages"] = add_or_update_system_message(
|
|
|
|
+ f"\n{system_prompt}", data["messages"]
|
|
)
|
|
)
|
|
|
|
|
|
modified_body_bytes = json.dumps(data).encode("utf-8")
|
|
modified_body_bytes = json.dumps(data).encode("utf-8")
|