Browse Source

refac: tools & rag

Timothy J. Baek 10 months ago
parent
commit
bd5a8567ef
2 changed files with 26 additions and 34 deletions
  1. 2 12
      backend/apps/rag/utils.py
  2. 24 22
      backend/main.py

+ 2 - 12
backend/apps/rag/utils.py

@@ -236,10 +236,9 @@ def get_embedding_function(
         return lambda query: generate_multiple(query, func)
 
 
-def rag_messages(
+def get_rag_context(
     docs,
     messages,
-    template,
     embedding_function,
     k,
     reranking_function,
@@ -318,16 +317,7 @@ def rag_messages(
 
     context_string = context_string.strip()
 
-    ra_content = rag_template(
-        template=template,
-        context=context_string,
-        query=query,
-    )
-
-    log.debug(f"ra_content: {ra_content}")
-    messages = add_or_update_system_message(ra_content, messages)
-
-    return messages, citations
+    return context_string, citations
 
 
 def get_model_path(model: str, update_model: bool = False):

+ 24 - 22
backend/main.py

@@ -64,7 +64,7 @@ from utils.task import (
 )
 from utils.misc import get_last_user_message, add_or_update_system_message
 
-from apps.rag.utils import rag_messages, rag_template
+from apps.rag.utils import get_rag_context, rag_template
 
 from config import (
     CONFIG_DATA,
@@ -248,6 +248,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             # Parse string to JSON
             data = json.loads(body_str) if body_str else {}
 
+            user = get_current_user(
+                get_http_authorization_cred(request.headers.get("Authorization"))
+            )
+
             # Remove the citations from the body
             return_citations = data.get("citations", False)
             if "citations" in data:
@@ -276,13 +280,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 ):
                     task_model_id = app.state.config.TASK_MODEL_EXTERNAL
 
+            context = ""
+
+            # If tool_ids field is present, call the functions
             if "tool_ids" in data:
-                user = get_current_user(
-                    get_http_authorization_cred(request.headers.get("Authorization"))
-                )
                 prompt = get_last_user_message(data["messages"])
-                context = ""
-
                 for tool_id in data["tool_ids"]:
                     print(tool_id)
                     response = await get_function_call_response(
@@ -295,37 +297,37 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
                     if response:
                         context += ("\n" if context != "" else "") + response
-
-                if context != "":
-                    system_prompt = rag_template(
-                        rag_app.state.config.RAG_TEMPLATE, context, prompt
-                    )
-
-                    print(system_prompt)
-
-                    data["messages"] = add_or_update_system_message(
-                        f"\n{system_prompt}", data["messages"]
-                    )
-
                 del data["tool_ids"]
 
             # If docs field is present, generate RAG completions
             if "docs" in data:
                 data = {**data}
-                data["messages"], citations = rag_messages(
+                rag_context, citations = get_rag_context(
                     docs=data["docs"],
                     messages=data["messages"],
-                    template=rag_app.state.config.RAG_TEMPLATE,
                     embedding_function=rag_app.state.EMBEDDING_FUNCTION,
                     k=rag_app.state.config.TOP_K,
                     reranking_function=rag_app.state.sentence_transformer_rf,
                     r=rag_app.state.config.RELEVANCE_THRESHOLD,
                     hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
                 )
+
+                if rag_context:
+                    context += ("\n" if context != "" else "") + rag_context
+
                 del data["docs"]
 
-                log.debug(
-                    f"data['messages']: {data['messages']}, citations: {citations}"
+                log.debug(f"rag_context: {rag_context}, citations: {citations}")
+
+            if context != "":
+                system_prompt = rag_template(
+                    rag_app.state.config.RAG_TEMPLATE, context, prompt
+                )
+
+                print(system_prompt)
+
+                data["messages"] = add_or_update_system_message(
+                    f"\n{system_prompt}", data["messages"]
                 )
 
             modified_body_bytes = json.dumps(data).encode("utf-8")