Timothy J. Baek 10 місяців тому
батько
коміт
591cd993c2
5 змінених файлів з 223 додано та 29 видалено
  1. 25 0
      backend/config.py
  2. 113 3
      backend/main.py
  3. 42 0
      backend/utils/task.py
  4. 40 0
      src/lib/apis/index.ts
  5. 3 26
      src/lib/components/chat/Chat.svelte

+ 25 - 0
backend/config.py

@@ -618,6 +618,18 @@ ADMIN_EMAIL = PersistentConfig(
 )
 
 
+TASK_MODEL = PersistentConfig(
+    "TASK_MODEL",
+    "task.model.default",
+    os.environ.get("TASK_MODEL", ""),
+)
+
+TASK_MODEL_EXTERNAL = PersistentConfig(
+    "TASK_MODEL_EXTERNAL",
+    "task.model.external",
+    os.environ.get("TASK_MODEL_EXTERNAL", ""),
+)
+
 TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
     "TITLE_GENERATION_PROMPT_TEMPLATE",
     "task.title.prompt_template",
@@ -639,6 +651,19 @@ Artificial Intelligence in Healthcare
 )
 
 
+SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
+    "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
+    "task.search.prompt_template",
+    os.environ.get(
+        "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
+        """You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is {{CURRENT_DATE}}.
+        
+Question:
+{{prompt:end:4000}}""",
+    ),
+)
+
+
 ####################################
 # WEBUI_SECRET_KEY
 ####################################

+ 113 - 3
backend/main.py

@@ -53,7 +53,7 @@ from utils.utils import (
     get_current_user,
     get_http_authorization_cred,
 )
-from utils.task import title_generation_template
+from utils.task import title_generation_template, search_query_generation_template
 
 from apps.rag.utils import rag_messages
 
@@ -77,7 +77,10 @@ from config import (
     WEBHOOK_URL,
     ENABLE_ADMIN_EXPORT,
     WEBUI_BUILD_HASH,
+    TASK_MODEL,
+    TASK_MODEL_EXTERNAL,
     TITLE_GENERATION_PROMPT_TEMPLATE,
+    SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
     AppConfig,
 )
 from constants import ERROR_MESSAGES
@@ -132,9 +135,15 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
 app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
 app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 
-
 app.state.config.WEBHOOK_URL = WEBHOOK_URL
+
+
+app.state.config.TASK_MODEL = TASK_MODEL
+app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
 app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
+app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
+    SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
+)
 
 app.state.MODELS = {}
 
@@ -494,9 +503,46 @@ async def get_models(user=Depends(get_verified_user)):
     return {"data": models}
 
 
+@app.get("/api/task/config")
+async def get_task_config(user=Depends(get_verified_user)):
+    return {
+        "TASK_MODEL": app.state.config.TASK_MODEL,
+        "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
+        "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
+        "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
+    }
+
+
+class TaskConfigForm(BaseModel):
+    TASK_MODEL: Optional[str]
+    TASK_MODEL_EXTERNAL: Optional[str]
+    TITLE_GENERATION_PROMPT_TEMPLATE: str
+    SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
+
+
+@app.post("/api/task/config/update")
+async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)):
+    app.state.config.TASK_MODEL = form_data.TASK_MODEL
+    app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
+    app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
+        form_data.TITLE_GENERATION_PROMPT_TEMPLATE
+    )
+    app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
+        form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
+    )
+
+    return {
+        "TASK_MODEL": app.state.config.TASK_MODEL,
+        "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
+        "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
+        "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
+    }
+
+
 @app.post("/api/task/title/completions")
 async def generate_title(form_data: dict, user=Depends(get_verified_user)):
     print("generate_title")
+
     model_id = form_data["model"]
     if model_id not in app.state.MODELS:
         raise HTTPException(
@@ -504,6 +550,20 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
             detail="Model not found",
         )
 
+    # Check if the user has a custom task model
+    # If the user has a custom task model, use that model
+    if app.state.MODELS[model_id]["owned_by"] == "ollama":
+        if app.state.config.TASK_MODEL:
+            task_model_id = app.state.config.TASK_MODEL
+            if task_model_id in app.state.MODELS:
+                model_id = task_model_id
+    else:
+        if app.state.config.TASK_MODEL_EXTERNAL:
+            task_model_id = app.state.config.TASK_MODEL_EXTERNAL
+            if task_model_id in app.state.MODELS:
+                model_id = task_model_id
+
+    print(model_id)
     model = app.state.MODELS[model_id]
 
     template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
@@ -532,6 +592,57 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
         return await generate_openai_chat_completion(payload, user=user)
 
 
