Bläddra i källkod

enh: retrieval query generation

Timothy Jaeryang Baek 5 månader sedan
förälder
incheckning
dbb67a12ca

+ 32 - 33
backend/open_webui/apps/retrieval/utils.py

@@ -177,35 +177,34 @@ def merge_and_sort_query_results(
 
 def query_collection(
     collection_names: list[str],
-    query: str,
+    queries: list[str],
     embedding_function,
     k: int,
 ) -> dict:
-
     results = []
-    query_embedding = embedding_function(query)
-
-    for collection_name in collection_names:
-        if collection_name:
-            try:
-                result = query_doc(
-                    collection_name=collection_name,
-                    k=k,
-                    query_embedding=query_embedding,
-                )
-                if result is not None:
-                    results.append(result.model_dump())
-            except Exception as e:
-                log.exception(f"Error when querying the collection: {e}")
-        else:
-            pass
+    for query in queries:
+        query_embedding = embedding_function(query)
+        for collection_name in collection_names:
+            if collection_name:
+                try:
+                    result = query_doc(
+                        collection_name=collection_name,
+                        k=k,
+                        query_embedding=query_embedding,
+                    )
+                    if result is not None:
+                        results.append(result.model_dump())
+                except Exception as e:
+                    log.exception(f"Error when querying the collection: {e}")
+            else:
+                pass
 
     return merge_and_sort_query_results(results, k=k)
 
 
 def query_collection_with_hybrid_search(
     collection_names: list[str],
-    query: str,
+    queries: list[str],
     embedding_function,
     k: int,
     reranking_function,
@@ -215,15 +214,16 @@ def query_collection_with_hybrid_search(
     error = False
     for collection_name in collection_names:
         try:
-            result = query_doc_with_hybrid_search(
-                collection_name=collection_name,
-                query=query,
-                embedding_function=embedding_function,
-                k=k,
-                reranking_function=reranking_function,
-                r=r,
-            )
-            results.append(result)
+            for query in queries:
+                result = query_doc_with_hybrid_search(
+                    collection_name=collection_name,
+                    query=query,
+                    embedding_function=embedding_function,
+                    k=k,
+                    reranking_function=reranking_function,
+                    r=r,
+                )
+                results.append(result)
         except Exception as e:
             log.exception(
                 "Error when querying the collection with " f"hybrid_search: {e}"
@@ -309,15 +309,14 @@ def get_embedding_function(
 
 def get_rag_context(
     files,
-    messages,
+    queries,
     embedding_function,
     k,
     reranking_function,
     r,
     hybrid_search,
 ):
-    log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
-    query = get_last_user_message(messages)
+    log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
 
     extracted_collections = []
     relevant_contexts = []
@@ -359,7 +358,7 @@ def get_rag_context(
                         try:
                             context = query_collection_with_hybrid_search(
                                 collection_names=collection_names,
-                                query=query,
+                                queries=queries,
                                 embedding_function=embedding_function,
                                 k=k,
                                 reranking_function=reranking_function,
@@ -374,7 +373,7 @@ def get_rag_context(
                     if (not hybrid_search) or (context is None):
                         context = query_collection(
                             collection_names=collection_names,
-                            query=query,
+                            queries=queries,
                             embedding_function=embedding_function,
                             k=k,
                         )

+ 38 - 29
backend/open_webui/config.py

@@ -941,19 +941,49 @@ ENABLE_TAGS_GENERATION = PersistentConfig(
     os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true",
 )
 
-ENABLE_SEARCH_QUERY = PersistentConfig(
-    "ENABLE_SEARCH_QUERY",
-    "task.search.enable",
-    os.environ.get("ENABLE_SEARCH_QUERY", "True").lower() == "true",
+
+ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig(
+    "ENABLE_SEARCH_QUERY_GENERATION",
+    "task.query.search.enable",
+    os.environ.get("ENABLE_SEARCH_QUERY_GENERATION", "True").lower() == "true",
+)
+
+ENABLE_RETRIEVAL_QUERY_GENERATION = PersistentConfig(
+    "ENABLE_RETRIEVAL_QUERY_GENERATION",
+    "task.query.retrieval.enable",
+    os.environ.get("ENABLE_RETRIEVAL_QUERY_GENERATION", "True").lower() == "true",
 )
 
 
-SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
-    "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
-    "task.search.prompt_template",
-    os.environ.get("SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", ""),
+QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
+    "QUERY_GENERATION_PROMPT_TEMPLATE",
+    "task.query.prompt_template",
+    os.environ.get("QUERY_GENERATION_PROMPT_TEMPLATE", ""),
 )
 
+DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE = """### Task:
+Based on the chat history, determine whether a search is necessary, and if so, generate a 1-3 broad search queries to retrieve comprehensive and updated information. If no search is required, return an empty list.
+
+### Guidelines:
+- Respond exclusively with a JSON object.
+- If a search query is needed, return an object like: { "queries": ["query1", "query2"] } where each query is distinct and concise.
+- If no search query is necessary, output should be: { "queries": [] }
+- Default to suggesting a search query to ensure accurate and updated information, unless it is definitively clear no search is required.
+- Be concise, focusing strictly on composing search queries with no additional commentary or text.
+- When in doubt, prefer to suggest a search for comprehensiveness.
+- Today's date is: {{CURRENT_DATE}}
+
+### Output:
+JSON format: {
+  "queries": ["query1", "query2"]
+}
+
+### Chat History:
+<chat_history>
+{{MESSAGES:END:6}}
+</chat_history>
+"""
+
 
 TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
     "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
@@ -1127,27 +1157,6 @@ RAG_TEXT_SPLITTER = PersistentConfig(
 )
 
 
-ENABLE_RAG_QUERY_GENERATION = PersistentConfig(
-    "ENABLE_RAG_QUERY_GENERATION",
-    "rag.query_generation.enable",
-    os.environ.get("ENABLE_RAG_QUERY_GENERATION", "False").lower() == "true",
-)
-
-DEFAULT_RAG_QUERY_GENERATION_TEMPLATE = """Given the user's message and interaction history, decide if a file search is necessary. You must be concise and exclusively provide a search query if one is necessary. Refrain from verbose responses or any additional commentary. Prefer suggesting a search if uncertain to provide comprehensive or updated information. If a search isn't needed at all, respond with an empty string. Default to a search query when in doubt.
-User Message:
-{{prompt:end:4000}}
-Interaction History:
-{{MESSAGES:END:6}}
-Search Query:"""
-
-
-RAG_QUERY_GENERATION_TEMPLATE = PersistentConfig(
-    "RAG_QUERY_GENERATION_TEMPLATE",
-    "rag.query_generation.template",
-    os.environ.get("RAG_QUERY_GENERATION_TEMPLATE", ""),
-)
-
-
 TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken")
 TIKTOKEN_ENCODING_NAME = PersistentConfig(
     "TIKTOKEN_ENCODING_NAME",

+ 75 - 46
backend/open_webui/main.py

@@ -78,11 +78,13 @@ from open_webui.config import (
     ENV,
     FRONTEND_BUILD_DIR,
     OAUTH_PROVIDERS,
-    ENABLE_SEARCH_QUERY,
-    SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
     STATIC_DIR,
     TASK_MODEL,
     TASK_MODEL_EXTERNAL,
+    ENABLE_SEARCH_QUERY_GENERATION,
+    ENABLE_RETRIEVAL_QUERY_GENERATION,
+    QUERY_GENERATION_PROMPT_TEMPLATE,
+    DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
     TITLE_GENERATION_PROMPT_TEMPLATE,
     TAGS_GENERATION_PROMPT_TEMPLATE,
     TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
@@ -122,7 +124,7 @@ from open_webui.utils.security_headers import SecurityHeadersMiddleware
 from open_webui.utils.task import (
     moa_response_generation_template,
     tags_generation_template,
-    search_query_generation_template,
+    query_generation_template,
     emoji_generation_template,
     title_generation_template,
     tools_function_calling_generation_template,
@@ -206,10 +208,9 @@ app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
 app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
 
 
-app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY
-app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
-    SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
-)
+app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION
+app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
+app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE
 
 app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
     TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
@@ -492,14 +493,41 @@ async def chat_completion_tools_handler(
     return body, {"contexts": contexts, "citations": citations}
 
 
-async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
+async def chat_completion_files_handler(
+    body: dict, user: UserModel
+) -> tuple[dict, dict[str, list]]:
     contexts = []
     citations = []
 
+    try:
+        queries_response = await generate_queries(
+            {
+                "model": body["model"],
+                "messages": body["messages"],
+                "type": "retrieval",
+            },
+            user,
+        )
+        queries_response = queries_response["choices"][0]["message"]["content"]
+
+        try:
+            queries_response = json.loads(queries_response)
+        except Exception as e:
+            queries_response = {"queries": []}
+
+        queries = queries_response.get("queries", [])
+    except Exception as e:
+        queries = []
+
+    if len(queries) == 0:
+        queries = [get_last_user_message(body["messages"])]
+
+    print(f"{queries=}")
+
     if files := body.get("metadata", {}).get("files", None):
         contexts, citations = get_rag_context(
             files=files,
-            messages=body["messages"],
+            queries=queries,
             embedding_function=retrieval_app.state.EMBEDDING_FUNCTION,
             k=retrieval_app.state.config.TOP_K,
             reranking_function=retrieval_app.state.sentence_transformer_rf,
@@ -643,7 +671,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             log.exception(e)
 
         try:
-            body, flags = await chat_completion_files_handler(body)
+            body, flags = await chat_completion_files_handler(body, user)
             contexts.extend(flags.get("contexts", []))
             citations.extend(flags.get("citations", []))
         except Exception as e:
@@ -1579,8 +1607,9 @@ async def get_task_config(user=Depends(get_verified_user)):
         "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
         "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
         "ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION,
-        "ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
-        "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
+        "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
+        "ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
+        "QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
         "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
     }
 
@@ -1591,8 +1620,9 @@ class TaskConfigForm(BaseModel):
     TITLE_GENERATION_PROMPT_TEMPLATE: str
     TAGS_GENERATION_PROMPT_TEMPLATE: str
     ENABLE_TAGS_GENERATION: bool
-    SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
-    ENABLE_SEARCH_QUERY: bool
+    ENABLE_SEARCH_QUERY_GENERATION: bool
+    ENABLE_RETRIEVAL_QUERY_GENERATION: bool
+    QUERY_GENERATION_PROMPT_TEMPLATE: str
     TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
 
 
@@ -1607,11 +1637,16 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
         form_data.TAGS_GENERATION_PROMPT_TEMPLATE
     )
     app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
+    app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
+        form_data.ENABLE_SEARCH_QUERY_GENERATION
+    )
+    app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
+        form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
+    )
 
-    app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
-        form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
+    app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
+        form_data.QUERY_GENERATION_PROMPT_TEMPLATE
     )
-    app.state.config.ENABLE_SEARCH_QUERY = form_data.ENABLE_SEARCH_QUERY
     app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
         form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
     )
@@ -1622,8 +1657,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
         "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
         "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
         "ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION,
-        "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
-        "ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
+        "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
+        "ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
+        "QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
         "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
     }
 
@@ -1799,14 +1835,22 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
     return await generate_chat_completions(form_data=payload, user=user)
 
 
-@app.post("/api/task/query/completions")
-async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
-    print("generate_search_query")
-    if not app.state.config.ENABLE_SEARCH_QUERY:
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=f"Search query generation is disabled",
-        )
+@app.post("/api/task/queries/completions")
+async def generate_queries(form_data: dict, user=Depends(get_verified_user)):
+    print("generate_queries")
+    type = form_data.get("type")
+    if type == "web_search":
+        if not app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=f"Search query generation is disabled",
+            )
+    elif type == "retrieval":
+        if not app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=f"Query generation is disabled",
+            )
 
     model_list = await get_all_models()
     models = {model["id"]: model for model in model_list}
@@ -1830,20 +1874,12 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
 
     model = models[task_model_id]
 
-    if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "":
-        template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
+    if app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE != "":
+        template = app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
     else:
-        template = """Given the user's message and interaction history, decide if a web search is necessary. You must be concise and exclusively provide a search query if one is necessary. Refrain from verbose responses or any additional commentary. Prefer suggesting a search if uncertain to provide comprehensive or updated information. If a search isn't needed at all, respond with an empty string. Default to a search query when in doubt. Today's date is {{CURRENT_DATE}}.
-
-User Message:
-{{prompt:end:4000}}
+        template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
 
-Interaction History:
-{{MESSAGES:END:6}}
-
-Search Query:"""
-
-    content = search_query_generation_template(
+    content = query_generation_template(
         template, form_data["messages"], {"name": user.name}
     )
 
@@ -1851,13 +1887,6 @@ Search Query:"""
         "model": task_model_id,
         "messages": [{"role": "user", "content": content}],
         "stream": False,
-        **(
-            {"max_tokens": 30}
-            if models[task_model_id]["owned_by"] == "ollama"
-            else {
-                "max_completion_tokens": 30,
-            }
-        ),
         "metadata": {"task": str(TASKS.QUERY_GENERATION), "task_body": form_data},
     }
     log.debug(payload)

+ 1 - 1
backend/open_webui/utils/task.py

@@ -163,7 +163,7 @@ def emoji_generation_template(
     return template
 
 
-def search_query_generation_template(
+def query_generation_template(
     template: str, messages: list[dict], user: Optional[dict] = None
 ) -> str:
     prompt = get_last_user_message(messages)

+ 40 - 5
src/lib/apis/index.ts

@@ -348,15 +348,16 @@ export const generateEmoji = async (
 	return null;
 };
 
-export const generateSearchQuery = async (
+export const generateQueries = async (
 	token: string = '',
 	model: string,
 	messages: object[],
-	prompt: string
+	prompt: string,
+	type?: string = 'web_search'
 ) => {
 	let error = null;
 
-	const res = await fetch(`${WEBUI_BASE_URL}/api/task/query/completions`, {
+	const res = await fetch(`${WEBUI_BASE_URL}/api/task/queries/completions`, {
 		method: 'POST',
 		headers: {
 			Accept: 'application/json',
@@ -366,7 +367,8 @@ export const generateSearchQuery = async (
 		body: JSON.stringify({
 			model: model,
 			messages: messages,
-			prompt: prompt
+			prompt: prompt,
+			type: type
 		})
 	})
 		.then(async (res) => {
@@ -385,7 +387,40 @@ export const generateSearchQuery = async (
 		throw error;
 	}
 
-	return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt;
+
+	try {
+		// Step 1: Safely extract the response string
+		const response = res?.choices[0]?.message?.content ?? '';
+
+		// Step 2: Attempt to fix common JSON format issues like single quotes
+		const sanitizedResponse = response.replace(/['‘’`]/g, '"'); // Convert single quotes to double quotes for valid JSON
+
+		// Step 3: Find the relevant JSON block within the response
+		const jsonStartIndex = sanitizedResponse.indexOf('{');
+		const jsonEndIndex = sanitizedResponse.lastIndexOf('}');
+
+		// Step 4: Check if we found a valid JSON block (with both `{` and `}`)
+		if (jsonStartIndex !== -1 && jsonEndIndex !== -1) {
+			const jsonResponse = sanitizedResponse.substring(jsonStartIndex, jsonEndIndex + 1);
+
+			// Step 5: Parse the JSON block
+			const parsed = JSON.parse(jsonResponse);
+
+			// Step 6: If there's a "queries" key, return the queries array; otherwise, return an empty array
+			if (parsed && parsed.queries) {
+				return Array.isArray(parsed.queries) ? parsed.queries : [];
+			} else {
+				return [];
+			}
+		}
+
+		// If no valid JSON block found, return an empty array
+		return [];
+	} catch (e) {
+		// Catch and safely return empty array on any parsing errors
+		console.error('Failed to parse response: ', e);
+		return [];
+	}
 };
 
 export const generateMoACompletion = async (

+ 24 - 19
src/lib/components/admin/Settings/Interface.svelte

@@ -26,8 +26,9 @@
 		TITLE_GENERATION_PROMPT_TEMPLATE: '',
 		TAGS_GENERATION_PROMPT_TEMPLATE: '',
 		ENABLE_TAGS_GENERATION: true,
-		ENABLE_SEARCH_QUERY: true,
-		SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: ''
+		ENABLE_SEARCH_QUERY_GENERATION: true,
+		ENABLE_RETRIEVAL_QUERY_GENERATION: true,
+		QUERY_GENERATION_PROMPT_TEMPLATE: ''
 	};
 
 	let promptSuggestions = [];
@@ -164,31 +165,35 @@
 
 			<hr class=" dark:border-gray-850 my-3" />
 
+			<div class="my-3 flex w-full items-center justify-between">
+				<div class=" self-center text-xs font-medium">
+					{$i18n.t('Enable Retrieval Query Generation')}
+				</div>
+
+				<Switch bind:state={taskConfig.ENABLE_RETRIEVAL_QUERY_GENERATION} />
+			</div>
+
 			<div class="my-3 flex w-full items-center justify-between">
 				<div class=" self-center text-xs font-medium">
 					{$i18n.t('Enable Web Search Query Generation')}
 				</div>
 
-				<Switch bind:state={taskConfig.ENABLE_SEARCH_QUERY} />
+				<Switch bind:state={taskConfig.ENABLE_SEARCH_QUERY_GENERATION} />
 			</div>
 
-			{#if taskConfig.ENABLE_SEARCH_QUERY}
-				<div class="">
-					<div class=" mb-2.5 text-xs font-medium">{$i18n.t('Search Query Generation Prompt')}</div>
+			<div class="">
+				<div class=" mb-2.5 text-xs font-medium">{$i18n.t('Query Generation Prompt')}</div>
 
-					<Tooltip
-						content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
-						placement="top-start"
-					>
-						<Textarea
-							bind:value={taskConfig.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE}
-							placeholder={$i18n.t(
-								'Leave empty to use the default prompt, or enter a custom prompt'
-							)}
-						/>
-					</Tooltip>
-				</div>
-			{/if}
+				<Tooltip
+					content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
+					placement="top-start"
+				>
+					<Textarea
+						bind:value={taskConfig.QUERY_GENERATION_PROMPT_TEMPLATE}
+						placeholder={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
+					/>
+				</Tooltip>
+			</div>
 		</div>
 
 		<hr class=" dark:border-gray-850 my-3" />

+ 6 - 4
src/lib/components/chat/Chat.svelte

@@ -66,7 +66,7 @@
 	import {
 		chatCompleted,
 		generateTitle,
-		generateSearchQuery,
+		generateQueries,
 		chatAction,
 		generateMoACompletion,
 		generateTags
@@ -2047,17 +2047,17 @@
 		history.messages[responseMessageId] = responseMessage;
 
 		const prompt = userMessage.content;
-		let searchQuery = await generateSearchQuery(
+		let queries = await generateQueries(
 			localStorage.token,
 			model,
 			messages.filter((message) => message?.content?.trim()),
 			prompt
 		).catch((error) => {
 			console.log(error);
-			return prompt;
+			return [];
 		});
 
-		if (!searchQuery || searchQuery == '') {
+		if (queries.length === 0) {
 			responseMessage.statusHistory.push({
 				done: true,
 				error: true,
@@ -2068,6 +2068,8 @@
 			return;
 		}
 
+		const searchQuery = queries[0];
+
 		responseMessage.statusHistory.push({
 			done: false,
 			action: 'web_search',