|
@@ -39,8 +39,6 @@ import json
|
|
|
|
|
|
import sentence_transformers
|
|
import sentence_transformers
|
|
|
|
|
|
-from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
|
|
|
|
-
|
|
|
|
from apps.web.models.documents import (
|
|
from apps.web.models.documents import (
|
|
Documents,
|
|
Documents,
|
|
DocumentForm,
|
|
DocumentForm,
|
|
@@ -48,9 +46,10 @@ from apps.web.models.documents import (
|
|
)
|
|
)
|
|
|
|
|
|
from apps.rag.utils import (
|
|
from apps.rag.utils import (
|
|
|
|
+ get_model_path,
|
|
query_embeddings_doc,
|
|
query_embeddings_doc,
|
|
|
|
+ query_embeddings_function,
|
|
query_embeddings_collection,
|
|
query_embeddings_collection,
|
|
- generate_openai_embeddings,
|
|
|
|
)
|
|
)
|
|
|
|
|
|
from utils.misc import (
|
|
from utils.misc import (
|
|
@@ -60,13 +59,20 @@ from utils.misc import (
|
|
extract_folders_after_data_docs,
|
|
extract_folders_after_data_docs,
|
|
)
|
|
)
|
|
from utils.utils import get_current_user, get_admin_user
|
|
from utils.utils import get_current_user, get_admin_user
|
|
|
|
+
|
|
from config import (
|
|
from config import (
|
|
SRC_LOG_LEVELS,
|
|
SRC_LOG_LEVELS,
|
|
UPLOAD_DIR,
|
|
UPLOAD_DIR,
|
|
DOCS_DIR,
|
|
DOCS_DIR,
|
|
|
|
+ RAG_TOP_K,
|
|
|
|
+ RAG_RELEVANCE_THRESHOLD,
|
|
RAG_EMBEDDING_ENGINE,
|
|
RAG_EMBEDDING_ENGINE,
|
|
RAG_EMBEDDING_MODEL,
|
|
RAG_EMBEDDING_MODEL,
|
|
|
|
+ RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
|
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
|
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
|
|
|
+ RAG_RERANKING_MODEL,
|
|
|
|
+ RAG_RERANKING_MODEL_AUTO_UPDATE,
|
|
|
|
+ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
RAG_OPENAI_API_BASE_URL,
|
|
RAG_OPENAI_API_BASE_URL,
|
|
RAG_OPENAI_API_KEY,
|
|
RAG_OPENAI_API_KEY,
|
|
DEVICE_TYPE,
|
|
DEVICE_TYPE,
|
|
@@ -83,14 +89,14 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|
|
|
|
|
app = FastAPI()
|
|
app = FastAPI()
|
|
|
|
|
|
-
|
|
|
|
-app.state.TOP_K = 4
|
|
|
|
|
|
+app.state.TOP_K = RAG_TOP_K
|
|
|
|
+app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
|
|
app.state.CHUNK_SIZE = CHUNK_SIZE
|
|
app.state.CHUNK_SIZE = CHUNK_SIZE
|
|
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
|
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
|
|
|
|
|
-
|
|
|
|
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
|
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
|
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
|
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
|
|
|
+app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
|
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
|
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
|
|
|
|
|
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
|
|
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
|
|
@@ -98,16 +104,48 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
|
|
|
|
|
app.state.PDF_EXTRACT_IMAGES = False
|
|
app.state.PDF_EXTRACT_IMAGES = False
|
|
|
|
|
|
-if app.state.RAG_EMBEDDING_ENGINE == "":
|
|
|
|
- app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
|
|
|
- app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
- device=DEVICE_TYPE,
|
|
|
|
- trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
|
|
|
- )
|
|
|
|
|
|
|
|
|
|
+def update_embedding_model(
|
|
|
|
+ embedding_model: str,
|
|
|
|
+ update_model: bool = False,
|
|
|
|
+):
|
|
|
|
+ if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
|
|
|
|
+ app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
|
|
|
+ get_model_path(embedding_model, update_model),
|
|
|
|
+ device=DEVICE_TYPE,
|
|
|
|
+ trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ app.state.sentence_transformer_ef = None
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def update_reranking_model(
|
|
|
|
+ reranking_model: str,
|
|
|
|
+ update_model: bool = False,
|
|
|
|
+):
|
|
|
|
+ if reranking_model:
|
|
|
|
+ app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
|
|
|
+ get_model_path(reranking_model, update_model),
|
|
|
|
+ device=DEVICE_TYPE,
|
|
|
|
+ trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ app.state.sentence_transformer_rf = None
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+update_embedding_model(
|
|
|
|
+ app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
+ RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+update_reranking_model(
|
|
|
|
+ app.state.RAG_RERANKING_MODEL,
|
|
|
|
+ RAG_RERANKING_MODEL_AUTO_UPDATE,
|
|
|
|
+)
|
|
|
|
|
|
origins = ["*"]
|
|
origins = ["*"]
|
|
|
|
|
|
|
|
+
|
|
app.add_middleware(
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
CORSMiddleware,
|
|
allow_origins=origins,
|
|
allow_origins=origins,
|
|
@@ -134,6 +172,7 @@ async def get_status():
|
|
"template": app.state.RAG_TEMPLATE,
|
|
"template": app.state.RAG_TEMPLATE,
|
|
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
|
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
|
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
|
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
+ "reranking_model": app.state.RAG_RERANKING_MODEL,
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -150,6 +189,11 @@ async def get_embedding_config(user=Depends(get_admin_user)):
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
+@app.get("/reranking")
|
|
|
|
+async def get_reraanking_config(user=Depends(get_admin_user)):
|
|
|
|
+ return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}
|
|
|
|
+
|
|
|
|
+
|
|
class OpenAIConfigForm(BaseModel):
|
|
class OpenAIConfigForm(BaseModel):
|
|
url: str
|
|
url: str
|
|
key: str
|
|
key: str
|
|
@@ -170,22 +214,14 @@ async def update_embedding_config(
|
|
)
|
|
)
|
|
try:
|
|
try:
|
|
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
|
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
|
|
|
+ app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
|
|
|
|
|
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
|
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
|
- app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
|
|
|
- app.state.sentence_transformer_ef = None
|
|
|
|
-
|
|
|
|
if form_data.openai_config != None:
|
|
if form_data.openai_config != None:
|
|
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
|
|
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
|
|
app.state.OPENAI_API_KEY = form_data.openai_config.key
|
|
app.state.OPENAI_API_KEY = form_data.openai_config.key
|
|
- else:
|
|
|
|
- sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
|
|
|
- app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
- device=DEVICE_TYPE,
|
|
|
|
- trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
|
|
|
- )
|
|
|
|
- app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
|
|
|
- app.state.sentence_transformer_ef = sentence_transformer_ef
|
|
|
|
|
|
+
|
|
|
|
+ update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
|
|
|
|
|
|
return {
|
|
return {
|
|
"status": True,
|
|
"status": True,
|
|
@@ -196,7 +232,6 @@ async def update_embedding_config(
|
|
"key": app.state.OPENAI_API_KEY,
|
|
"key": app.state.OPENAI_API_KEY,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
-
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
log.exception(f"Problem updating embedding model: {e}")
|
|
log.exception(f"Problem updating embedding model: {e}")
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
@@ -205,6 +240,34 @@ async def update_embedding_config(
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
+class RerankingModelUpdateForm(BaseModel):
|
|
|
|
+ reranking_model: str
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.post("/reranking/update")
|
|
|
|
+async def update_reranking_config(
|
|
|
|
+ form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
|
|
|
|
+):
|
|
|
|
+ log.info(
|
|
|
|
+ f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
|
|
|
|
+ )
|
|
|
|
+ try:
|
|
|
|
+ app.state.RAG_RERANKING_MODEL = form_data.reranking_model
|
|
|
|
+
|
|
|
|
+ update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
|
|
|
|
+
|
|
|
|
+ return {
|
|
|
|
+ "status": True,
|
|
|
|
+ "reranking_model": app.state.RAG_RERANKING_MODEL,
|
|
|
|
+ }
|
|
|
|
+ except Exception as e:
|
|
|
|
+ log.exception(f"Problem updating reranking model: {e}")
|
|
|
|
+ raise HTTPException(
|
|
|
|
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
+ detail=ERROR_MESSAGES.DEFAULT(e),
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+
|
|
@app.get("/config")
|
|
@app.get("/config")
|
|
async def get_rag_config(user=Depends(get_admin_user)):
|
|
async def get_rag_config(user=Depends(get_admin_user)):
|
|
return {
|
|
return {
|
|
@@ -257,11 +320,13 @@ async def get_query_settings(user=Depends(get_admin_user)):
|
|
"status": True,
|
|
"status": True,
|
|
"template": app.state.RAG_TEMPLATE,
|
|
"template": app.state.RAG_TEMPLATE,
|
|
"k": app.state.TOP_K,
|
|
"k": app.state.TOP_K,
|
|
|
|
+ "r": app.state.RELEVANCE_THRESHOLD,
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
class QuerySettingsForm(BaseModel):
|
|
class QuerySettingsForm(BaseModel):
|
|
k: Optional[int] = None
|
|
k: Optional[int] = None
|
|
|
|
+ r: Optional[float] = None
|
|
template: Optional[str] = None
|
|
template: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
@@ -271,6 +336,7 @@ async def update_query_settings(
|
|
):
|
|
):
|
|
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
|
|
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
|
|
app.state.TOP_K = form_data.k if form_data.k else 4
|
|
app.state.TOP_K = form_data.k if form_data.k else 4
|
|
|
|
+ app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
|
|
return {"status": True, "template": app.state.RAG_TEMPLATE}
|
|
return {"status": True, "template": app.state.RAG_TEMPLATE}
|
|
|
|
|
|
|
|
|
|
@@ -278,6 +344,7 @@ class QueryDocForm(BaseModel):
|
|
collection_name: str
|
|
collection_name: str
|
|
query: str
|
|
query: str
|
|
k: Optional[int] = None
|
|
k: Optional[int] = None
|
|
|
|
+ r: Optional[float] = None
|
|
|
|
|
|
|
|
|
|
@app.post("/query/doc")
|
|
@app.post("/query/doc")
|
|
@@ -286,34 +353,22 @@ def query_doc_handler(
|
|
user=Depends(get_current_user),
|
|
user=Depends(get_current_user),
|
|
):
|
|
):
|
|
try:
|
|
try:
|
|
- if app.state.RAG_EMBEDDING_ENGINE == "":
|
|
|
|
- query_embeddings = app.state.sentence_transformer_ef.encode(
|
|
|
|
- form_data.query
|
|
|
|
- ).tolist()
|
|
|
|
- elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
|
|
|
- query_embeddings = generate_ollama_embeddings(
|
|
|
|
- GenerateEmbeddingsForm(
|
|
|
|
- **{
|
|
|
|
- "model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
- "prompt": form_data.query,
|
|
|
|
- }
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
- elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
|
|
|
- query_embeddings = generate_openai_embeddings(
|
|
|
|
- model=app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
- text=form_data.query,
|
|
|
|
- key=app.state.OPENAI_API_KEY,
|
|
|
|
- url=app.state.OPENAI_API_BASE_URL,
|
|
|
|
- )
|
|
|
|
|
|
+ embeddings_function = query_embeddings_function(
|
|
|
|
+ app.state.RAG_EMBEDDING_ENGINE,
|
|
|
|
+ app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
+ app.state.sentence_transformer_ef,
|
|
|
|
+ app.state.OPENAI_API_KEY,
|
|
|
|
+ app.state.OPENAI_API_BASE_URL,
|
|
|
|
+ )
|
|
|
|
|
|
return query_embeddings_doc(
|
|
return query_embeddings_doc(
|
|
collection_name=form_data.collection_name,
|
|
collection_name=form_data.collection_name,
|
|
query=form_data.query,
|
|
query=form_data.query,
|
|
- query_embeddings=query_embeddings,
|
|
|
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
|
|
|
+ r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
|
|
|
|
+ embeddings_function=embeddings_function,
|
|
|
|
+ reranking_function=app.state.sentence_transformer_rf,
|
|
)
|
|
)
|
|
-
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
log.exception(e)
|
|
log.exception(e)
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
@@ -326,6 +381,7 @@ class QueryCollectionsForm(BaseModel):
|
|
collection_names: List[str]
|
|
collection_names: List[str]
|
|
query: str
|
|
query: str
|
|
k: Optional[int] = None
|
|
k: Optional[int] = None
|
|
|
|
+ r: Optional[float] = None
|
|
|
|
|
|
|
|
|
|
@app.post("/query/collection")
|
|
@app.post("/query/collection")
|
|
@@ -334,33 +390,22 @@ def query_collection_handler(
|
|
user=Depends(get_current_user),
|
|
user=Depends(get_current_user),
|
|
):
|
|
):
|
|
try:
|
|
try:
|
|
- if app.state.RAG_EMBEDDING_ENGINE == "":
|
|
|
|
- query_embeddings = app.state.sentence_transformer_ef.encode(
|
|
|
|
- form_data.query
|
|
|
|
- ).tolist()
|
|
|
|
- elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
|
|
|
- query_embeddings = generate_ollama_embeddings(
|
|
|
|
- GenerateEmbeddingsForm(
|
|
|
|
- **{
|
|
|
|
- "model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
- "prompt": form_data.query,
|
|
|
|
- }
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
- elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
|
|
|
- query_embeddings = generate_openai_embeddings(
|
|
|
|
- model=app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
- text=form_data.query,
|
|
|
|
- key=app.state.OPENAI_API_KEY,
|
|
|
|
- url=app.state.OPENAI_API_BASE_URL,
|
|
|
|
- )
|
|
|
|
|
|
+ embeddings_function = query_embeddings_function(
|
|
|
|
+ app.state.RAG_EMBEDDING_ENGINE,
|
|
|
|
+ app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
+ app.state.sentence_transformer_ef,
|
|
|
|
+ app.state.OPENAI_API_KEY,
|
|
|
|
+ app.state.OPENAI_API_BASE_URL,
|
|
|
|
+ )
|
|
|
|
|
|
return query_embeddings_collection(
|
|
return query_embeddings_collection(
|
|
collection_names=form_data.collection_names,
|
|
collection_names=form_data.collection_names,
|
|
- query_embeddings=query_embeddings,
|
|
|
|
|
|
+ query=form_data.query,
|
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
|
|
|
+ r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
|
|
|
|
+ embeddings_function=embeddings_function,
|
|
|
|
+ reranking_function=app.state.sentence_transformer_rf,
|
|
)
|
|
)
|
|
-
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
log.exception(e)
|
|
log.exception(e)
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
@@ -427,8 +472,6 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
|
log.info(f"store_docs_in_vector_db {docs} {collection_name}")
|
|
log.info(f"store_docs_in_vector_db {docs} {collection_name}")
|
|
|
|
|
|
texts = [doc.page_content for doc in docs]
|
|
texts = [doc.page_content for doc in docs]
|
|
- texts = list(map(lambda x: x.replace("\n", " "), texts))
|
|
|
|
-
|
|
|
|
metadatas = [doc.metadata for doc in docs]
|
|
metadatas = [doc.metadata for doc in docs]
|
|
|
|
|
|
try:
|
|
try:
|
|
@@ -440,27 +483,16 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
|
|
|
|
|
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
|
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
|
|
|
|
|
- if app.state.RAG_EMBEDDING_ENGINE == "":
|
|
|
|
- embeddings = app.state.sentence_transformer_ef.encode(texts).tolist()
|
|
|
|
- elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
|
|
|
- embeddings = [
|
|
|
|
- generate_ollama_embeddings(
|
|
|
|
- GenerateEmbeddingsForm(
|
|
|
|
- **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
- for text in texts
|
|
|
|
- ]
|
|
|
|
- elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
|
|
|
- embeddings = [
|
|
|
|
- generate_openai_embeddings(
|
|
|
|
- model=app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
- text=text,
|
|
|
|
- key=app.state.OPENAI_API_KEY,
|
|
|
|
- url=app.state.OPENAI_API_BASE_URL,
|
|
|
|
- )
|
|
|
|
- for text in texts
|
|
|
|
- ]
|
|
|
|
|
|
+ embedding_func = query_embeddings_function(
|
|
|
|
+ app.state.RAG_EMBEDDING_ENGINE,
|
|
|
|
+ app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
+ app.state.sentence_transformer_ef,
|
|
|
|
+ app.state.OPENAI_API_KEY,
|
|
|
|
+ app.state.OPENAI_API_BASE_URL,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
|
|
|
|
+ embeddings = embedding_func(embedding_texts)
|
|
|
|
|
|
for batch in create_batches(
|
|
for batch in create_batches(
|
|
api=CHROMA_CLIENT,
|
|
api=CHROMA_CLIENT,
|