|
@@ -79,6 +79,8 @@ app.state.CHUNK_SIZE = CHUNK_SIZE
|
|
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
|
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
|
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
|
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
|
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
|
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
|
|
|
+app.state.TOP_K = 4
|
|
|
|
+
|
|
app.state.sentence_transformer_ef = (
|
|
app.state.sentence_transformer_ef = (
|
|
embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
model_name=app.state.RAG_EMBEDDING_MODEL,
|
|
model_name=app.state.RAG_EMBEDDING_MODEL,
|
|
@@ -210,23 +212,33 @@ async def get_rag_template(user=Depends(get_current_user)):
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
-class RAGTemplateForm(BaseModel):
|
|
|
|
- template: str
|
|
|
|
|
|
+@app.get("/query/settings")
|
|
|
|
+async def get_query_settings(user=Depends(get_admin_user)):
|
|
|
|
+ return {
|
|
|
|
+ "status": True,
|
|
|
|
+ "template": app.state.RAG_TEMPLATE,
|
|
|
|
+ "k": app.state.TOP_K,
|
|
|
|
+ }
|
|
|
|
|
|
|
|
|
|
-@app.post("/template/update")
|
|
|
|
-async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)):
|
|
|
|
- # TODO: check template requirements
|
|
|
|
- app.state.RAG_TEMPLATE = (
|
|
|
|
- form_data.template if form_data.template != "" else RAG_TEMPLATE
|
|
|
|
- )
|
|
|
|
|
|
+class QuerySettingsForm(BaseModel):
|
|
|
|
+ k: Optional[int] = None
|
|
|
|
+ template: Optional[str] = None
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.post("/query/settings/update")
|
|
|
|
+async def update_query_settings(
|
|
|
|
+ form_data: QuerySettingsForm, user=Depends(get_admin_user)
|
|
|
|
+):
|
|
|
|
+ app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
|
|
|
|
+ app.state.TOP_K = form_data.k if form_data.k else 4
|
|
return {"status": True, "template": app.state.RAG_TEMPLATE}
|
|
return {"status": True, "template": app.state.RAG_TEMPLATE}
|
|
|
|
|
|
|
|
|
|
class QueryDocForm(BaseModel):
|
|
class QueryDocForm(BaseModel):
|
|
collection_name: str
|
|
collection_name: str
|
|
query: str
|
|
query: str
|
|
- k: Optional[int] = 4
|
|
|
|
|
|
+ k: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
@app.post("/query/doc")
|
|
@app.post("/query/doc")
|
|
@@ -240,7 +252,10 @@ def query_doc(
|
|
name=form_data.collection_name,
|
|
name=form_data.collection_name,
|
|
embedding_function=app.state.sentence_transformer_ef,
|
|
embedding_function=app.state.sentence_transformer_ef,
|
|
)
|
|
)
|
|
- result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
|
|
|
|
|
|
+ result = collection.query(
|
|
|
|
+ query_texts=[form_data.query],
|
|
|
|
+ n_results=form_data.k if form_data.k else app.state.TOP_K,
|
|
|
|
+ )
|
|
return result
|
|
return result
|
|
except Exception as e:
|
|
except Exception as e:
|
|
print(e)
|
|
print(e)
|
|
@@ -253,7 +268,7 @@ def query_doc(
|
|
class QueryCollectionsForm(BaseModel):
|
|
class QueryCollectionsForm(BaseModel):
|
|
collection_names: List[str]
|
|
collection_names: List[str]
|
|
query: str
|
|
query: str
|
|
- k: Optional[int] = 4
|
|
|
|
|
|
+ k: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
def merge_and_sort_query_results(query_results, k):
|
|
def merge_and_sort_query_results(query_results, k):
|
|
@@ -317,13 +332,16 @@ def query_collection(
|
|
)
|
|
)
|
|
|
|
|
|
result = collection.query(
|
|
result = collection.query(
|
|
- query_texts=[form_data.query], n_results=form_data.k
|
|
|
|
|
|
+ query_texts=[form_data.query],
|
|
|
|
+ n_results=form_data.k if form_data.k else app.state.TOP_K,
|
|
)
|
|
)
|
|
results.append(result)
|
|
results.append(result)
|
|
except:
|
|
except:
|
|
pass
|
|
pass
|
|
|
|
|
|
- return merge_and_sort_query_results(results, form_data.k)
|
|
|
|
|
|
+ return merge_and_sort_query_results(
|
|
|
|
+ results, form_data.k if form_data.k else app.state.TOP_K
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
@app.post("/web")
|
|
@app.post("/web")
|
|
@@ -423,7 +441,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
|
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
|
] or file_ext in ["xls", "xlsx"]:
|
|
] or file_ext in ["xls", "xlsx"]:
|
|
loader = UnstructuredExcelLoader(file_path)
|
|
loader = UnstructuredExcelLoader(file_path)
|
|
- elif file_ext in known_source_ext or (file_content_type and file_content_type.find("text/") >= 0):
|
|
|
|
|
|
+ elif file_ext in known_source_ext or (
|
|
|
|
+ file_content_type and file_content_type.find("text/") >= 0
|
|
|
|
+ ):
|
|
loader = TextLoader(file_path)
|
|
loader = TextLoader(file_path)
|
|
else:
|
|
else:
|
|
loader = TextLoader(file_path)
|
|
loader = TextLoader(file_path)
|