|
@@ -272,26 +272,26 @@ def get_embedding_function(
|
|
|
return lambda query: embedding_function.encode(query).tolist()
|
|
|
elif embedding_engine in ["ollama", "openai"]:
|
|
|
if embedding_engine == "ollama":
|
|
|
- func = lambda query: generate_ollama_embeddings(
|
|
|
+ func = lambda query: generate_embeddings(
|
|
|
model=embedding_model,
|
|
|
- input=query,
|
|
|
+ text=query,
|
|
|
)
|
|
|
elif embedding_engine == "openai":
|
|
|
- func = lambda query: generate_openai_embeddings(
|
|
|
+ func = lambda query: generate_embeddings(
|
|
|
model=embedding_model,
|
|
|
text=query,
|
|
|
key=openai_key,
|
|
|
url=openai_url,
|
|
|
)
|
|
|
|
|
|
- def generate_multiple(query, f):
|
|
|
+ def generate_multiple(query, func):
|
|
|
if isinstance(query, list):
|
|
|
embeddings = []
|
|
|
for i in range(0, len(query), embedding_batch_size):
|
|
|
- embeddings.extend(f(query[i : i + embedding_batch_size]))
|
|
|
+ embeddings.extend(func(query[i : i + embedding_batch_size]))
|
|
|
return embeddings
|
|
|
else:
|
|
|
- return f(query)
|
|
|
+ return func(query)
|
|
|
|
|
|
return lambda query: generate_multiple(query, func)
|
|
|
|
|
@@ -438,20 +438,6 @@ def get_model_path(model: str, update_model: bool = False):
|
|
|
return model
|
|
|
|
|
|
|
|
|
-def generate_openai_embeddings(
|
|
|
- model: str,
|
|
|
- text: Union[str, list[str]],
|
|
|
- key: str,
|
|
|
- url: str = "https://api.openai.com/v1",
|
|
|
-):
|
|
|
- if isinstance(text, list):
|
|
|
- embeddings = generate_openai_batch_embeddings(model, text, key, url)
|
|
|
- else:
|
|
|
- embeddings = generate_openai_batch_embeddings(model, [text], key, url)
|
|
|
-
|
|
|
- return embeddings[0] if isinstance(text, str) else embeddings
|
|
|
-
|
|
|
-
|
|
|
def generate_openai_batch_embeddings(
|
|
|
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
|
|
|
) -> Optional[list[list[float]]]:
|
|
@@ -475,19 +461,31 @@ def generate_openai_batch_embeddings(
|
|
|
return None
|
|
|
|
|
|
|
|
|
-def generate_ollama_embeddings(
|
|
|
- model: str, input: list[str]
|
|
|
-) -> Optional[list[list[float]]]:
|
|
|
- if isinstance(input, list):
|
|
|
- embeddings = generate_ollama_batch_embeddings(
|
|
|
- GenerateEmbedForm(**{"model": model, "input": input})
|
|
|
- )
|
|
|
- else:
|
|
|
- embeddings = generate_ollama_batch_embeddings(
|
|
|
- GenerateEmbedForm(**{"model": model, "input": [input]})
|
|
|
+def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
|
|
|
+ if engine == "ollama":
|
|
|
+ if isinstance(text, list):
|
|
|
+ embeddings = generate_ollama_batch_embeddings(
|
|
|
+ GenerateEmbedForm(**{"model": model, "input": text})
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ embeddings = generate_ollama_batch_embeddings(
|
|
|
+ GenerateEmbedForm(**{"model": model, "input": [text]})
|
|
|
+ )
|
|
|
+ return (
|
|
|
+ embeddings["embeddings"][0]
|
|
|
+ if isinstance(text, str)
|
|
|
+ else embeddings["embeddings"]
|
|
|
)
|
|
|
+ elif engine == "openai":
|
|
|
+ key = kwargs.get("key", "")
|
|
|
+ url = kwargs.get("url", "https://api.openai.com/v1")
|
|
|
+
|
|
|
+ if isinstance(text, list):
|
|
|
+ embeddings = generate_openai_batch_embeddings(model, text, key, url)
|
|
|
+ else:
|
|
|
+ embeddings = generate_openai_batch_embeddings(model, [text], key, url)
|
|
|
|
|
|
- return embeddings["embeddings"]
|
|
|
+ return embeddings[0] if isinstance(text, str) else embeddings
|
|
|
|
|
|
|
|
|
import operator
|