Browse Source

enh: add/remove file from knowledge

Timothy J. Baek 7 tháng trước cách đây
mục cha
commit
78413d0c2e

+ 2 - 20
backend/open_webui/apps/webui/models/knowledge.py

@@ -126,28 +126,10 @@ class KnowledgeTable:
     ) -> Optional[KnowledgeModel]:
     ) -> Optional[KnowledgeModel]:
         try:
         try:
             with get_db() as db:
             with get_db() as db:
+                knowledge = self.get_knowledge_by_id(id=id)
                 db.query(Knowledge).filter_by(id=id).update(
                 db.query(Knowledge).filter_by(id=id).update(
                     {
                     {
-                        **({"name": form_data.name} if form_data.name else {}),
-                        **(
-                            {"description": form_data.description}
-                            if form_data.description
-                            else {}
-                        ),
-                        **(
-                            {
-                                "data": (
-                                    form_data.data
-                                    if overwrite
-                                    else {
-                                        **(self.get_knowledge_by_id(id=id)).data,
-                                        **form_data.data,
-                                    }
-                                )
-                            }
-                            if form_data.data
-                            else {}
-                        ),
+                        **form_data.model_dump(exclude_none=True),
                         "updated_at": int(time.time()),
                         "updated_at": int(time.time()),
                     }
                     }
                 )
                 )

+ 133 - 2
backend/open_webui/apps/webui/routers/knowledge.py

@@ -15,6 +15,9 @@ from open_webui.apps.webui.models.files import Files, FileModel
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.utils.utils import get_admin_user, get_verified_user
 from open_webui.utils.utils import get_admin_user, get_verified_user
 
 
+from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
+
+
 router = APIRouter()
 router = APIRouter()
 
 
 ############################
 ############################
@@ -96,7 +99,7 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
 ############################
 ############################
 
 
 
 
-@router.post("/{id}/update", response_model=Optional[KnowledgeResponse])
+@router.post("/{id}/update", response_model=Optional[KnowledgeFilesResponse])
 async def update_knowledge_by_id(
 async def update_knowledge_by_id(
     id: str,
     id: str,
     form_data: KnowledgeUpdateForm,
     form_data: KnowledgeUpdateForm,
@@ -105,7 +108,13 @@ async def update_knowledge_by_id(
     knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
     knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
 
 
     if knowledge:
     if knowledge:
-        return knowledge
+        file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
+        files = Files.get_files_by_ids(file_ids)
+
+        return KnowledgeFilesResponse(
+            **knowledge.model_dump(),
+            files=files,
+        )
     else:
     else:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
             status_code=status.HTTP_400_BAD_REQUEST,
@@ -113,6 +122,128 @@ async def update_knowledge_by_id(
         )
         )
 
 
 
 
+############################
+# AddFileToKnowledge
+############################
+
+
+class KnowledgeFileIdForm(BaseModel):
+    file_id: str
+
+
+@router.post("/{id}/file/add", response_model=Optional[KnowledgeFilesResponse])
+async def add_file_to_knowledge_by_id(
+    id: str,
+    form_data: KnowledgeFileIdForm,
+    user=Depends(get_admin_user),
+):
+    knowledge = Knowledges.get_knowledge_by_id(id=id)
+    file = Files.get_file_by_id(form_data.file_id)
+    if not file:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+    if knowledge:
+        data = knowledge.data or {}
+        file_ids = data.get("file_ids", [])
+
+        if form_data.file_id not in file_ids:
+            file_ids.append(form_data.file_id)
+            data["file_ids"] = file_ids
+
+            knowledge = Knowledges.update_knowledge_by_id(
+                id=id, form_data=KnowledgeUpdateForm(data=data)
+            )
+
+            if knowledge:
+                files = Files.get_files_by_ids(file_ids)
+
+                return KnowledgeFilesResponse(
+                    **knowledge.model_dump(),
+                    files=files,
+                )
+            else:
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT("knowledge"),
+                )
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("file_id"),
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
+############################
+# RemoveFileFromKnowledge
+############################
+
+
+class KnowledgeFileIdForm(BaseModel):
+    file_id: str
+
+
+@router.post("/{id}/file/remove", response_model=Optional[KnowledgeFilesResponse])
+async def remove_file_from_knowledge_by_id(
+    id: str,
+    form_data: KnowledgeFileIdForm,
+    user=Depends(get_admin_user),
+):
+    knowledge = Knowledges.get_knowledge_by_id(id=id)
+    file = Files.get_file_by_id(form_data.file_id)
+    if not file:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+    VECTOR_DB_CLIENT.delete(
+        collection_name=knowledge.id, filter={"file_id": form_data.file_id}
+    )
+
+    if knowledge:
+        data = knowledge.data or {}
+        file_ids = data.get("file_ids", [])
+
+        if form_data.file_id in file_ids:
+            file_ids.remove(form_data.file_id)
+            data["file_ids"] = file_ids
+
+            knowledge = Knowledges.update_knowledge_by_id(
+                id=id, form_data=KnowledgeUpdateForm(data=data)
+            )
+
+            if knowledge:
+                files = Files.get_files_by_ids(file_ids)
+
+                return KnowledgeFilesResponse(
+                    **knowledge.model_dump(),
+                    files=files,
+                )
+            else:
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT("knowledge"),
+                )
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("file_id"),
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
 ############################
 ############################
 # DeleteKnowledgeById
 # DeleteKnowledgeById
 ############################
 ############################