فهرست منبع

feat: search query threshold

Timothy J. Baek 10 ماه پیش
والد
کامیت
8debb71197
3فایلهای تغییر یافته به همراه32 افزوده شده و 3 حذف شده
  1. 14 0
      backend/config.py
  2. 10 0
      backend/main.py
  3. 8 3
      src/lib/components/chat/Chat.svelte

+ 14 - 0
backend/config.py

@@ -618,6 +618,11 @@ ADMIN_EMAIL = PersistentConfig(
 )
 )
 
 
 
 
+####################################
+# TASKS
+####################################
+
+
 TASK_MODEL = PersistentConfig(
 TASK_MODEL = PersistentConfig(
     "TASK_MODEL",
     "TASK_MODEL",
     "task.model.default",
     "task.model.default",
@@ -664,6 +669,15 @@ Question:
 )
 )
 
 
 
 
+SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
+    "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
+    "task.search.prompt_length_threshold",
+    os.environ.get(
+        "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
+        100,
+    ),
+)
+
 ####################################
 ####################################
 # WEBUI_SECRET_KEY
 # WEBUI_SECRET_KEY
 ####################################
 ####################################

+ 10 - 0
backend/main.py

@@ -81,6 +81,7 @@ from config import (
     TASK_MODEL_EXTERNAL,
     TASK_MODEL_EXTERNAL,
     TITLE_GENERATION_PROMPT_TEMPLATE,
     TITLE_GENERATION_PROMPT_TEMPLATE,
     SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
     SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
+    SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
     AppConfig,
     AppConfig,
 )
 )
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
@@ -144,6 +145,9 @@ app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMP
 app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
 app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
     SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
     SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
 )
 )
+app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
+    SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
+)
 
 
 app.state.MODELS = {}
 app.state.MODELS = {}
 
 
@@ -596,6 +600,12 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
 async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
 async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
     print("generate_search_query")
     print("generate_search_query")
 
 
+    if len(form_data["prompt"]) < app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=f"Skip search query generation for short prompts (< {app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD} characters)",
+        )
+
     model_id = form_data["model"]
     model_id = form_data["model"]
     if model_id not in app.state.MODELS:
     if model_id not in app.state.MODELS:
         raise HTTPException(
         raise HTTPException(

+ 8 - 3
src/lib/components/chat/Chat.svelte

@@ -56,6 +56,7 @@
 	import Messages from '$lib/components/chat/Messages.svelte';
 	import Messages from '$lib/components/chat/Messages.svelte';
 	import Navbar from '$lib/components/layout/Navbar.svelte';
 	import Navbar from '$lib/components/layout/Navbar.svelte';
 	import CallOverlay from './MessageInput/CallOverlay.svelte';
 	import CallOverlay from './MessageInput/CallOverlay.svelte';
+	import { error } from '@sveltejs/kit';
 
 
 	const i18n: Writable<i18nType> = getContext('i18n');
 	const i18n: Writable<i18nType> = getContext('i18n');
 
 
@@ -506,7 +507,13 @@
 		messages = messages;
 		messages = messages;
 
 
 		const prompt = history.messages[parentId].content;
 		const prompt = history.messages[parentId].content;
-		let searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt);
+		let searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt).catch(
+			(error) => {
+				console.log(error);
+				return prompt;
+			}
+		);
+
 		if (!searchQuery) {
 		if (!searchQuery) {
 			toast.warning($i18n.t('No search query generated'));
 			toast.warning($i18n.t('No search query generated'));
 			responseMessage.status = {
 			responseMessage.status = {
@@ -516,8 +523,6 @@
 				description: 'No search query generated'
 				description: 'No search query generated'
 			};
 			};
 			messages = messages;
 			messages = messages;
-
-			searchQuery = prompt;
 		}
 		}
 
 
 		responseMessage.status = {
 		responseMessage.status = {