|
@@ -53,6 +53,7 @@ from apps.rag.utils import (
|
|
|
query_collection,
|
|
|
query_embeddings_collection,
|
|
|
get_embedding_model_path,
|
|
|
+ generate_openai_embeddings,
|
|
|
)
|
|
|
|
|
|
from utils.misc import (
|
|
@@ -93,6 +94,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
|
|
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
|
|
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
|
|
|
|
|
+app.state.RAG_OPENAI_API_BASE_URL = "https://api.openai.com"
|
|
|
+app.state.RAG_OPENAI_API_KEY = ""
|
|
|
|
|
|
app.state.PDF_EXTRACT_IMAGES = False
|
|
|
|
|
@@ -144,10 +147,20 @@ async def get_embedding_config(user=Depends(get_admin_user)):
|
|
|
"status": True,
|
|
|
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
|
|
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
+ "openai_config": {
|
|
|
+ "url": app.state.RAG_OPENAI_API_BASE_URL,
|
|
|
+ "key": app.state.RAG_OPENAI_API_KEY,
|
|
|
+ },
|
|
|
}
|
|
|
|
|
|
|
|
|
+class OpenAIConfigForm(BaseModel):
|
|
|
+ url: str
|
|
|
+ key: str
|
|
|
+
|
|
|
+
|
|
|
class EmbeddingModelUpdateForm(BaseModel):
|
|
|
+ openai_config: Optional[OpenAIConfigForm] = None
|
|
|
embedding_engine: str
|
|
|
embedding_model: str
|
|
|
|
|
@@ -156,17 +169,19 @@ class EmbeddingModelUpdateForm(BaseModel):
|
|
|
async def update_embedding_config(
|
|
|
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
|
|
):
|
|
|
-
|
|
|
log.info(
|
|
|
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
|
|
)
|
|
|
-
|
|
|
try:
|
|
|
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
|
|
|
|
|
- if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
|
|
+ if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
|
|
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
|
|
app.state.sentence_transformer_ef = None
|
|
|
+
|
|
|
+ if form_data.openai_config != None:
|
|
|
+ app.state.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url
|
|
|
+ app.state.RAG_OPENAI_API_KEY = form_data.openai_config.key
|
|
|
else:
|
|
|
sentence_transformer_ef = (
|
|
|
embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
@@ -183,6 +198,10 @@ async def update_embedding_config(
|
|
|
"status": True,
|
|
|
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
|
|
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
+ "openai_config": {
|
|
|
+ "url": app.state.RAG_OPENAI_API_BASE_URL,
|
|
|
+ "key": app.state.RAG_OPENAI_API_KEY,
|
|
|
+ },
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
@@ -275,28 +294,37 @@ def query_doc_handler(
|
|
|
):
|
|
|
|
|
|
try:
|
|
|
- if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
|
|
- query_embeddings = generate_ollama_embeddings(
|
|
|
- GenerateEmbeddingsForm(
|
|
|
- **{
|
|
|
- "model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
- "prompt": form_data.query,
|
|
|
- }
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- return query_embeddings_doc(
|
|
|
+ if app.state.RAG_EMBEDDING_ENGINE == "":
|
|
|
+ return query_doc(
|
|
|
collection_name=form_data.collection_name,
|
|
|
- query_embeddings=query_embeddings,
|
|
|
+ query=form_data.query,
|
|
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
|
|
+ embedding_function=app.state.sentence_transformer_ef,
|
|
|
)
|
|
|
else:
|
|
|
- return query_doc(
|
|
|
+ if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
|
|
+ query_embeddings = generate_ollama_embeddings(
|
|
|
+ GenerateEmbeddingsForm(
|
|
|
+ **{
|
|
|
+ "model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
+ "prompt": form_data.query,
|
|
|
+ }
|
|
|
+ )
|
|
|
+ )
|
|
|
+ elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
|
|
+ query_embeddings = generate_openai_embeddings(
|
|
|
+ model=app.state.RAG_EMBEDDING_MODEL,
|
|
|
+ text=form_data.query,
|
|
|
+ key=app.state.RAG_OPENAI_API_KEY,
|
|
|
+ url=app.state.RAG_OPENAI_API_BASE_URL,
|
|
|
+ )
|
|
|
+
|
|
|
+ return query_embeddings_doc(
|
|
|
collection_name=form_data.collection_name,
|
|
|
- query=form_data.query,
|
|
|
+ query_embeddings=query_embeddings,
|
|
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
|
|
- embedding_function=app.state.sentence_transformer_ef,
|
|
|
)
|
|
|
+
|
|
|
except Exception as e:
|
|
|
log.exception(e)
|
|
|
raise HTTPException(
|
|
@@ -317,28 +345,38 @@ def query_collection_handler(
|
|
|
user=Depends(get_current_user),
|
|
|
):
|
|
|
try:
|
|
|
- if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
|
|
- query_embeddings = generate_ollama_embeddings(
|
|
|
- GenerateEmbeddingsForm(
|
|
|
- **{
|
|
|
- "model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
- "prompt": form_data.query,
|
|
|
- }
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- return query_embeddings_collection(
|
|
|
+ if app.state.RAG_EMBEDDING_ENGINE == "":
|
|
|
+ return query_collection(
|
|
|
collection_names=form_data.collection_names,
|
|
|
- query_embeddings=query_embeddings,
|
|
|
+ query=form_data.query,
|
|
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
|
|
+ embedding_function=app.state.sentence_transformer_ef,
|
|
|
)
|
|
|
else:
|
|
|
- return query_collection(
|
|
|
+
|
|
|
+ if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
|
|
+ query_embeddings = generate_ollama_embeddings(
|
|
|
+ GenerateEmbeddingsForm(
|
|
|
+ **{
|
|
|
+ "model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
+ "prompt": form_data.query,
|
|
|
+ }
|
|
|
+ )
|
|
|
+ )
|
|
|
+ elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
|
|
+ query_embeddings = generate_openai_embeddings(
|
|
|
+ model=app.state.RAG_EMBEDDING_MODEL,
|
|
|
+ text=form_data.query,
|
|
|
+ key=app.state.RAG_OPENAI_API_KEY,
|
|
|
+ url=app.state.RAG_OPENAI_API_BASE_URL,
|
|
|
+ )
|
|
|
+
|
|
|
+ return query_embeddings_collection(
|
|
|
collection_names=form_data.collection_names,
|
|
|
- query=form_data.query,
|
|
|
+ query_embeddings=query_embeddings,
|
|
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
|
|
- embedding_function=app.state.sentence_transformer_ef,
|
|
|
)
|
|
|
+
|
|
|
except Exception as e:
|
|
|
log.exception(e)
|
|
|
raise HTTPException(
|
|
@@ -383,7 +421,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b
|
|
|
docs = text_splitter.split_documents(data)
|
|
|
|
|
|
if len(docs) > 0:
|
|
|
- log.info("store_data_in_vector_db", "store_docs_in_vector_db")
|
|
|
+ log.info(f"store_data_in_vector_db {docs}")
|
|
|
return store_docs_in_vector_db(docs, collection_name, overwrite), None
|
|
|
else:
|
|
|
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
|
|
@@ -402,7 +440,7 @@ def store_text_in_vector_db(
|
|
|
|
|
|
|
|
|
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
|
|
|
- log.info("store_docs_in_vector_db", docs, collection_name)
|
|
|
+ log.info(f"store_docs_in_vector_db {docs} {collection_name}")
|
|
|
|
|
|
texts = [doc.page_content for doc in docs]
|
|
|
metadatas = [doc.metadata for doc in docs]
|
|
@@ -414,39 +452,54 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
|
|
log.info(f"deleting existing collection {collection_name}")
|
|
|
CHROMA_CLIENT.delete_collection(name=collection_name)
|
|
|
|
|
|
- if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
|
|
- collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
|
|
+ if app.state.RAG_EMBEDDING_ENGINE == "":
|
|
|
+
|
|
|
+ collection = CHROMA_CLIENT.create_collection(
|
|
|
+ name=collection_name,
|
|
|
+ embedding_function=app.state.sentence_transformer_ef,
|
|
|
+ )
|
|
|
|
|
|
for batch in create_batches(
|
|
|
api=CHROMA_CLIENT,
|
|
|
ids=[str(uuid.uuid1()) for _ in texts],
|
|
|
metadatas=metadatas,
|
|
|
- embeddings=[
|
|
|
+ documents=texts,
|
|
|
+ ):
|
|
|
+ collection.add(*batch)
|
|
|
+
|
|
|
+ else:
|
|
|
+ collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
|
|
+
|
|
|
+ if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
|
|
+ embeddings = [
|
|
|
generate_ollama_embeddings(
|
|
|
GenerateEmbeddingsForm(
|
|
|
- **{"model": RAG_EMBEDDING_MODEL, "prompt": text}
|
|
|
+ **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
|
|
|
)
|
|
|
)
|
|
|
for text in texts
|
|
|
- ],
|
|
|
- ):
|
|
|
- collection.add(*batch)
|
|
|
- else:
|
|
|
-
|
|
|
- collection = CHROMA_CLIENT.create_collection(
|
|
|
- name=collection_name,
|
|
|
- embedding_function=app.state.sentence_transformer_ef,
|
|
|
- )
|
|
|
+ ]
|
|
|
+ elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
|
|
+ embeddings = [
|
|
|
+ generate_openai_embeddings(
|
|
|
+ model=app.state.RAG_EMBEDDING_MODEL,
|
|
|
+ text=text,
|
|
|
+ key=app.state.RAG_OPENAI_API_KEY,
|
|
|
+ url=app.state.RAG_OPENAI_API_BASE_URL,
|
|
|
+ )
|
|
|
+ for text in texts
|
|
|
+ ]
|
|
|
|
|
|
for batch in create_batches(
|
|
|
api=CHROMA_CLIENT,
|
|
|
ids=[str(uuid.uuid1()) for _ in texts],
|
|
|
metadatas=metadatas,
|
|
|
+ embeddings=embeddings,
|
|
|
documents=texts,
|
|
|
):
|
|
|
collection.add(*batch)
|
|
|
|
|
|
- return True
|
|
|
+ return True
|
|
|
except Exception as e:
|
|
|
log.exception(e)
|
|
|
if e.__class__.__name__ == "UniqueConstraintError":
|