|
@@ -6,9 +6,12 @@ import requests
|
|
|
|
|
|
|
|
|
from huggingface_hub import snapshot_download
|
|
|
+from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
|
|
|
+
|
|
|
|
|
|
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
|
|
|
|
|
+
|
|
|
log = logging.getLogger(__name__)
|
|
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|
|
|
|
@@ -32,7 +35,7 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function):
|
|
|
def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
|
|
|
try:
|
|
|
# if you use docker use the model from the environment variable
|
|
|
- log.info("query_embeddings_doc", query_embeddings)
|
|
|
+ log.info(f"query_embeddings_doc {query_embeddings}")
|
|
|
collection = CHROMA_CLIENT.get_collection(
|
|
|
name=collection_name,
|
|
|
)
|
|
@@ -118,7 +121,7 @@ def query_collection(
|
|
|
def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
|
|
|
|
|
|
results = []
|
|
|
- log.info("query_embeddings_collection", query_embeddings)
|
|
|
+ log.info(f"query_embeddings_collection {query_embeddings}")
|
|
|
|
|
|
for collection_name in collection_names:
|
|
|
try:
|
|
@@ -141,7 +144,17 @@ def rag_template(template: str, context: str, query: str):
|
|
|
return template
|
|
|
|
|
|
|
|
|
-def rag_messages(docs, messages, template, k, embedding_function):
|
|
|
+def rag_messages(
|
|
|
+ docs,
|
|
|
+ messages,
|
|
|
+ template,
|
|
|
+ k,
|
|
|
+ embedding_engine,
|
|
|
+ embedding_model,
|
|
|
+ embedding_function,
|
|
|
+ openai_key,
|
|
|
+ openai_url,
|
|
|
+):
|
|
|
log.debug(f"docs: {docs}")
|
|
|
|
|
|
last_user_message_idx = None
|
|
@@ -175,22 +188,57 @@ def rag_messages(docs, messages, template, k, embedding_function):
|
|
|
context = None
|
|
|
|
|
|
try:
|
|
|
- if doc["type"] == "collection":
|
|
|
- context = query_collection(
|
|
|
- collection_names=doc["collection_names"],
|
|
|
- query=query,
|
|
|
- k=k,
|
|
|
- embedding_function=embedding_function,
|
|
|
- )
|
|
|
- elif doc["type"] == "text":
|
|
|
+
|
|
|
+ if doc["type"] == "text":
|
|
|
context = doc["content"]
|
|
|
else:
|
|
|
- context = query_doc(
|
|
|
- collection_name=doc["collection_name"],
|
|
|
- query=query,
|
|
|
- k=k,
|
|
|
- embedding_function=embedding_function,
|
|
|
- )
|
|
|
+ if embedding_engine == "":
|
|
|
+ if doc["type"] == "collection":
|
|
|
+ context = query_collection(
|
|
|
+ collection_names=doc["collection_names"],
|
|
|
+ query=query,
|
|
|
+ k=k,
|
|
|
+ embedding_function=embedding_function,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ context = query_doc(
|
|
|
+ collection_name=doc["collection_name"],
|
|
|
+ query=query,
|
|
|
+ k=k,
|
|
|
+ embedding_function=embedding_function,
|
|
|
+ )
|
|
|
+
|
|
|
+ else:
|
|
|
+ if embedding_engine == "ollama":
|
|
|
+ query_embeddings = generate_ollama_embeddings(
|
|
|
+ GenerateEmbeddingsForm(
|
|
|
+ **{
|
|
|
+ "model": embedding_model,
|
|
|
+ "prompt": query,
|
|
|
+ }
|
|
|
+ )
|
|
|
+ )
|
|
|
+ elif embedding_engine == "openai":
|
|
|
+ query_embeddings = generate_openai_embeddings(
|
|
|
+ model=embedding_model,
|
|
|
+ text=query,
|
|
|
+ key=openai_key,
|
|
|
+ url=openai_url,
|
|
|
+ )
|
|
|
+
|
|
|
+ if doc["type"] == "collection":
|
|
|
+ context = query_embeddings_collection(
|
|
|
+ collection_names=doc["collection_names"],
|
|
|
+ query_embeddings=query_embeddings,
|
|
|
+ k=k,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ context = query_embeddings_doc(
|
|
|
+ collection_name=doc["collection_name"],
|
|
|
+ query_embeddings=query_embeddings,
|
|
|
+ k=k,
|
|
|
+ )
|
|
|
+
|
|
|
except Exception as e:
|
|
|
log.exception(e)
|
|
|
context = None
|