Prechádzať zdrojové kódy

enh: add to vector db support

Timothy J. Baek 7 mesiacov pred
rodič
commit
d394f8b7be
1 zmenil súbory, kde vykonal 83 pridanie a 40 odobranie
  1. 83 40
      backend/open_webui/apps/retrieval/main.py

+ 83 - 40
backend/open_webui/apps/retrieval/main.py

@@ -637,6 +637,7 @@ def save_docs_to_vector_db(
     metadata: Optional[dict] = None,
     overwrite: bool = False,
     split: bool = True,
+    add: bool = False,
 ) -> bool:
     log.info(f"save_docs_to_vector_db {docs} {collection_name}")
 
@@ -662,42 +663,44 @@ def save_docs_to_vector_db(
                 metadata[key] = str(value)
 
     try:
-        if overwrite:
-            if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
-                log.info(f"deleting existing collection {collection_name}")
-                VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
-
         if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
             log.info(f"collection {collection_name} already exists")
-            return True
-        else:
-            embedding_function = get_embedding_function(
-                app.state.config.RAG_EMBEDDING_ENGINE,
-                app.state.config.RAG_EMBEDDING_MODEL,
-                app.state.sentence_transformer_ef,
-                app.state.config.OPENAI_API_KEY,
-                app.state.config.OPENAI_API_BASE_URL,
-                app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
-            )
 
-            embeddings = embedding_function(
-                list(map(lambda x: x.replace("\n", " "), texts))
-            )
+            if overwrite:
+                VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
+                log.info(f"deleting existing collection {collection_name}")
 
-            VECTOR_DB_CLIENT.insert(
-                collection_name=collection_name,
-                items=[
-                    {
-                        "id": str(uuid.uuid4()),
-                        "text": text,
-                        "vector": embeddings[idx],
-                        "metadata": metadatas[idx],
-                    }
-                    for idx, text in enumerate(texts)
-                ],
-            )
+            if add is False:
+                return True
+
+        log.info(f"adding to collection {collection_name}")
+        embedding_function = get_embedding_function(
+            app.state.config.RAG_EMBEDDING_ENGINE,
+            app.state.config.RAG_EMBEDDING_MODEL,
+            app.state.sentence_transformer_ef,
+            app.state.config.OPENAI_API_KEY,
+            app.state.config.OPENAI_API_BASE_URL,
+            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+        )
+
+        embeddings = embedding_function(
+            list(map(lambda x: x.replace("\n", " "), texts))
+        )
 
-            return True
+        VECTOR_DB_CLIENT.insert(
+            collection_name=collection_name,
+            items=[
+                {
+                    "id": str(uuid.uuid4()),
+                    "text": text,
+                    "vector": embeddings[idx],
+                    "metadata": metadatas[idx],
+                }
+                for idx, text in enumerate(texts)
+            ],
+        )
+
+        return True
     except Exception as e:
         log.exception(e)
         return False
@@ -715,37 +718,53 @@ def process_file(
 ):
     try:
         file = Files.get_file_by_id(form_data.file_id)
-        file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}")
 
         collection_name = form_data.collection_name
         if collection_name is None:
-            with open(file_path, "rb") as f:
-                collection_name = calculate_sha256(f)[:63]
+            collection_name = file.id
 
         loader = Loader(
             engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
             TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
             PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
         )
-        docs = loader.load(file.filename, file.meta.get("content_type"), file_path)
+
+        file_path = file.meta.get("path", None)
+        if file_path:
+            docs = loader.load(file.filename, file.meta.get("content_type"), file_path)
+        else:
+            docs = [
+                Document(
+                    page_content=file.data.get("content", ""),
+                    metadata={
+                        "name": file.filename,
+                        "created_by": file.user_id,
+                        **file.meta,
+                    },
+                )
+            ]
+
         text_content = " ".join([doc.page_content for doc in docs])
         log.debug(f"text_content: {text_content}")
         hash = calculate_sha256_string(text_content)
 
-        Files.update_file_data_by_id(
-            form_data.file_id,
+        res = Files.update_file_data_by_id(
+            file.id,
             {"content": text_content},
         )
+        print(res)
         Files.update_file_hash_by_id(form_data.file_id, hash)
 
         try:
             result = save_docs_to_vector_db(
-                docs,
-                collection_name,
-                {
+                docs=docs,
+                collection_name=collection_name,
+                metadata={
                     "file_id": form_data.file_id,
                     "name": file.meta.get("name", file.filename),
+                    "hash": hash,
                 },
+                add=(True if form_data.collection_name else False),
             )
 
             if result:
@@ -1184,6 +1203,30 @@ def query_collection_handler(
 ####################################
 
 
+class DeleteForm(BaseModel):
+    collection_name: str
+    file_id: str
+
+
+@app.post("/delete")
+def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)):
+    try:
+        if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
+            file = Files.get_file_by_id(form_data.file_id)
+            hash = file.hash
+
+            VECTOR_DB_CLIENT.delete(
+                collection_name=form_data.collection_name,
+                metadata={"hash": hash},
+            )
+            return {"status": True}
+        else:
+            return {"status": False}
+    except Exception as e:
+        log.exception(e)
+        return {"status": False}
+
+
 @app.post("/reset/db")
 def reset_vector_db(user=Depends(get_admin_user)):
     VECTOR_DB_CLIENT.reset()