Timothy J. Baek 7 月之前
父節點
當前提交
2428878f42
共有 1 個文件被更改,包括 16 次插入80 次删除
  1. 16 80
      backend/open_webui/apps/retrieval/main.py

+ 16 - 80
backend/open_webui/apps/retrieval/main.py

@@ -246,10 +246,10 @@ app.add_middleware(
 
 
 class CollectionNameForm(BaseModel):
-    collection_name: Optional[str] = "test"
+    collection_name: Optional[str] = None
 
 
-class UrlForm(CollectionNameForm):
+class ProcessUrlForm(CollectionNameForm):
     url: str
 
 
@@ -636,7 +636,6 @@ def store_data_in_vector_db(
         chunk_overlap=app.state.config.CHUNK_OVERLAP,
         add_start_index=True,
     )
-
     docs = text_splitter.split_documents(data)
 
     if len(docs) > 0:
@@ -715,66 +714,6 @@ def store_docs_in_vector_db(
         return False
 
 
-@app.post("/doc")
-def store_doc(
-    collection_name: Optional[str] = Form(None),
-    file: UploadFile = File(...),
-    user=Depends(get_verified_user),
-):
-    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
-
-    log.info(f"file.content_type: {file.content_type}")
-    try:
-        unsanitized_filename = file.filename
-        filename = os.path.basename(unsanitized_filename)
-
-        file_path = f"{UPLOAD_DIR}/{filename}"
-
-        contents = file.file.read()
-        with open(file_path, "wb") as f:
-            f.write(contents)
-            f.close()
-
-        f = open(file_path, "rb")
-        if collection_name is None:
-            collection_name = calculate_sha256(f)[:63]
-        f.close()
-
-        loader = Loader(
-            engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
-            TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
-            PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
-        )
-        data = loader.load(filename, file.content_type, file_path)
-
-        try:
-            result = store_data_in_vector_db(data, collection_name)
-
-            if result:
-                return {
-                    "status": True,
-                    "collection_name": collection_name,
-                    "filename": filename,
-                }
-        except Exception as e:
-            raise HTTPException(
-                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
-                detail=e,
-            )
-    except Exception as e:
-        log.exception(e)
-        if "No pandoc was found" in str(e):
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
-            )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.DEFAULT(e),
-            )
-
-
 class ProcessFileForm(BaseModel):
     file_id: str
     collection_name: Optional[str] = None
@@ -796,11 +735,10 @@ def process_file(
         )
         data = loader.load(file.filename, file.meta.get("content_type"), file_path)
 
-        f = open(file_path, "rb")
         collection_name = form_data.collection_name
         if collection_name is None:
-            collection_name = calculate_sha256(f)[:63]
-        f.close()
+            with open(file_path, "rb") as f:
+                collection_name = calculate_sha256(f)[:63]
 
         try:
             result = store_data_in_vector_db(
@@ -813,11 +751,9 @@ def process_file(
             )
 
             if result:
-
                 return {
                     "status": True,
                     "collection_name": collection_name,
-                    "known_type": known_type,
                     "filename": file.meta.get("name", file.filename),
                 }
         except Exception as e:
@@ -839,15 +775,15 @@ def process_file(
             )
 
 
-class TextRAGForm(BaseModel):
+class ProcessTextForm(BaseModel):
     name: str
     content: str
     collection_name: Optional[str] = None
 
 
-@app.post("/text")
-def store_text(
-    form_data: TextRAGForm,
+@app.post("/process/text")
+def process_text(
+    form_data: ProcessTextForm,
     user=Depends(get_verified_user),
 ):
     collection_name = form_data.collection_name
@@ -878,9 +814,8 @@ def process_docs_dir(user=Depends(get_admin_user)):
                 filename = path.name
                 file_content_type = mimetypes.guess_type(path)
 
-                f = open(path, "rb")
-                collection_name = calculate_sha256(f)[:63]
-                f.close()
+                with open(path, "rb") as f:
+                    collection_name = calculate_sha256(f)[:63]
 
                 loader = Loader(
                     engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
@@ -933,7 +868,7 @@ def process_docs_dir(user=Depends(get_admin_user)):
 
 
 @app.post("/process/youtube")
-def process_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)):
+def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_user)):
     try:
         loader = YoutubeLoader.from_youtube_url(
             form_data.url,
@@ -944,10 +879,11 @@ def process_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)):
         data = loader.load()
 
         collection_name = form_data.collection_name
-        if collection_name == "":
+        if not collection_name:
             collection_name = calculate_sha256_string(form_data.url)[:63]
 
         store_data_in_vector_db(data, collection_name, overwrite=True)
+
         return {
             "status": True,
             "collection_name": collection_name,
@@ -962,8 +898,7 @@ def process_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)):
 
 
 @app.post("/process/web")
-def process_web(form_data: UrlForm, user=Depends(get_verified_user)):
-    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
+def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)):
     try:
         loader = get_web_loader(
             form_data.url,
@@ -973,10 +908,11 @@ def process_web(form_data: UrlForm, user=Depends(get_verified_user)):
         data = loader.load()
 
         collection_name = form_data.collection_name
-        if collection_name == "":
+        if not collection_name:
             collection_name = calculate_sha256_string(form_data.url)[:63]
 
         store_data_in_vector_db(data, collection_name, overwrite=True)
+
         return {
             "status": True,
             "collection_name": collection_name,