Przeglądaj źródła

feat: RAG text ingestion(store) api

Timothy J. Baek 1 rok temu
rodzic
commit
7e0ea8f77d
2 zmienionych plików z 79 dodań i 33 usunięć
  1. 77 33
      backend/apps/rag/main.py
  2. 2 0
      backend/apps/rag/utils.py

+ 77 - 33
backend/apps/rag/main.py

@@ -111,39 +111,6 @@ class StoreWebForm(CollectionNameForm):
     url: str
 
 
-def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
-    text_splitter = RecursiveCharacterTextSplitter(
-        chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
-    )
-    docs = text_splitter.split_documents(data)
-
-    texts = [doc.page_content for doc in docs]
-    metadatas = [doc.metadata for doc in docs]
-
-    try:
-        if overwrite:
-            for collection in CHROMA_CLIENT.list_collections():
-                if collection_name == collection.name:
-                    print(f"deleting existing collection {collection_name}")
-                    CHROMA_CLIENT.delete_collection(name=collection_name)
-
-        collection = CHROMA_CLIENT.create_collection(
-            name=collection_name,
-            embedding_function=app.state.sentence_transformer_ef,
-        )
-
-        collection.add(
-            documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
-        )
-        return True
-    except Exception as e:
-        print(e)
-        if e.__class__.__name__ == "UniqueConstraintError":
-            return True
-
-        return False
-
-
 @app.get("/")
 async def get_status():
     return {
@@ -325,6 +292,56 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
         )
 
 
+def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
+    text_splitter = RecursiveCharacterTextSplitter(
+        chunk_size=app.state.CHUNK_SIZE,
+        chunk_overlap=app.state.CHUNK_OVERLAP,
+        add_start_index=True,
+    )
+    docs = text_splitter.split_documents(data)
+    return store_docs_in_vector_db(docs, collection_name, overwrite)
+
+
+def store_text_in_vector_db(
+    text, name, collection_name, overwrite: bool = False
+) -> bool:
+    text_splitter = RecursiveCharacterTextSplitter(
+        chunk_size=app.state.CHUNK_SIZE,
+        chunk_overlap=app.state.CHUNK_OVERLAP,
+        add_start_index=True,
+    )
+    docs = text_splitter.create_documents([text], metadatas=[{"name": name}])
+    return store_docs_in_vector_db(docs, collection_name, overwrite)
+
+
+def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
+    texts = [doc.page_content for doc in docs]
+    metadatas = [doc.metadata for doc in docs]
+
+    try:
+        if overwrite:
+            for collection in CHROMA_CLIENT.list_collections():
+                if collection_name == collection.name:
+                    print(f"deleting existing collection {collection_name}")
+                    CHROMA_CLIENT.delete_collection(name=collection_name)
+
+        collection = CHROMA_CLIENT.create_collection(
+            name=collection_name,
+            embedding_function=app.state.sentence_transformer_ef,
+        )
+
+        collection.add(
+            documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
+        )
+        return True
+    except Exception as e:
+        print(e)
+        if e.__class__.__name__ == "UniqueConstraintError":
+            return True
+
+        return False
+
+
 def get_loader(filename: str, file_content_type: str, file_path: str):
     file_ext = filename.split(".")[-1].lower()
     known_type = True
@@ -460,6 +477,33 @@ def store_doc(
             )
 
 
+class TextRAGForm(BaseModel):
+    name: str
+    content: str
+    collection_name: Optional[str] = None
+
+
+@app.post("/text")
+def store_text(
+    form_data: TextRAGForm,
+    user=Depends(get_current_user),
+):
+
+    collection_name = form_data.collection_name
+    if collection_name == None:
+        collection_name = calculate_sha256_string(form_data.content)
+
+    result = store_text_in_vector_db(form_data.content, form_data.name, collection_name)
+
+    if result:
+        return {"status": True, "collection_name": collection_name}
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            detail=ERROR_MESSAGES.DEFAULT(),
+        )
+
+
 @app.get("/scan")
 def scan_docs_dir(user=Depends(get_admin_user)):
     for path in Path(DOCS_DIR).rglob("./**/*"):

+ 2 - 0
backend/apps/rag/utils.py

@@ -137,6 +137,8 @@ def rag_messages(docs, messages, template, k, embedding_function):
                     k=k,
                     embedding_function=embedding_function,
                 )
+            elif doc["type"] == "text":
+                context = doc["content"]
             else:
                 context = query_doc(
                     collection_name=doc["collection_name"],