Browse Source

refac: knowledge access control

Timothy Jaeryang Baek 5 months ago
parent
commit
c35f8c9673
1 changed files with 105 additions and 20 deletions
  1. 105 20
      backend/open_webui/apps/webui/routers/knowledge.py

+ 105 - 20
backend/open_webui/apps/webui/routers/knowledge.py

@@ -17,6 +17,9 @@ from open_webui.apps.retrieval.main import process_file, ProcessFileForm
 
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.utils.utils import get_admin_user, get_verified_user
+from open_webui.utils.access_control import has_access
+
+
 from open_webui.env import SRC_LOG_LEVELS
 
 
@@ -154,13 +157,20 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
     knowledge = Knowledges.get_knowledge_by_id(id=id)
 
     if knowledge:
-        file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
-        files = Files.get_files_by_ids(file_ids)
 
-        return KnowledgeFilesResponse(
-            **knowledge.model_dump(),
-            files=files,
-        )
+        if (
+            user.role == "admin"
+            or knowledge.user_id == user.id
+            or has_access(user.id, "read", knowledge.access_control)
+        ):
+
+            file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
+            files = Files.get_files_by_ids(file_ids)
+
+            return KnowledgeFilesResponse(
+                **knowledge.model_dump(),
+                files=files,
+            )
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
@@ -179,8 +189,20 @@ async def update_knowledge_by_id(
     form_data: KnowledgeUpdateForm,
     user=Depends(get_verified_user),
 ):
-    knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
+    knowledge = Knowledges.get_knowledge_by_id(id=id)
+    if not knowledge:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+    if knowledge.user_id != user.id and user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
 
+    knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
     if knowledge:
         file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
         files = Files.get_files_by_ids(file_ids)
@@ -212,6 +234,19 @@ def add_file_to_knowledge_by_id(
     user=Depends(get_verified_user),
 ):
     knowledge = Knowledges.get_knowledge_by_id(id=id)
+
+    if not knowledge:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+    if knowledge.user_id != user.id and user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
     file = Files.get_file_by_id(form_data.file_id)
     if not file:
         raise HTTPException(
@@ -277,6 +312,18 @@ def update_file_from_knowledge_by_id(
     user=Depends(get_verified_user),
 ):
     knowledge = Knowledges.get_knowledge_by_id(id=id)
+    if not knowledge:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+    if knowledge.user_id != user.id and user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
     file = Files.get_file_by_id(form_data.file_id)
     if not file:
         raise HTTPException(
@@ -327,6 +374,18 @@ def remove_file_from_knowledge_by_id(
     user=Depends(get_verified_user),
 ):
     knowledge = Knowledges.get_knowledge_by_id(id=id)
+    if not knowledge:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+    if knowledge.user_id != user.id and user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
     file = Files.get_file_by_id(form_data.file_id)
     if not file:
         raise HTTPException(
@@ -383,35 +442,61 @@ def remove_file_from_knowledge_by_id(
 
 
 ############################
-# ResetKnowledgeById
+# DeleteKnowledgeById
 ############################
 
 
-@router.post("/{id}/reset", response_model=Optional[KnowledgeResponse])
-async def reset_knowledge_by_id(id: str, user=Depends(get_admin_user)):
+@router.delete("/{id}/delete", response_model=bool)
+async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
+    knowledge = Knowledges.get_knowledge_by_id(id=id)
+    if not knowledge:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+    if knowledge.user_id != user.id and user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
     try:
         VECTOR_DB_CLIENT.delete_collection(collection_name=id)
     except Exception as e:
         log.debug(e)
         pass
-
-    knowledge = Knowledges.update_knowledge_by_id(
-        id=id, form_data=KnowledgeUpdateForm(data={"file_ids": []})
-    )
-    return knowledge
+    result = Knowledges.delete_knowledge_by_id(id=id)
+    return result
 
 
 ############################
-# DeleteKnowledgeById
+# ResetKnowledgeById
 ############################
 
 
-@router.delete("/{id}/delete", response_model=bool)
-async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
+@router.post("/{id}/reset", response_model=Optional[KnowledgeResponse])
+async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
+    knowledge = Knowledges.get_knowledge_by_id(id=id)
+    if not knowledge:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+    if knowledge.user_id != user.id and user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
     try:
         VECTOR_DB_CLIENT.delete_collection(collection_name=id)
     except Exception as e:
         log.debug(e)
         pass
-    result = Knowledges.delete_knowledge_by_id(id=id)
-    return result
+
+    knowledge = Knowledges.update_knowledge_by_id(
+        id=id, form_data=KnowledgeUpdateForm(data={"file_ids": []})
+    )
+    return knowledge