|
@@ -78,6 +78,7 @@ from utils.misc import (
|
|
|
from utils.utils import get_current_user, get_admin_user
|
|
|
|
|
|
from config import (
|
|
|
+ AppConfig,
|
|
|
ENV,
|
|
|
SRC_LOG_LEVELS,
|
|
|
UPLOAD_DIR,
|
|
@@ -114,7 +115,7 @@ from config import (
|
|
|
SERPER_API_KEY,
|
|
|
RAG_WEB_SEARCH_RESULT_COUNT,
|
|
|
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
|
|
- AppConfig,
|
|
|
+ RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
|
|
)
|
|
|
|
|
|
from constants import ERROR_MESSAGES
|
|
@@ -139,6 +140,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
|
|
|
|
|
|
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
|
|
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
|
|
+app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
|
|
|
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
|
|
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
|
|
|
|
@@ -212,6 +214,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
|
|
app.state.sentence_transformer_ef,
|
|
|
app.state.config.OPENAI_API_KEY,
|
|
|
app.state.config.OPENAI_API_BASE_URL,
|
|
|
+ app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
|
|
)
|
|
|
|
|
|
origins = ["*"]
|
|
@@ -248,6 +251,7 @@ async def get_status():
|
|
|
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
|
|
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
|
|
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
|
|
|
+ "openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
|
|
}
|
|
|
|
|
|
|
|
@@ -260,6 +264,7 @@ async def get_embedding_config(user=Depends(get_admin_user)):
|
|
|
"openai_config": {
|
|
|
"url": app.state.config.OPENAI_API_BASE_URL,
|
|
|
"key": app.state.config.OPENAI_API_KEY,
|
|
|
+ "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
|
|
},
|
|
|
}
|
|
|
|
|
@@ -275,6 +280,7 @@ async def get_reraanking_config(user=Depends(get_admin_user)):
|
|
|
class OpenAIConfigForm(BaseModel):
|
|
|
url: str
|
|
|
key: str
|
|
|
+ batch_size: Optional[int] = None
|
|
|
|
|
|
|
|
|
class EmbeddingModelUpdateForm(BaseModel):
|
|
@@ -295,9 +301,14 @@ async def update_embedding_config(
|
|
|
app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
|
|
|
|
|
if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
|
|
- if form_data.openai_config != None:
|
|
|
+ if form_data.openai_config is not None:
|
|
|
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
|
|
|
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
|
|
|
+ app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
|
|
|
+ form_data.openai_config.batch_size
|
|
|
+ if form_data.openai_config.batch_size
|
|
|
+ else 1
|
|
|
+ )
|
|
|
|
|
|
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
|
|
|
|
|
@@ -307,6 +318,7 @@ async def update_embedding_config(
|
|
|
app.state.sentence_transformer_ef,
|
|
|
app.state.config.OPENAI_API_KEY,
|
|
|
app.state.config.OPENAI_API_BASE_URL,
|
|
|
+ app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
|
|
)
|
|
|
|
|
|
return {
|
|
@@ -316,6 +328,7 @@ async def update_embedding_config(
|
|
|
"openai_config": {
|
|
|
"url": app.state.config.OPENAI_API_BASE_URL,
|
|
|
"key": app.state.config.OPENAI_API_KEY,
|
|
|
+ "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
|
|
},
|
|
|
}
|
|
|
except Exception as e:
|
|
@@ -881,6 +894,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
|
|
app.state.sentence_transformer_ef,
|
|
|
app.state.config.OPENAI_API_KEY,
|
|
|
app.state.config.OPENAI_API_BASE_URL,
|
|
|
+ app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
|
|
)
|
|
|
|
|
|
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
|