123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314 |
- from fastapi import (
- FastAPI,
- Request,
- Depends,
- HTTPException,
- status,
- UploadFile,
- File,
- Form,
- )
- from fastapi.middleware.cors import CORSMiddleware
- import os, shutil
- # from chromadb.utils import embedding_functions
- from langchain_community.document_loaders import (
- WebBaseLoader,
- TextLoader,
- PyPDFLoader,
- CSVLoader,
- Docx2txtLoader,
- UnstructuredEPubLoader,
- UnstructuredWordDocumentLoader,
- UnstructuredMarkdownLoader,
- UnstructuredXMLLoader,
- UnstructuredRSTLoader,
- UnstructuredExcelLoader,
- )
- from langchain.text_splitter import RecursiveCharacterTextSplitter
- from langchain_community.vectorstores import Chroma
- from langchain.chains import RetrievalQA
- from pydantic import BaseModel
- from typing import Optional
- import uuid
- import time
- from utils.misc import calculate_sha256, calculate_sha256_string
- from utils.utils import get_current_user
- from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
- from constants import ERROR_MESSAGES
- # EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
- # model_name=EMBED_MODEL
- # )
- app = FastAPI()
- origins = ["*"]
- app.add_middleware(
- CORSMiddleware,
- allow_origins=origins,
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- class CollectionNameForm(BaseModel):
- collection_name: Optional[str] = "test"
- class StoreWebForm(CollectionNameForm):
- url: str
- def store_data_in_vector_db(data, collection_name) -> bool:
- text_splitter = RecursiveCharacterTextSplitter(
- chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
- )
- docs = text_splitter.split_documents(data)
- texts = [doc.page_content for doc in docs]
- metadatas = [doc.metadata for doc in docs]
- try:
- collection = CHROMA_CLIENT.create_collection(name=collection_name)
- collection.add(
- documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
- )
- return True
- except Exception as e:
- print(e)
- if e.__class__.__name__ == "UniqueConstraintError":
- return True
- return False
- @app.get("/")
- async def get_status():
- return {"status": True}
- @app.get("/query/{collection_name}")
- def query_collection(
- collection_name: str,
- query: str,
- k: Optional[int] = 4,
- user=Depends(get_current_user),
- ):
- try:
- collection = CHROMA_CLIENT.get_collection(
- name=collection_name,
- )
- result = collection.query(query_texts=[query], n_results=k)
- return result
- except Exception as e:
- print(e)
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=ERROR_MESSAGES.DEFAULT(e),
- )
- @app.post("/web")
- def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
- # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
- try:
- loader = WebBaseLoader(form_data.url)
- data = loader.load()
- collection_name = form_data.collection_name
- if collection_name == "":
- collection_name = calculate_sha256_string(form_data.url)[:63]
- store_data_in_vector_db(data, collection_name)
- return {
- "status": True,
- "collection_name": collection_name,
- "filename": form_data.url,
- }
- except Exception as e:
- print(e)
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=ERROR_MESSAGES.DEFAULT(e),
- )
- def get_loader(file, file_path):
- file_ext = file.filename.split(".")[-1].lower()
- known_type = True
- known_source_ext = [
- "go",
- "py",
- "java",
- "sh",
- "bat",
- "ps1",
- "cmd",
- "js",
- "ts",
- "css",
- "cpp",
- "hpp",
- "h",
- "c",
- "cs",
- "sql",
- "log",
- "ini",
- "pl",
- "pm",
- "r",
- "dart",
- "dockerfile",
- "env",
- "php",
- "hs",
- "hsc",
- "lua",
- "nginxconf",
- "conf",
- "m",
- "mm",
- "plsql",
- "perl",
- "rb",
- "rs",
- "db2",
- "scala",
- "bash",
- "swift",
- "vue",
- "svelte",
- ]
- if file_ext == "pdf":
- loader = PyPDFLoader(file_path)
- elif file_ext == "csv":
- loader = CSVLoader(file_path)
- elif file_ext == "rst":
- loader = UnstructuredRSTLoader(file_path, mode="elements")
- elif file_ext == "xml":
- loader = UnstructuredXMLLoader(file_path)
- elif file_ext == "md":
- loader = UnstructuredMarkdownLoader(file_path)
- elif file.content_type == "application/epub+zip":
- loader = UnstructuredEPubLoader(file_path)
- elif (
- file.content_type
- == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
- or file_ext in ["doc", "docx"]
- ):
- loader = Docx2txtLoader(file_path)
- elif file.content_type in [
- "application/vnd.ms-excel",
- "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
- ] or file_ext in ["xls", "xlsx"]:
- loader = UnstructuredExcelLoader(file_path)
- elif file_ext in known_source_ext or file.content_type.find("text/") >= 0:
- loader = TextLoader(file_path)
- else:
- loader = TextLoader(file_path)
- known_type = False
- return loader, known_type
- @app.post("/doc")
- def store_doc(
- collection_name: Optional[str] = Form(None),
- file: UploadFile = File(...),
- user=Depends(get_current_user),
- ):
- # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
- print(file.content_type)
- try:
- filename = file.filename
- file_path = f"{UPLOAD_DIR}/{filename}"
- contents = file.file.read()
- with open(file_path, "wb") as f:
- f.write(contents)
- f.close()
- f = open(file_path, "rb")
- if collection_name == None:
- collection_name = calculate_sha256(f)[:63]
- f.close()
- loader, known_type = get_loader(file, file_path)
- data = loader.load()
- result = store_data_in_vector_db(data, collection_name)
- if result:
- return {
- "status": True,
- "collection_name": collection_name,
- "filename": filename,
- "known_type": known_type,
- }
- else:
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=ERROR_MESSAGES.DEFAULT(),
- )
- except Exception as e:
- print(e)
- if "No pandoc was found" in str(e):
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
- )
- else:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=ERROR_MESSAGES.DEFAULT(e),
- )
- @app.get("/reset/db")
- def reset_vector_db(user=Depends(get_current_user)):
- if user.role == "admin":
- CHROMA_CLIENT.reset()
- else:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
- )
- @app.get("/reset")
- def reset(user=Depends(get_current_user)) -> bool:
- if user.role == "admin":
- folder = f"{UPLOAD_DIR}"
- for filename in os.listdir(folder):
- file_path = os.path.join(folder, filename)
- try:
- if os.path.isfile(file_path) or os.path.islink(file_path):
- os.unlink(file_path)
- elif os.path.isdir(file_path):
- shutil.rmtree(file_path)
- except Exception as e:
- print("Failed to delete %s. Reason: %s" % (file_path, e))
- try:
- CHROMA_CLIENT.reset()
- except Exception as e:
- print(e)
- return True
- else:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
- )
|