+@app.post("/api/task/query/completions")
+async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
+    print("generate_search_query")
+
+    model_id = form_data["model"]
+    if model_id not in app.state.MODELS:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="Model not found",
+        )
+
+    # Check if the user has a custom task model
+    # If the user has a custom task model, use that model
+    if app.state.MODELS[model_id]["owned_by"] == "ollama":
+        if app.state.config.TASK_MODEL:
+            task_model_id = app.state.config.TASK_MODEL
+            if task_model_id in app.state.MODELS:
+                model_id = task_model_id
+    else:
+        if app.state.config.TASK_MODEL_EXTERNAL:
+            task_model_id = app.state.config.TASK_MODEL_EXTERNAL
+            if task_model_id in app.state.MODELS:
+                model_id = task_model_id
+
+    print(model_id)
+    model = app.state.MODELS[model_id]
+
+    template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
+
+    content = search_query_generation_template(
+        template, form_data["prompt"], user.model_dump()
+    )
+
+    payload = {
+        "model": model_id,
+        "messages": [{"role": "user", "content": content}],
+        "stream": False,
+        "max_tokens": 30,
+    }
+
+    print(payload)
+    payload = filter_pipeline(payload, user)
+
+    if model["owned_by"] == "ollama":
+        return await generate_ollama_chat_completion(
+            OpenAIChatCompletionForm(**payload), user=user
+        )
+    else:
+        return await generate_openai_chat_completion(payload, user=user)
+
+
 @app.post("/api/chat/completions")
 async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
     model_id = form_data["model"]
@@ -542,7 +653,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
         )
 
     model = app.state.MODELS[model_id]
-
     print(model)
 
     if model["owned_by"] == "ollama":

+ 42 - 0
backend/utils/task.py

@@ -68,3 +68,45 @@ def title_generation_template(
     )
 
     return template
+
+
+def search_query_generation_template(
+    template: str, prompt: str, user: Optional[dict] = None
+) -> str:
+
+    def replacement_function(match):
+        full_match = match.group(0)
+        start_length = match.group(1)
+        end_length = match.group(2)
+        middle_length = match.group(3)
+
+        if full_match == "{{prompt}}":
+            return prompt
+        elif start_length is not None:
+            return prompt[: int(start_length)]
+        elif end_length is not None:
+            return prompt[-int(end_length) :]
+        elif middle_length is not None:
+            middle_length = int(middle_length)
+            if len(prompt) <= middle_length:
+                return prompt
+            start = prompt[: math.ceil(middle_length / 2)]
+            end = prompt[-math.floor(middle_length / 2) :]
+            return f"{start}...{end}"
+        return ""
+
+    template = re.sub(
+        r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
+        replacement_function,
+        template,
+    )
+
+    template = prompt_template(
+        template,
+        **(
+            {"user_name": user.get("name"), "current_location": user.get("location")}
+            if user
+            else {}
+        ),
+    )
+    return template

+ 40 - 0
src/lib/apis/index.ts

@@ -144,6 +144,46 @@ export const generateTitle = async (
 	return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat';
 };
 
+export const generateSearchQuery = async (
+	token: string = '',
+	model: string,
+	messages: object[],
+	prompt: string
+) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_BASE_URL}/api/task/query/completions`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			model: model,
+			messages: messages,
+			prompt: prompt
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt;
+};
+
 export const getPipelinesList = async (token: string = '') => {
 	let error = null;
 

+ 3 - 26
src/lib/components/chat/Chat.svelte

@@ -44,12 +44,12 @@
 		getTagsById,
 		updateChatById
 	} from '$lib/apis/chats';
-	import { generateOpenAIChatCompletion, generateSearchQuery } from '$lib/apis/openai';
+	import { generateOpenAIChatCompletion } from '$lib/apis/openai';
 	import { runWebSearch } from '$lib/apis/rag';
 	import { createOpenAITextStream } from '$lib/apis/streaming';
 	import { queryMemory } from '$lib/apis/memories';
 	import { getUserSettings } from '$lib/apis/users';
-	import { chatCompleted, generateTitle } from '$lib/apis';
+	import { chatCompleted, generateTitle, generateSearchQuery } from '$lib/apis';
 
 	import Banner from '../common/Banner.svelte';
 	import MessageInput from '$lib/components/chat/MessageInput.svelte';
@@ -508,7 +508,7 @@
 		const prompt = history.messages[parentId].content;
 		let searchQuery = prompt;
 		if (prompt.length > 100) {
-			searchQuery = await generateChatSearchQuery(model, prompt);
+			searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt);
 			if (!searchQuery) {
 				toast.warning($i18n.t('No search query generated'));
 				responseMessage.status = {
@@ -1129,29 +1129,6 @@
 		}
 	};
 
-	const generateChatSearchQuery = async (modelId: string, prompt: string) => {
-		const model = $models.find((model) => model.id === modelId);
-		const taskModelId =
-			model?.owned_by === 'openai' ?? false
-				? $settings?.title?.modelExternal ?? modelId
-				: $settings?.title?.model ?? modelId;
-		const taskModel = $models.find((model) => model.id === taskModelId);
-
-		const previousMessages = messages
-			.filter((message) => message.role === 'user')
-			.map((message) => message.content);
-
-		return await generateSearchQuery(
-			localStorage.token,
-			taskModelId,
-			previousMessages,
-			prompt,
-			taskModel?.owned_by === 'openai' ?? false
-				? `${OPENAI_API_BASE_URL}`
-				: `${OLLAMA_API_BASE_URL}/v1`
-		);
-	};
-
 	const setChatTitle = async (_chatId, _title) => {
 		if (_chatId === $chatId) {
 			title = _title;