|
@@ -191,7 +191,9 @@ def rag_messages(docs, messages, template, k, embedding_function):
|
|
|
|
|
|
return messages
|
|
return messages
|
|
|
|
|
|
-def embedding_model_get_path(embedding_model: str, update_embedding_model: bool = False):
|
|
|
|
|
|
+def embedding_model_get_path(
|
|
|
|
+ embedding_model: str, update_embedding_model: bool = False
|
|
|
|
+):
|
|
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
|
|
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
|
|
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
|
|
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
|
|
local_files_only = not update_embedding_model
|
|
local_files_only = not update_embedding_model
|
|
@@ -206,7 +208,11 @@ def embedding_model_get_path(embedding_model: str, update_embedding_model: bool
|
|
log.debug(f"local_files_only: {local_files_only}")
|
|
log.debug(f"local_files_only: {local_files_only}")
|
|
|
|
|
|
# Inspiration from upstream sentence_transformers
|
|
# Inspiration from upstream sentence_transformers
|
|
- if (os.path.exists(embedding_model) or ("\\" in embedding_model or embedding_model.count("/") > 1) and local_files_only):
|
|
|
|
|
|
+ if (
|
|
|
|
+ os.path.exists(embedding_model)
|
|
|
|
+ or ("\\" in embedding_model or embedding_model.count("/") > 1)
|
|
|
|
+ and local_files_only
|
|
|
|
+ ):
|
|
# If fully qualified path exists, return input, else set repo_id
|
|
# If fully qualified path exists, return input, else set repo_id
|
|
return embedding_model
|
|
return embedding_model
|
|
elif "/" not in embedding_model:
|
|
elif "/" not in embedding_model:
|