浏览代码

feat: rag backend auth

Timothy J. Baek 1 年之前
父节点
当前提交
70d2571be1
共有 1 个文件被更改,包括 43 次插入20 次删除
  1. 43 20
      backend/apps/rag/main.py

+ 43 - 20
backend/apps/rag/main.py

@@ -24,6 +24,8 @@ from typing import Optional
 
 import uuid
 
+
+from utils.utils import get_current_user
 from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
 from constants import ERROR_MESSAGES
 
@@ -84,7 +86,12 @@ async def get_status():
 
 
 @app.get("/query/{collection_name}")
-def query_collection(collection_name: str, query: str, k: Optional[int] = 4):
+def query_collection(
+    collection_name: str,
+    query: str,
+    k: Optional[int] = 4,
+    user=Depends(get_current_user),
+):
     try:
         collection = CHROMA_CLIENT.get_collection(
             name=collection_name,
@@ -101,7 +108,7 @@ def query_collection(collection_name: str, query: str, k: Optional[int] = 4):
 
 
 @app.post("/web")
-def store_web(form_data: StoreWebForm):
+def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
     # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
     try:
         loader = WebBaseLoader(form_data.url)
@@ -117,7 +124,11 @@ def store_web(form_data: StoreWebForm):
 
 
 @app.post("/doc")
-def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)):
+def store_doc(
+    collection_name: str = Form(...),
+    file: UploadFile = File(...),
+    user=Depends(get_current_user),
+):
     # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
     file.filename = f"{collection_name}-{file.filename}"
 
@@ -159,26 +170,38 @@ def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)):
 
 
 @app.get("/reset/db")
-def reset_vector_db():
-    CHROMA_CLIENT.reset()
+def reset_vector_db(user=Depends(get_current_user)):
+    if user.role == "admin":
+        CHROMA_CLIENT.reset()
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
 
 
 @app.get("/reset")
-def reset():
-    folder = f"{UPLOAD_DIR}"
-    for filename in os.listdir(folder):
-        file_path = os.path.join(folder, filename)
+def reset(user=Depends(get_current_user)):
+    if user.role == "admin":
+        folder = f"{UPLOAD_DIR}"
+        for filename in os.listdir(folder):
+            file_path = os.path.join(folder, filename)
+            try:
+                if os.path.isfile(file_path) or os.path.islink(file_path):
+                    os.unlink(file_path)
+                elif os.path.isdir(file_path):
+                    shutil.rmtree(file_path)
+            except Exception as e:
+                print("Failed to delete %s. Reason: %s" % (file_path, e))
+
         try:
-            if os.path.isfile(file_path) or os.path.islink(file_path):
-                os.unlink(file_path)
-            elif os.path.isdir(file_path):
-                shutil.rmtree(file_path)
+            CHROMA_CLIENT.reset()
         except Exception as e:
-            print("Failed to delete %s. Reason: %s" % (file_path, e))
+            print(e)
 
-    try:
-        CHROMA_CLIENT.reset()
-    except Exception as e:
-        print(e)
-
-    return {"status": True}
+        return {"status": True}
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )