|
@@ -21,6 +21,7 @@ from langchain_community.document_loaders import (
|
|
Docx2txtLoader,
|
|
Docx2txtLoader,
|
|
UnstructuredWordDocumentLoader,
|
|
UnstructuredWordDocumentLoader,
|
|
UnstructuredMarkdownLoader,
|
|
UnstructuredMarkdownLoader,
|
|
|
|
+ UnstructuredXMLLoader,
|
|
)
|
|
)
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain_community.vectorstores import Chroma
|
|
from langchain_community.vectorstores import Chroma
|
|
@@ -147,6 +148,9 @@ def store_doc(
|
|
"application/pdf",
|
|
"application/pdf",
|
|
"text/plain",
|
|
"text/plain",
|
|
"text/csv",
|
|
"text/csv",
|
|
|
|
+ "text/xml",
|
|
|
|
+ "text/html",
|
|
|
|
+ "text/x-python",
|
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
"application/octet-stream",
|
|
"application/octet-stream",
|
|
]:
|
|
]:
|
|
@@ -154,10 +158,17 @@ def store_doc(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
|
|
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
|
|
)
|
|
)
|
|
-
|
|
|
|
- if file.content_type == "application/octet-stream" and file.filename.split(".")[
|
|
|
|
- -1
|
|
|
|
- ] not in ["md"]:
|
|
|
|
|
|
+ text_xml=["text/html", "text/xml"]
|
|
|
|
+ octet_markdown=["md"]
|
|
|
|
+ octet_plain=[
|
|
|
|
+ "go", "py", "java", "sh", "bat", "ps1", "cmd", "js",
|
|
|
|
+ "css", "cpp", "hpp","h", "c", "cs", "sql", "log", "ini",
|
|
|
|
+ "pl" "pm", "r", "dart", "dockerfile", "env", "php", "hs",
|
|
|
|
+ "hsc", "lua", "nginxconf", "conf", "m", "mm", "plsql", "perl",
|
|
|
|
+ "rb", "rs", "db2", "scala", "bash", "swift", "vue"
|
|
|
|
+ ]
|
|
|
|
+ file_ext=file.filename.split(".")[-1].lower()
|
|
|
|
+ if file.content_type == "application/octet-stream" and file_ext not in (octet_markdown + octet_plain):
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
|
|
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
|
|
@@ -183,13 +194,18 @@ def store_doc(
|
|
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
|
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
|
):
|
|
):
|
|
loader = Docx2txtLoader(file_path)
|
|
loader = Docx2txtLoader(file_path)
|
|
- elif file.content_type == "text/plain":
|
|
|
|
- loader = TextLoader(file_path)
|
|
|
|
|
|
+
|
|
elif file.content_type == "text/csv":
|
|
elif file.content_type == "text/csv":
|
|
loader = CSVLoader(file_path)
|
|
loader = CSVLoader(file_path)
|
|
|
|
+ elif file.content_type in text_xml:
|
|
|
|
+ loader=UnstructuredXMLLoader(file_path)
|
|
|
|
+ elif file.content_type == "text/plain" or file.content_type.find("text/")>=0:
|
|
|
|
+ loader = TextLoader(file_path)
|
|
elif file.content_type == "application/octet-stream":
|
|
elif file.content_type == "application/octet-stream":
|
|
- if file.filename.split(".")[-1] == "md":
|
|
|
|
|
|
+ if file_ext in octet_markdown:
|
|
loader = UnstructuredMarkdownLoader(file_path)
|
|
loader = UnstructuredMarkdownLoader(file_path)
|
|
|
|
+ if file_ext in octet_plain:
|
|
|
|
+ loader = TextLoader(file_path)
|
|
|
|
|
|
data = loader.load()
|
|
data = loader.load()
|
|
result = store_data_in_vector_db(data, collection_name)
|
|
result = store_data_in_vector_db(data, collection_name)
|