浏览代码

refac: naming convention

Timothy J. Baek 1 年之前
父节点
当前提交
cebf733b9d
共有 4 个文件被更改,包括 28 次插入16 次删除
  1. 16 7
      backend/apps/rag/main.py
  2. 7 7
      backend/apps/rag/utils.py
  3. 4 1
      backend/config.py
  4. 1 1
      backend/main.py

+ 16 - 7
backend/apps/rag/main.py

@@ -70,7 +70,7 @@ from config import (
     RAG_EMBEDDING_MODEL,
     RAG_EMBEDDING_MODEL,
     RAG_EMBEDDING_MODEL_AUTO_UPDATE,
     RAG_EMBEDDING_MODEL_AUTO_UPDATE,
     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
-    RAG_HYBRID,
+    ENABLE_RAG_HYBRID_SEARCH,
     RAG_RERANKING_MODEL,
     RAG_RERANKING_MODEL,
     RAG_RERANKING_MODEL_AUTO_UPDATE,
     RAG_RERANKING_MODEL_AUTO_UPDATE,
     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
@@ -92,7 +92,8 @@ app = FastAPI()
 
 
 app.state.TOP_K = RAG_TOP_K
 app.state.TOP_K = RAG_TOP_K
 app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
 app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
-app.state.HYBRID = RAG_HYBRID
+
+app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
 
 
 app.state.CHUNK_SIZE = CHUNK_SIZE
 app.state.CHUNK_SIZE = CHUNK_SIZE
 app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
 app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
@@ -324,7 +325,7 @@ async def get_query_settings(user=Depends(get_admin_user)):
         "template": app.state.RAG_TEMPLATE,
         "template": app.state.RAG_TEMPLATE,
         "k": app.state.TOP_K,
         "k": app.state.TOP_K,
         "r": app.state.RELEVANCE_THRESHOLD,
         "r": app.state.RELEVANCE_THRESHOLD,
-        "hybrid": app.state.HYBRID,
+        "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
     }
     }
 
 
 
 
@@ -342,13 +343,13 @@ async def update_query_settings(
     app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
     app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
     app.state.TOP_K = form_data.k if form_data.k else 4
     app.state.TOP_K = form_data.k if form_data.k else 4
     app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
     app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
-    app.state.HYBRID = form_data.hybrid if form_data.hybrid else False
+    app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False
     return {
     return {
         "status": True,
         "status": True,
         "template": app.state.RAG_TEMPLATE,
         "template": app.state.RAG_TEMPLATE,
         "k": app.state.TOP_K,
         "k": app.state.TOP_K,
         "r": app.state.RELEVANCE_THRESHOLD,
         "r": app.state.RELEVANCE_THRESHOLD,
-        "hybrid": app.state.HYBRID,
+        "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
     }
     }
 
 
 
 
@@ -381,7 +382,11 @@ def query_doc_handler(
             r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
             r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
             embeddings_function=embeddings_function,
             embeddings_function=embeddings_function,
             reranking_function=app.state.sentence_transformer_rf,
             reranking_function=app.state.sentence_transformer_rf,
-            hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID,
+            hybrid_search=(
+                form_data.hybrid
+                if form_data.hybrid
+                else app.state.ENABLE_RAG_HYBRID_SEARCH
+            ),
         )
         )
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)
@@ -420,7 +425,11 @@ def query_collection_handler(
             r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
             r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
             embeddings_function=embeddings_function,
             embeddings_function=embeddings_function,
             reranking_function=app.state.sentence_transformer_rf,
             reranking_function=app.state.sentence_transformer_rf,
-            hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID,
+            hybrid_search=(
+                form_data.hybrid
+                if form_data.hybrid
+                else app.state.ENABLE_RAG_HYBRID_SEARCH
+            ),
         )
         )
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)

+ 7 - 7
backend/apps/rag/utils.py

@@ -33,12 +33,12 @@ def query_embeddings_doc(
     reranking_function,
     reranking_function,
     k: int,
     k: int,
     r: int,
     r: int,
-    hybrid: bool,
+    hybrid_search: bool,
 ):
 ):
     try:
     try:
         collection = CHROMA_CLIENT.get_collection(name=collection_name)
         collection = CHROMA_CLIENT.get_collection(name=collection_name)
 
 
-        if hybrid:
+        if hybrid_search:
             documents = collection.get()  # get all documents
             documents = collection.get()  # get all documents
             bm25_retriever = BM25Retriever.from_texts(
             bm25_retriever = BM25Retriever.from_texts(
                 texts=documents.get("documents"),
                 texts=documents.get("documents"),
@@ -134,7 +134,7 @@ def query_embeddings_collection(
     r: float,
     r: float,
     embeddings_function,
     embeddings_function,
     reranking_function,
     reranking_function,
-    hybrid: bool,
+    hybrid_search: bool,
 ):
 ):
 
 
     results = []
     results = []
@@ -148,7 +148,7 @@ def query_embeddings_collection(
                 r=r,
                 r=r,
                 embeddings_function=embeddings_function,
                 embeddings_function=embeddings_function,
                 reranking_function=reranking_function,
                 reranking_function=reranking_function,
-                hybrid=hybrid,
+                hybrid_search=hybrid_search,
             )
             )
             results.append(result)
             results.append(result)
         except:
         except:
@@ -206,7 +206,7 @@ def rag_messages(
     template,
     template,
     k,
     k,
     r,
     r,
-    hybrid,
+    hybrid_search,
     embedding_engine,
     embedding_engine,
     embedding_model,
     embedding_model,
     embedding_function,
     embedding_function,
@@ -279,7 +279,7 @@ def rag_messages(
                     r=r,
                     r=r,
                     embeddings_function=embeddings_function,
                     embeddings_function=embeddings_function,
                     reranking_function=reranking_function,
                     reranking_function=reranking_function,
-                    hybrid=hybrid,
+                    hybrid_search=hybrid_search,
                 )
                 )
             else:
             else:
                 context = query_embeddings_doc(
                 context = query_embeddings_doc(
@@ -289,7 +289,7 @@ def rag_messages(
                     r=r,
                     r=r,
                     embeddings_function=embeddings_function,
                     embeddings_function=embeddings_function,
                     reranking_function=reranking_function,
                     reranking_function=reranking_function,
-                    hybrid=hybrid,
+                    hybrid_search=hybrid_search,
                 )
                 )
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)

+ 4 - 1
backend/config.py

@@ -422,7 +422,10 @@ CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
 
 
 RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5"))
 RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5"))
 RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0"))
 RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0"))
-RAG_HYBRID = os.environ.get("RAG_HYBRID", "").lower() == "true"
+
+ENABLE_RAG_HYBRID_SEARCH = (
+    os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true"
+)
 
 
 RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
 RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
 
 

+ 1 - 1
backend/main.py

@@ -121,7 +121,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
                     rag_app.state.RAG_TEMPLATE,
                     rag_app.state.RAG_TEMPLATE,
                     rag_app.state.TOP_K,
                     rag_app.state.TOP_K,
                     rag_app.state.RELEVANCE_THRESHOLD,
                     rag_app.state.RELEVANCE_THRESHOLD,
-                    rag_app.state.HYBRID,
+                    rag_app.state.ENABLE_RAG_HYBRID_SEARCH,
                     rag_app.state.RAG_EMBEDDING_ENGINE,
                     rag_app.state.RAG_EMBEDDING_ENGINE,
                     rag_app.state.RAG_EMBEDDING_MODEL,
                     rag_app.state.RAG_EMBEDDING_MODEL,
                     rag_app.state.sentence_transformer_ef,
                     rag_app.state.sentence_transformer_ef,