|
@@ -1,6 +1,5 @@
|
|
|
from fastapi import (
|
|
|
FastAPI,
|
|
|
- Request,
|
|
|
Depends,
|
|
|
HTTPException,
|
|
|
status,
|
|
@@ -12,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
|
import os, shutil
|
|
|
from typing import List
|
|
|
|
|
|
-# from chromadb.utils import embedding_functions
|
|
|
+from chromadb.utils import embedding_functions
|
|
|
|
|
|
from langchain_community.document_loaders import (
|
|
|
WebBaseLoader,
|
|
@@ -28,24 +27,19 @@ from langchain_community.document_loaders import (
|
|
|
UnstructuredExcelLoader,
|
|
|
)
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
-from langchain_community.vectorstores import Chroma
|
|
|
-from langchain.chains import RetrievalQA
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
from typing import Optional
|
|
|
|
|
|
import uuid
|
|
|
-import time
|
|
|
|
|
|
from utils.misc import calculate_sha256, calculate_sha256_string
|
|
|
from utils.utils import get_current_user, get_admin_user
|
|
|
-from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
|
|
|
+from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
|
|
|
from constants import ERROR_MESSAGES
|
|
|
|
|
|
-# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
|
-# model_name=EMBED_MODEL
|
|
|
-# )
|
|
|
+sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL)
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
@@ -78,11 +72,17 @@ def store_data_in_vector_db(data, collection_name) -> bool:
|
|
|
metadatas = [doc.metadata for doc in docs]
|
|
|
|
|
|
try:
|
|
|
- collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
|
|
+ if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
|
|
|
+ # if you use docker use the model from the environment variable
|
|
|
+ collection = CHROMA_CLIENT.create_collection(name=collection_name, embedding_function=sentence_transformer_ef)
|
|
|
+
|
|
|
+ else:
|
|
|
+ # for local development use the default model
|
|
|
+ collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
|
|
|
|
|
collection.add(
|
|
|
- documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
|
|
|
- )
|
|
|
+ documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
|
|
|
+ )
|
|
|
return True
|
|
|
except Exception as e:
|
|
|
print(e)
|
|
@@ -109,9 +109,17 @@ def query_doc(
|
|
|
user=Depends(get_current_user),
|
|
|
):
|
|
|
try:
|
|
|
- collection = CHROMA_CLIENT.get_collection(
|
|
|
- name=form_data.collection_name,
|
|
|
- )
|
|
|
+ if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
|
|
|
+ # if you use docker use the model from the environment variable
|
|
|
+ collection = CHROMA_CLIENT.get_collection(
|
|
|
+ name=form_data.collection_name,
|
|
|
+ embedding_function=sentence_transformer_ef
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ # for local development use the default model
|
|
|
+ collection = CHROMA_CLIENT.get_collection(
|
|
|
+ name=form_data.collection_name,
|
|
|
+ )
|
|
|
result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
|
|
|
return result
|
|
|
except Exception as e:
|
|
@@ -182,9 +190,18 @@ def query_collection(
|
|
|
|
|
|
for collection_name in form_data.collection_names:
|
|
|
try:
|
|
|
- collection = CHROMA_CLIENT.get_collection(
|
|
|
- name=collection_name,
|
|
|
+ if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
|
|
|
+ # if you use docker use the model from the environment variable
|
|
|
+ collection = CHROMA_CLIENT.get_collection(
|
|
|
+ name=form_data.collection_name,
|
|
|
+ embedding_function=sentence_transformer_ef
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ # for local development use the default model
|
|
|
+ collection = CHROMA_CLIENT.get_collection(
|
|
|
+ name=form_data.collection_name,
|
|
|
)
|
|
|
+
|
|
|
result = collection.query(
|
|
|
query_texts=[form_data.query], n_results=form_data.k
|
|
|
)
|