소스 검색

Merge pull request #1858 from buroa/buroa/fixes

fix: various rag api calls and ui cleanup
Timothy Jaeryang Baek 1 년 전
부모
커밋
f3199c6510

+ 6 - 6
backend/apps/rag/main.py

@@ -391,16 +391,16 @@ def query_doc_handler(
             return query_doc_with_hybrid_search(
                 collection_name=form_data.collection_name,
                 query=form_data.query,
-                embeddings_function=app.state.EMBEDDING_FUNCTION,
-                reranking_function=app.state.sentence_transformer_rf,
+                embedding_function=app.state.EMBEDDING_FUNCTION,
                 k=form_data.k if form_data.k else app.state.TOP_K,
+                reranking_function=app.state.sentence_transformer_rf,
                 r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
             )
         else:
             return query_doc(
                 collection_name=form_data.collection_name,
                 query=form_data.query,
-                embeddings_function=app.state.EMBEDDING_FUNCTION,
+                embedding_function=app.state.EMBEDDING_FUNCTION,
                 k=form_data.k if form_data.k else app.state.TOP_K,
             )
     except Exception as e:
@@ -429,16 +429,16 @@ def query_collection_handler(
             return query_collection_with_hybrid_search(
                 collection_names=form_data.collection_names,
                 query=form_data.query,
-                embeddings_function=app.state.EMBEDDING_FUNCTION,
-                reranking_function=app.state.sentence_transformer_rf,
+                embedding_function=app.state.EMBEDDING_FUNCTION,
                 k=form_data.k if form_data.k else app.state.TOP_K,
+                reranking_function=app.state.sentence_transformer_rf,
                 r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
             )
         else:
             return query_collection(
                 collection_names=form_data.collection_names,
                 query=form_data.query,
-                embeddings_function=app.state.EMBEDDING_FUNCTION,
+                embedding_function=app.state.EMBEDDING_FUNCTION,
                 k=form_data.k if form_data.k else app.state.TOP_K,
             )
 

+ 16 - 14
backend/apps/rag/utils.py

@@ -35,6 +35,7 @@ def query_doc(
     try:
         collection = CHROMA_CLIENT.get_collection(name=collection_name)
         query_embeddings = embedding_function(query)
+
         result = collection.query(
             query_embeddings=[query_embeddings],
             n_results=k,
@@ -76,9 +77,9 @@ def query_doc_with_hybrid_search(
 
         compressor = RerankCompressor(
             embedding_function=embedding_function,
+            top_n=k,
             reranking_function=reranking_function,
             r_score=r,
-            top_n=k,
         )
 
         compression_retriever = ContextualCompressionRetriever(
@@ -91,6 +92,7 @@ def query_doc_with_hybrid_search(
             "documents": [[d.page_content for d in result]],
             "metadatas": [[d.metadata for d in result]],
         }
+
         log.info(f"query_doc_with_hybrid_search:result {result}")
         return result
     except Exception as e:
@@ -167,7 +169,6 @@ def query_collection_with_hybrid_search(
     reranking_function,
     r: float,
 ):
-
     results = []
     for collection_name in collection_names:
         try:
@@ -182,7 +183,6 @@ def query_collection_with_hybrid_search(
             results.append(result)
         except:
             pass
-
     return merge_and_sort_query_results(results, k=k, reverse=True)
 
 
@@ -443,13 +443,15 @@ class ChromaRetriever(BaseRetriever):
         metadatas = results["metadatas"][0]
         documents = results["documents"][0]
 
-        return [
-            Document(
-                metadata=metadatas[idx],
-                page_content=documents[idx],
+        results = []
+        for idx in range(len(ids)):
+            results.append(
+                Document(
+                    metadata=metadatas[idx],
+                    page_content=documents[idx],
+                )
             )
-            for idx in range(len(ids))
-        ]
+        return results
 
 
 import operator
@@ -465,9 +467,9 @@ from sentence_transformers import util
 
 class RerankCompressor(BaseDocumentCompressor):
     embedding_function: Any
+    top_n: int
     reranking_function: Any
     r_score: float
-    top_n: int
 
     class Config:
         extra = Extra.forbid
@@ -479,7 +481,9 @@ class RerankCompressor(BaseDocumentCompressor):
         query: str,
         callbacks: Optional[Callbacks] = None,
     ) -> Sequence[Document]:
-        if self.reranking_function:
+        reranking = self.reranking_function is not None
+
+        if reranking:
             scores = self.reranking_function.predict(
                 [(query, doc.page_content) for doc in documents]
             )
@@ -496,9 +500,7 @@ class RerankCompressor(BaseDocumentCompressor):
                 (d, s) for d, s in docs_with_scores if s >= self.r_score
             ]
 
-        reverse = self.reranking_function is not None
-        result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=reverse)
-
+        result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
         final_results = []
         for doc, doc_score in result[: self.top_n]:
             metadata = doc.metadata

+ 4 - 1
src/lib/components/admin/UserChatsModal.svelte

@@ -133,7 +133,10 @@
 						{/each} -->
 					</div>
 				{:else}
-					<div class="text-left text-sm w-full mb-8">{user.name} {$i18n.t('has no conversations.')}</div>
+					<div class="text-left text-sm w-full mb-8">
+						{user.name}
+						{$i18n.t('has no conversations.')}
+					</div>
 				{/if}
 			</div>
 		</div>

+ 27 - 14
src/lib/components/documents/Settings/General.svelte

@@ -137,9 +137,15 @@
 		if (res) {
 			console.log('rerankingModelUpdateHandler:', res);
 			if (res.status === true) {
-				toast.success($i18n.t('Reranking model set to "{{reranking_model}}"', res), {
-					duration: 1000 * 10
-				});
+				if (rerankingModel === '') {
+					toast.success($i18n.t('Reranking model disabled', res), {
+						duration: 1000 * 10
+					});
+				} else {
+					toast.success($i18n.t('Reranking model set to "{{reranking_model}}"', res), {
+						duration: 1000 * 10
+					});
+				}
 			}
 		}
 	};
@@ -584,12 +590,12 @@
 
 				<hr class=" dark:border-gray-700 my-3" />
 
-				<div>
+				<div class=" ">
 					<div class=" text-sm font-medium">{$i18n.t('Query Params')}</div>
 
 					<div class=" flex">
 						<div class="  flex w-full justify-between">
-							<div class="self-center text-xs font-medium flex-1">{$i18n.t('Top K')}</div>
+							<div class="self-center text-xs font-medium min-w-fit">{$i18n.t('Top K')}</div>
 
 							<div class="self-center p-3">
 								<input
@@ -602,13 +608,11 @@
 								/>
 							</div>
 						</div>
-					</div>
 
-					{#if querySettings.hybrid === true}
-						<div class=" flex">
-							<div class="  flex w-full justify-between">
-								<div class="self-center text-xs font-medium flex-1">
-									{$i18n.t('Relevance Threshold')}
+						{#if querySettings.hybrid === true}
+							<div class="flex w-full">
+								<div class=" self-center text-xs font-medium min-w-fit">
+									{$i18n.t('Minimum Score')}
 								</div>
 
 								<div class="self-center p-3">
@@ -616,14 +620,25 @@
 										class=" w-full rounded-lg py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
 										type="number"
 										step="0.01"
-										placeholder={$i18n.t('Enter Relevance Threshold')}
+										placeholder={$i18n.t('Enter Score')}
 										bind:value={querySettings.r}
 										autocomplete="off"
 										min="0.0"
+										title={$i18n.t('The score should be a value between 0.0 (0%) and 1.0 (100%).')}
 									/>
 								</div>
 							</div>
+						{/if}
+					</div>
+
+					{#if querySettings.hybrid === true}
+						<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
+							{$i18n.t(
+								'Note: If you set a minimum score, the search will only return documents with a score greater than or equal to the minimum score.'
+							)}
 						</div>
+
+						<hr class=" dark:border-gray-700 my-3" />
 					{/if}
 
 					<div>
@@ -636,8 +651,6 @@
 					</div>
 				</div>
 
-				<hr class=" dark:border-gray-700 my-3" />
-
 				{#if showResetConfirm}
 					<div class="flex justify-between rounded-md items-center py-2 px-3.5 w-full transition">
 						<div class="flex items-center space-x-3">

+ 1 - 1
src/lib/components/layout/Sidebar/ChatMenu.svelte

@@ -1,7 +1,7 @@
 <script lang="ts">
 	import { DropdownMenu } from 'bits-ui';
 	import { flyAndScale } from '$lib/utils/transitions';
-	import { getContext } from 'svelte'
+	import { getContext } from 'svelte';
 
 	import Dropdown from '$lib/components/common/Dropdown.svelte';
 	import GarbageBin from '$lib/components/icons/GarbageBin.svelte';

+ 1 - 1
src/lib/i18n/locales/de-DE/translation.json

@@ -291,7 +291,7 @@
 	"Profile Image": "Profilbild",
 	"Prompt (e.g. Tell me a fun fact about the Roman Empire)": "Prompt (z.B. Erzähle mir eine interessante Tatsache über das Römische Reich.",
 	"Playground": "Spielplatz",
-									   
+
 	"Profile": "Profil",
 	"Prompt Content": "Prompt-Inhalt",
 	"Prompt suggestions": "Prompt-Vorschläge",