|
@@ -13,7 +13,6 @@ import os, shutil, logging, re
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
from typing import List
|
|
from typing import List
|
|
|
|
|
|
-from chromadb.utils import embedding_functions
|
|
|
|
from chromadb.utils.batch_utils import create_batches
|
|
from chromadb.utils.batch_utils import create_batches
|
|
|
|
|
|
from langchain_community.document_loaders import (
|
|
from langchain_community.document_loaders import (
|
|
@@ -38,6 +37,7 @@ import mimetypes
|
|
import uuid
|
|
import uuid
|
|
import json
|
|
import json
|
|
|
|
|
|
|
|
+import sentence_transformers
|
|
|
|
|
|
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
|
|
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
|
|
|
|
|
|
@@ -48,11 +48,8 @@ from apps.web.models.documents import (
|
|
)
|
|
)
|
|
|
|
|
|
from apps.rag.utils import (
|
|
from apps.rag.utils import (
|
|
- query_doc,
|
|
|
|
query_embeddings_doc,
|
|
query_embeddings_doc,
|
|
- query_collection,
|
|
|
|
query_embeddings_collection,
|
|
query_embeddings_collection,
|
|
- get_embedding_model_path,
|
|
|
|
generate_openai_embeddings,
|
|
generate_openai_embeddings,
|
|
)
|
|
)
|
|
|
|
|
|
@@ -69,7 +66,7 @@ from config import (
|
|
DOCS_DIR,
|
|
DOCS_DIR,
|
|
RAG_EMBEDDING_ENGINE,
|
|
RAG_EMBEDDING_ENGINE,
|
|
RAG_EMBEDDING_MODEL,
|
|
RAG_EMBEDDING_MODEL,
|
|
- RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
|
|
|
|
|
+ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
|
RAG_OPENAI_API_BASE_URL,
|
|
RAG_OPENAI_API_BASE_URL,
|
|
RAG_OPENAI_API_KEY,
|
|
RAG_OPENAI_API_KEY,
|
|
DEVICE_TYPE,
|
|
DEVICE_TYPE,
|
|
@@ -101,15 +98,12 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
|
|
|
|
|
app.state.PDF_EXTRACT_IMAGES = False
|
|
app.state.PDF_EXTRACT_IMAGES = False
|
|
|
|
|
|
-
|
|
|
|
-app.state.sentence_transformer_ef = (
|
|
|
|
- embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
|
|
- model_name=get_embedding_model_path(
|
|
|
|
- app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE
|
|
|
|
- ),
|
|
|
|
|
|
+if app.state.RAG_EMBEDDING_ENGINE == "":
|
|
|
|
+ app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
|
|
|
+ app.state.RAG_EMBEDDING_MODEL,
|
|
device=DEVICE_TYPE,
|
|
device=DEVICE_TYPE,
|
|
|
|
+ trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
|
)
|
|
)
|
|
-)
|
|
|
|
|
|
|
|
|
|
|
|
origins = ["*"]
|
|
origins = ["*"]
|
|
@@ -185,13 +179,10 @@ async def update_embedding_config(
|
|
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
|
|
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
|
|
app.state.OPENAI_API_KEY = form_data.openai_config.key
|
|
app.state.OPENAI_API_KEY = form_data.openai_config.key
|
|
else:
|
|
else:
|
|
- sentence_transformer_ef = (
|
|
|
|
- embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
|
|
- model_name=get_embedding_model_path(
|
|
|
|
- form_data.embedding_model, True
|
|
|
|
- ),
|
|
|
|
- device=DEVICE_TYPE,
|
|
|
|
- )
|
|
|
|
|
|
+ sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
|
|
|
+ app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
+ device=DEVICE_TYPE,
|
|
|
|
+ trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
|
)
|
|
)
|
|
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
|
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
|
app.state.sentence_transformer_ef = sentence_transformer_ef
|
|
app.state.sentence_transformer_ef = sentence_transformer_ef
|
|
@@ -294,39 +285,35 @@ def query_doc_handler(
|
|
form_data: QueryDocForm,
|
|
form_data: QueryDocForm,
|
|
user=Depends(get_current_user),
|
|
user=Depends(get_current_user),
|
|
):
|
|
):
|
|
-
|
|
|
|
try:
|
|
try:
|
|
if app.state.RAG_EMBEDDING_ENGINE == "":
|
|
if app.state.RAG_EMBEDDING_ENGINE == "":
|
|
- return query_doc(
|
|
|
|
- collection_name=form_data.collection_name,
|
|
|
|
- query=form_data.query,
|
|
|
|
- k=form_data.k if form_data.k else app.state.TOP_K,
|
|
|
|
- embedding_function=app.state.sentence_transformer_ef,
|
|
|
|
- )
|
|
|
|
- else:
|
|
|
|
- if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
|
|
|
- query_embeddings = generate_ollama_embeddings(
|
|
|
|
- GenerateEmbeddingsForm(
|
|
|
|
- **{
|
|
|
|
- "model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
- "prompt": form_data.query,
|
|
|
|
- }
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
- elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
|
|
|
- query_embeddings = generate_openai_embeddings(
|
|
|
|
- model=app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
- text=form_data.query,
|
|
|
|
- key=app.state.OPENAI_API_KEY,
|
|
|
|
- url=app.state.OPENAI_API_BASE_URL,
|
|
|
|
|
|
+ query_embeddings = app.state.sentence_transformer_ef.encode(
|
|
|
|
+ form_data.query
|
|
|
|
+ ).tolist()
|
|
|
|
+ elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
|
|
|
+ query_embeddings = generate_ollama_embeddings(
|
|
|
|
+ GenerateEmbeddingsForm(
|
|
|
|
+ **{
|
|
|
|
+ "model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
+ "prompt": form_data.query,
|
|
|
|
+ }
|
|
)
|
|
)
|
|
-
|
|
|
|
- return query_embeddings_doc(
|
|
|
|
- collection_name=form_data.collection_name,
|
|
|
|
- query_embeddings=query_embeddings,
|
|
|
|
- k=form_data.k if form_data.k else app.state.TOP_K,
|
|
|
|
|
|
+ )
|
|
|
|
+ elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
|
|
|
+ query_embeddings = generate_openai_embeddings(
|
|
|
|
+ model=app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
+ text=form_data.query,
|
|
|
|
+ key=app.state.OPENAI_API_KEY,
|
|
|
|
+ url=app.state.OPENAI_API_BASE_URL,
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ return query_embeddings_doc(
|
|
|
|
+ collection_name=form_data.collection_name,
|
|
|
|
+ query=form_data.query,
|
|
|
|
+ query_embeddings=query_embeddings,
|
|
|
|
+ k=form_data.k if form_data.k else app.state.TOP_K,
|
|
|
|
+ )
|
|
|
|
+
|
|
except Exception as e:
|
|
except Exception as e:
|
|
log.exception(e)
|
|
log.exception(e)
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
@@ -348,37 +335,32 @@ def query_collection_handler(
|
|
):
|
|
):
|
|
try:
|
|
try:
|
|
if app.state.RAG_EMBEDDING_ENGINE == "":
|
|
if app.state.RAG_EMBEDDING_ENGINE == "":
|
|
- return query_collection(
|
|
|
|
- collection_names=form_data.collection_names,
|
|
|
|
- query=form_data.query,
|
|
|
|
- k=form_data.k if form_data.k else app.state.TOP_K,
|
|
|
|
- embedding_function=app.state.sentence_transformer_ef,
|
|
|
|
- )
|
|
|
|
- else:
|
|
|
|
-
|
|
|
|
- if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
|
|
|
- query_embeddings = generate_ollama_embeddings(
|
|
|
|
- GenerateEmbeddingsForm(
|
|
|
|
- **{
|
|
|
|
- "model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
- "prompt": form_data.query,
|
|
|
|
- }
|
|
|
|
- )
|
|
|
|
|
|
+ query_embeddings = app.state.sentence_transformer_ef.encode(
|
|
|
|
+ form_data.query
|
|
|
|
+ ).tolist()
|
|
|
|
+ elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
|
|
|
+ query_embeddings = generate_ollama_embeddings(
|
|
|
|
+ GenerateEmbeddingsForm(
|
|
|
|
+ **{
|
|
|
|
+ "model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
+ "prompt": form_data.query,
|
|
|
|
+ }
|
|
)
|
|
)
|
|
- elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
|
|
|
- query_embeddings = generate_openai_embeddings(
|
|
|
|
- model=app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
- text=form_data.query,
|
|
|
|
- key=app.state.OPENAI_API_KEY,
|
|
|
|
- url=app.state.OPENAI_API_BASE_URL,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- return query_embeddings_collection(
|
|
|
|
- collection_names=form_data.collection_names,
|
|
|
|
- query_embeddings=query_embeddings,
|
|
|
|
- k=form_data.k if form_data.k else app.state.TOP_K,
|
|
|
|
|
|
+ )
|
|
|
|
+ elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
|
|
|
+ query_embeddings = generate_openai_embeddings(
|
|
|
|
+ model=app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
+ text=form_data.query,
|
|
|
|
+ key=app.state.OPENAI_API_KEY,
|
|
|
|
+ url=app.state.OPENAI_API_BASE_URL,
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ return query_embeddings_collection(
|
|
|
|
+ collection_names=form_data.collection_names,
|
|
|
|
+ query_embeddings=query_embeddings,
|
|
|
|
+ k=form_data.k if form_data.k else app.state.TOP_K,
|
|
|
|
+ )
|
|
|
|
+
|
|
except Exception as e:
|
|
except Exception as e:
|
|
log.exception(e)
|
|
log.exception(e)
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
@@ -445,6 +427,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
|
log.info(f"store_docs_in_vector_db {docs} {collection_name}")
|
|
log.info(f"store_docs_in_vector_db {docs} {collection_name}")
|
|
|
|
|
|
texts = [doc.page_content for doc in docs]
|
|
texts = [doc.page_content for doc in docs]
|
|
|
|
+ texts = list(map(lambda x: x.replace("\n", " "), texts))
|
|
|
|
+
|
|
metadatas = [doc.metadata for doc in docs]
|
|
metadatas = [doc.metadata for doc in docs]
|
|
|
|
|
|
try:
|
|
try:
|
|
@@ -454,52 +438,38 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
|
log.info(f"deleting existing collection {collection_name}")
|
|
log.info(f"deleting existing collection {collection_name}")
|
|
CHROMA_CLIENT.delete_collection(name=collection_name)
|
|
CHROMA_CLIENT.delete_collection(name=collection_name)
|
|
|
|
|
|
- if app.state.RAG_EMBEDDING_ENGINE == "":
|
|
|
|
-
|
|
|
|
- collection = CHROMA_CLIENT.create_collection(
|
|
|
|
- name=collection_name,
|
|
|
|
- embedding_function=app.state.sentence_transformer_ef,
|
|
|
|
- )
|
|
|
|
|
|
+ collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
|
|
|
|
|
- for batch in create_batches(
|
|
|
|
- api=CHROMA_CLIENT,
|
|
|
|
- ids=[str(uuid.uuid1()) for _ in texts],
|
|
|
|
- metadatas=metadatas,
|
|
|
|
- documents=texts,
|
|
|
|
- ):
|
|
|
|
- collection.add(*batch)
|
|
|
|
-
|
|
|
|
- else:
|
|
|
|
- collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
|
|
|
-
|
|
|
|
- if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
|
|
|
- embeddings = [
|
|
|
|
- generate_ollama_embeddings(
|
|
|
|
- GenerateEmbeddingsForm(
|
|
|
|
- **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
- for text in texts
|
|
|
|
- ]
|
|
|
|
- elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
|
|
|
- embeddings = [
|
|
|
|
- generate_openai_embeddings(
|
|
|
|
- model=app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
- text=text,
|
|
|
|
- key=app.state.OPENAI_API_KEY,
|
|
|
|
- url=app.state.OPENAI_API_BASE_URL,
|
|
|
|
|
|
+ if app.state.RAG_EMBEDDING_ENGINE == "":
|
|
|
|
+ embeddings = app.state.sentence_transformer_ef.encode(texts).tolist()
|
|
|
|
+ elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
|
|
|
+ embeddings = [
|
|
|
|
+ generate_ollama_embeddings(
|
|
|
|
+ GenerateEmbeddingsForm(
|
|
|
|
+ **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
|
|
)
|
|
)
|
|
- for text in texts
|
|
|
|
- ]
|
|
|
|
-
|
|
|
|
- for batch in create_batches(
|
|
|
|
- api=CHROMA_CLIENT,
|
|
|
|
- ids=[str(uuid.uuid1()) for _ in texts],
|
|
|
|
- metadatas=metadatas,
|
|
|
|
- embeddings=embeddings,
|
|
|
|
- documents=texts,
|
|
|
|
- ):
|
|
|
|
- collection.add(*batch)
|
|
|
|
|
|
+ )
|
|
|
|
+ for text in texts
|
|
|
|
+ ]
|
|
|
|
+ elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
|
|
|
+ embeddings = [
|
|
|
|
+ generate_openai_embeddings(
|
|
|
|
+ model=app.state.RAG_EMBEDDING_MODEL,
|
|
|
|
+ text=text,
|
|
|
|
+ key=app.state.OPENAI_API_KEY,
|
|
|
|
+ url=app.state.OPENAI_API_BASE_URL,
|
|
|
|
+ )
|
|
|
|
+ for text in texts
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ for batch in create_batches(
|
|
|
|
+ api=CHROMA_CLIENT,
|
|
|
|
+ ids=[str(uuid.uuid1()) for _ in texts],
|
|
|
|
+ metadatas=metadatas,
|
|
|
|
+ embeddings=embeddings,
|
|
|
|
+ documents=texts,
|
|
|
|
+ ):
|
|
|
|
+ collection.add(*batch)
|
|
|
|
|
|
return True
|
|
return True
|
|
except Exception as e:
|
|
except Exception as e:
|