|
@@ -1,6 +1,5 @@
|
|
|
from fastapi import (
|
|
|
FastAPI,
|
|
|
- Request,
|
|
|
Depends,
|
|
|
HTTPException,
|
|
|
status,
|
|
@@ -14,7 +13,8 @@ import os, shutil
|
|
|
from pathlib import Path
|
|
|
from typing import List
|
|
|
|
|
|
-# from chromadb.utils import embedding_functions
|
|
|
+from sentence_transformers import SentenceTransformer
|
|
|
+from chromadb.utils import embedding_functions
|
|
|
|
|
|
from langchain_community.document_loaders import (
|
|
|
WebBaseLoader,
|
|
@@ -30,16 +30,12 @@ from langchain_community.document_loaders import (
|
|
|
UnstructuredExcelLoader,
|
|
|
)
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
-from langchain.chains import RetrievalQA
|
|
|
-from langchain_community.vectorstores import Chroma
|
|
|
-
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
from typing import Optional
|
|
|
import mimetypes
|
|
|
import uuid
|
|
|
import json
|
|
|
-import time
|
|
|
|
|
|
|
|
|
from apps.web.models.documents import (
|
|
@@ -58,23 +54,37 @@ from utils.utils import get_current_user, get_admin_user
|
|
|
from config import (
|
|
|
UPLOAD_DIR,
|
|
|
DOCS_DIR,
|
|
|
- EMBED_MODEL,
|
|
|
+ RAG_EMBEDDING_MODEL,
|
|
|
+ RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
|
|
CHROMA_CLIENT,
|
|
|
CHUNK_SIZE,
|
|
|
CHUNK_OVERLAP,
|
|
|
RAG_TEMPLATE,
|
|
|
)
|
|
|
+
|
|
|
from constants import ERROR_MESSAGES
|
|
|
|
|
|
-# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
|
-# model_name=EMBED_MODEL
|
|
|
-# )
|
|
|
+#
|
|
|
+# if RAG_EMBEDDING_MODEL:
|
|
|
+# sentence_transformer_ef = SentenceTransformer(
|
|
|
+# model_name_or_path=RAG_EMBEDDING_MODEL,
|
|
|
+# cache_folder=RAG_EMBEDDING_MODEL_DIR,
|
|
|
+# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
|
|
+# )
|
|
|
+
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
app.state.CHUNK_SIZE = CHUNK_SIZE
|
|
|
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
|
|
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
|
|
+app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
|
|
+app.state.sentence_transformer_ef = (
|
|
|
+ embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
|
+ model_name=app.state.RAG_EMBEDDING_MODEL,
|
|
|
+ device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
|
|
+ )
|
|
|
+)
|
|
|
|
|
|
|
|
|
origins = ["*"]
|
|
@@ -106,7 +116,10 @@ def store_data_in_vector_db(data, collection_name) -> bool:
|
|
|
metadatas = [doc.metadata for doc in docs]
|
|
|
|
|
|
try:
|
|
|
- collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
|
|
+ collection = CHROMA_CLIENT.create_collection(
|
|
|
+ name=collection_name,
|
|
|
+ embedding_function=app.state.sentence_transformer_ef,
|
|
|
+ )
|
|
|
|
|
|
collection.add(
|
|
|
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
|
|
@@ -126,6 +139,38 @@ async def get_status():
|
|
|
"status": True,
|
|
|
"chunk_size": app.state.CHUNK_SIZE,
|
|
|
"chunk_overlap": app.state.CHUNK_OVERLAP,
|
|
|
+ "template": app.state.RAG_TEMPLATE,
|
|
|
+ "embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+@app.get("/embedding/model")
|
|
|
+async def get_embedding_model(user=Depends(get_admin_user)):
|
|
|
+ return {
|
|
|
+ "status": True,
|
|
|
+ "embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+class EmbeddingModelUpdateForm(BaseModel):
|
|
|
+ embedding_model: str
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/embedding/model/update")
|
|
|
+async def update_embedding_model(
|
|
|
+ form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
|
|
+):
|
|
|
+ app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
|
|
+ app.state.sentence_transformer_ef = (
|
|
|
+ embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
|
+ model_name=app.state.RAG_EMBEDDING_MODEL,
|
|
|
+ device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ return {
|
|
|
+ "status": True,
|
|
|
+ "embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
|
|
}
|
|
|
|
|
|
|
|
@@ -190,8 +235,10 @@ def query_doc(
|
|
|
user=Depends(get_current_user),
|
|
|
):
|
|
|
try:
|
|
|
+ # if you use docker use the model from the environment variable
|
|
|
collection = CHROMA_CLIENT.get_collection(
|
|
|
name=form_data.collection_name,
|
|
|
+ embedding_function=app.state.sentence_transformer_ef,
|
|
|
)
|
|
|
result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
|
|
|
return result
|
|
@@ -263,9 +310,12 @@ def query_collection(
|
|
|
|
|
|
for collection_name in form_data.collection_names:
|
|
|
try:
|
|
|
+ # if you use docker use the model from the environment variable
|
|
|
collection = CHROMA_CLIENT.get_collection(
|
|
|
name=collection_name,
|
|
|
+ embedding_function=app.state.sentence_transformer_ef,
|
|
|
)
|
|
|
+
|
|
|
result = collection.query(
|
|
|
query_texts=[form_data.query], n_results=form_data.k
|
|
|
)
|