|
@@ -29,11 +29,13 @@ from langchain_community.document_loaders import (
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
|
|
|
|
|
+
|
|
|
from pydantic import BaseModel
|
|
|
from typing import Optional
|
|
|
|
|
|
import uuid
|
|
|
|
|
|
+
|
|
|
from utils.misc import calculate_sha256, calculate_sha256_string
|
|
|
from utils.utils import get_current_user, get_admin_user
|
|
|
from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
|
|
@@ -113,12 +115,12 @@ def query_doc(
|
|
|
# if you use docker use the model from the environment variable
|
|
|
collection = CHROMA_CLIENT.get_collection(
|
|
|
name=form_data.collection_name,
|
|
|
- embedding_function=sentence_transformer_ef
|
|
|
+ embedding_function=sentence_transformer_ef,
|
|
|
)
|
|
|
else:
|
|
|
# for local development use the default model
|
|
|
- collection = CHROMA_CLIENT.get_collection(
|
|
|
- name=form_data.collection_name,
|
|
|
+ collection = CHROMA_CLIENT.get_collection(
|
|
|
+ name=form_data.collection_name,
|
|
|
)
|
|
|
result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
|
|
|
return result
|
|
@@ -191,16 +193,16 @@ def query_collection(
|
|
|
for collection_name in form_data.collection_names:
|
|
|
try:
|
|
|
if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
|
|
|
- # if you use docker use the model from the environment variable
|
|
|
+ # if you use docker use the model from the environment variable
|
|
|
collection = CHROMA_CLIENT.get_collection(
|
|
|
- name=form_data.collection_name,
|
|
|
- embedding_function=sentence_transformer_ef
|
|
|
+ name=collection_name,
|
|
|
+ embedding_function=sentence_transformer_ef,
|
|
|
)
|
|
|
else:
|
|
|
# for local development use the default model
|
|
|
collection = CHROMA_CLIENT.get_collection(
|
|
|
- name=form_data.collection_name,
|
|
|
- )
|
|
|
+ name=collection_name,
|
|
|
+ )
|
|
|
|
|
|
result = collection.query(
|
|
|
query_texts=[form_data.query], n_results=form_data.k
|