|
@@ -271,18 +271,13 @@ def get_embedding_function(
|
|
|
if embedding_engine == "":
|
|
|
return lambda query: embedding_function.encode(query).tolist()
|
|
|
elif embedding_engine in ["ollama", "openai"]:
|
|
|
- if embedding_engine == "ollama":
|
|
|
- func = lambda query: generate_embeddings(
|
|
|
- model=embedding_model,
|
|
|
- text=query,
|
|
|
- )
|
|
|
- elif embedding_engine == "openai":
|
|
|
- func = lambda query: generate_embeddings(
|
|
|
- model=embedding_model,
|
|
|
- text=query,
|
|
|
- key=openai_key,
|
|
|
- url=openai_url,
|
|
|
- )
|
|
|
+ func = lambda query: generate_embeddings(
|
|
|
+ engine=embedding_engine,
|
|
|
+ model=embedding_model,
|
|
|
+ text=query,
|
|
|
+ key=openai_key if embedding_engine == "openai" else "",
|
|
|
+ url=openai_url if embedding_engine == "openai" else "",
|
|
|
+ )
|
|
|
|
|
|
def generate_multiple(query, func):
|
|
|
if isinstance(query, list):
|