|
@@ -62,6 +62,7 @@ from config import (
|
|
|
CHROMA_CLIENT,
|
|
|
CHUNK_SIZE,
|
|
|
CHUNK_OVERLAP,
|
|
|
+ RAG_TEMPLATE,
|
|
|
)
|
|
|
from constants import ERROR_MESSAGES
|
|
|
|
|
@@ -71,6 +72,11 @@ from constants import ERROR_MESSAGES
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
+app.state.CHUNK_SIZE = CHUNK_SIZE
|
|
|
+app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
|
|
+app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
|
|
+
|
|
|
+
|
|
|
origins = ["*"]
|
|
|
|
|
|
app.add_middleware(
|
|
@@ -92,7 +98,7 @@ class StoreWebForm(CollectionNameForm):
|
|
|
|
|
|
def store_data_in_vector_db(data, collection_name) -> bool:
|
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
|
- chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
|
|
|
+ chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
|
|
|
)
|
|
|
docs = text_splitter.split_documents(data)
|
|
|
|
|
@@ -116,7 +122,60 @@ def store_data_in_vector_db(data, collection_name) -> bool:
|
|
|
|
|
|
@app.get("/")
|
|
|
async def get_status():
|
|
|
- return {"status": True}
|
|
|
+ return {
|
|
|
+ "status": True,
|
|
|
+ "chunk_size": app.state.CHUNK_SIZE,
|
|
|
+ "chunk_overlap": app.state.CHUNK_OVERLAP,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+@app.get("/chunk")
|
|
|
+async def get_chunk_params(user=Depends(get_admin_user)):
|
|
|
+ return {
|
|
|
+ "status": True,
|
|
|
+ "chunk_size": app.state.CHUNK_SIZE,
|
|
|
+ "chunk_overlap": app.state.CHUNK_OVERLAP,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+class ChunkParamUpdateForm(BaseModel):
|
|
|
+ chunk_size: int
|
|
|
+ chunk_overlap: int
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/chunk/update")
|
|
|
+async def update_chunk_params(
|
|
|
+ form_data: ChunkParamUpdateForm, user=Depends(get_admin_user)
|
|
|
+):
|
|
|
+ app.state.CHUNK_SIZE = form_data.chunk_size
|
|
|
+ app.state.CHUNK_OVERLAP = form_data.chunk_overlap
|
|
|
+
|
|
|
+ return {
|
|
|
+ "status": True,
|
|
|
+ "chunk_size": app.state.CHUNK_SIZE,
|
|
|
+ "chunk_overlap": app.state.CHUNK_OVERLAP,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+@app.get("/template")
|
|
|
+async def get_rag_template(user=Depends(get_current_user)):
|
|
|
+ return {
|
|
|
+ "status": True,
|
|
|
+ "template": app.state.RAG_TEMPLATE,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+class RAGTemplateForm(BaseModel):
|
|
|
+ template: str
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/template/update")
|
|
|
+async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)):
|
|
|
+ # TODO: check template requirements
|
|
|
+ app.state.RAG_TEMPLATE = (
|
|
|
+ form_data.template if form_data.template != "" else RAG_TEMPLATE
|
|
|
+ )
|
|
|
+ return {"status": True, "template": app.state.RAG_TEMPLATE}
|
|
|
|
|
|
|
|
|
class QueryDocForm(BaseModel):
|