main.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from fastapi import FastAPI, Request, Depends, HTTPException, status, UploadFile, File
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from chromadb.utils import embedding_functions
  4. from langchain.document_loaders import WebBaseLoader, TextLoader
  5. from langchain.text_splitter import RecursiveCharacterTextSplitter
  6. from langchain_community.vectorstores import Chroma
  7. from langchain.chains import RetrievalQA
  8. from pydantic import BaseModel
  9. from typing import Optional
  10. import uuid
  11. from config import EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
  12. from constants import ERROR_MESSAGES
  13. EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
  14. model_name=EMBED_MODEL
  15. )
  16. app = FastAPI()
  17. origins = ["*"]
  18. app.add_middleware(
  19. CORSMiddleware,
  20. allow_origins=origins,
  21. allow_credentials=True,
  22. allow_methods=["*"],
  23. allow_headers=["*"],
  24. )
  25. class StoreWebForm(BaseModel):
  26. url: str
  27. collection_name: Optional[str] = "test"
  28. def store_data_in_vector_db(data, collection_name):
  29. text_splitter = RecursiveCharacterTextSplitter(
  30. chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
  31. )
  32. docs = text_splitter.split_documents(data)
  33. texts = [doc.page_content for doc in docs]
  34. metadatas = [doc.metadata for doc in docs]
  35. collection = CHROMA_CLIENT.create_collection(
  36. name=collection_name, embedding_function=EMBEDDING_FUNC
  37. )
  38. collection.add(
  39. documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
  40. )
  41. @app.get("/")
  42. async def get_status():
  43. return {"status": True}
  44. @app.get("/query/{collection_name}")
  45. def query_collection(collection_name: str, query: str, k: Optional[int] = 4):
  46. collection = CHROMA_CLIENT.get_collection(
  47. name=collection_name,
  48. )
  49. result = collection.query(query_texts=[query], n_results=k)
  50. return result
  51. @app.post("/web")
  52. def store_web(form_data: StoreWebForm):
  53. # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
  54. try:
  55. loader = WebBaseLoader(form_data.url)
  56. data = loader.load()
  57. store_data_in_vector_db(data, form_data.collection_name)
  58. return {"status": True}
  59. except Exception as e:
  60. print(e)
  61. raise HTTPException(
  62. status_code=status.HTTP_400_BAD_REQUEST,
  63. detail=ERROR_MESSAGES.DEFAULT(e),
  64. )
  65. @app.post("/doc")
  66. def store_doc(file: UploadFile = File(...)):
  67. # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
  68. try:
  69. print(file)
  70. file.filename = f"{uuid.uuid4()}-{file.filename}"
  71. contents = file.file.read()
  72. with open(f"./data/{file.filename}", "wb") as f:
  73. f.write(contents)
  74. f.close()
  75. # loader = WebBaseLoader(form_data.url)
  76. # data = loader.load()
  77. # store_data_in_vector_db(data, form_data.collection_name)
  78. return {"status": True}
  79. except Exception as e:
  80. print(e)
  81. raise HTTPException(
  82. status_code=status.HTTP_400_BAD_REQUEST,
  83. detail=ERROR_MESSAGES.DEFAULT(e),
  84. )
  85. def reset_vector_db():
  86. CHROMA_CLIENT.reset()
  87. return {"status": True}