Jannik Streidl преди 1 година
родител
ревизия
bc3dd34d8b
променени са 1 файла, в които са добавени 10 реда и са изтрити 8 реда
  1. 10 8
      backend/apps/rag/main.py

+ 10 - 8
backend/apps/rag/main.py

@@ -29,11 +29,13 @@ from langchain_community.document_loaders import (
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
+
 from pydantic import BaseModel
 from typing import Optional
 
 import uuid
 
+
 from utils.misc import calculate_sha256, calculate_sha256_string
 from utils.utils import get_current_user, get_admin_user
 from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
@@ -113,12 +115,12 @@ def query_doc(
         # if you use docker use the model from the environment variable
             collection = CHROMA_CLIENT.get_collection(
                 name=form_data.collection_name,
-                embedding_function=sentence_transformer_ef
+                embedding_function=sentence_transformer_ef,
             )
         else:
         # for local development use the default model
-                collection = CHROMA_CLIENT.get_collection(
-                name=form_data.collection_name,
+            collection = CHROMA_CLIENT.get_collection(
+            name=form_data.collection_name,
             )
         result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
         return result
@@ -191,16 +193,16 @@ def query_collection(
     for collection_name in form_data.collection_names:
         try:
             if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
-             # if you use docker use the model from the environment variable
+            # if you use docker use the model from the environment variable
                 collection = CHROMA_CLIENT.get_collection(
-                    name=form_data.collection_name,
-                    embedding_function=sentence_transformer_ef
+                    name=collection_name,
+                    embedding_function=sentence_transformer_ef,
                 )
             else:
             # for local development use the default model
                 collection = CHROMA_CLIENT.get_collection(
-                name=form_data.collection_name,
-            )
+                name=collection_name,
+                )
                 
             result = collection.query(
                 query_texts=[form_data.query], n_results=form_data.k