Browse Source

feat: Add RAG support for various programming languages

Enables RAG for golang, python, java, sh, bat, powershell, cmd, js, css, c/c++/c#, sql, logs, ini, perl, r, dart, docker, env, php, haskell, lua, conf, plsql, ruby, db2, scalla, bash, swift, vue, html, xml, and other arbitrary text files.
Marclass 1 năm trước cách đây
mục cha
commit
43d8466677
1 tập tin đã thay đổi với 23 bổ sung7 xóa
  1. 23 7
      backend/apps/rag/main.py

+ 23 - 7
backend/apps/rag/main.py

@@ -21,6 +21,7 @@ from langchain_community.document_loaders import (
     Docx2txtLoader,
     Docx2txtLoader,
     UnstructuredWordDocumentLoader,
     UnstructuredWordDocumentLoader,
     UnstructuredMarkdownLoader,
     UnstructuredMarkdownLoader,
+    UnstructuredXMLLoader,
 )
 )
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 from langchain_community.vectorstores import Chroma
 from langchain_community.vectorstores import Chroma
@@ -147,6 +148,9 @@ def store_doc(
         "application/pdf",
         "application/pdf",
         "text/plain",
         "text/plain",
         "text/csv",
         "text/csv",
+        "text/xml",
+        "text/html",
+        "text/x-python",
         "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
         "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
         "application/octet-stream",
         "application/octet-stream",
     ]:
     ]:
@@ -154,10 +158,17 @@ def store_doc(
             status_code=status.HTTP_400_BAD_REQUEST,
             status_code=status.HTTP_400_BAD_REQUEST,
             detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
             detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
         )
         )
-
-    if file.content_type == "application/octet-stream" and file.filename.split(".")[
-        -1
-    ] not in ["md"]:
+    text_xml=["text/html", "text/xml"]
+    octet_markdown=["md"]
+    octet_plain=[
+        "go", "py", "java", "sh", "bat", "ps1", "cmd", "js", 
+        "css", "cpp", "hpp","h", "c", "cs", "sql", "log", "ini",
+        "pl" "pm", "r", "dart", "dockerfile", "env", "php", "hs",
+        "hsc", "lua", "nginxconf", "conf", "m", "mm", "plsql", "perl",
+        "rb", "rs", "db2", "scala", "bash", "swift", "vue"
+        ]
+    file_ext=file.filename.split(".")[-1].lower()
+    if file.content_type == "application/octet-stream" and file_ext not in (octet_markdown + octet_plain):
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
             status_code=status.HTTP_400_BAD_REQUEST,
             detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
             detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
@@ -183,13 +194,18 @@ def store_doc(
             == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
             == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
         ):
         ):
             loader = Docx2txtLoader(file_path)
             loader = Docx2txtLoader(file_path)
-        elif file.content_type == "text/plain":
-            loader = TextLoader(file_path)
+        
         elif file.content_type == "text/csv":
         elif file.content_type == "text/csv":
             loader = CSVLoader(file_path)
             loader = CSVLoader(file_path)
+        elif file.content_type in text_xml:
+            loader=UnstructuredXMLLoader(file_path)
+        elif file.content_type == "text/plain" or file.content_type.find("text/")>=0:
+            loader = TextLoader(file_path)
         elif file.content_type == "application/octet-stream":
         elif file.content_type == "application/octet-stream":
-            if file.filename.split(".")[-1] == "md":
+            if file_ext in octet_markdown:
                 loader = UnstructuredMarkdownLoader(file_path)
                 loader = UnstructuredMarkdownLoader(file_path)
+            if file_ext in octet_plain:
+                loader = TextLoader(file_path)
 
 
         data = loader.load()
         data = loader.load()
         result = store_data_in_vector_db(data, collection_name)
         result = store_data_in_vector_db(data, collection_name)