Browse Source

refac: title generation

Timothy J. Baek 10 months ago
parent
commit
5e7237b9cb

+ 0 - 2
backend/apps/ollama/main.py

@@ -41,8 +41,6 @@ from utils.utils import (
     get_admin_user,
 )
 
-from utils.models import get_model_id_from_custom_model_id
-
 
 from config import (
     SRC_LOG_LEVELS,

+ 21 - 0
backend/config.py

@@ -618,6 +618,27 @@ ADMIN_EMAIL = PersistentConfig(
 )
 
 
+TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
+    "TITLE_GENERATION_PROMPT_TEMPLATE",
+    "task.title.prompt_template",
+    os.environ.get(
+        "TITLE_GENERATION_PROMPT_TEMPLATE",
+        """Here is the query:
+{{prompt:middletruncate:8000}}
+
+Create a concise, 3-5 word phrase with an emoji as a title for the previous query. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
+
+Examples of titles:
+📉 Stock Market Trends
+🍪 Perfect Chocolate Chip Recipe
+Evolution of Music Streaming
+Remote Work Productivity Tips
+Artificial Intelligence in Healthcare
+🎮 Video Game Development Insights""",
+    ),
+)
+
+
 ####################################
 # WEBUI_SECRET_KEY
 ####################################

+ 119 - 81
backend/main.py

@@ -53,6 +53,8 @@ from utils.utils import (
     get_current_user,
     get_http_authorization_cred,
 )
+from utils.task import title_generation_template
+
 from apps.rag.utils import rag_messages
 
 from config import (
@@ -74,8 +76,9 @@ from config import (
     SRC_LOG_LEVELS,
     WEBHOOK_URL,
     ENABLE_ADMIN_EXPORT,
-    AppConfig,
     WEBUI_BUILD_HASH,
+    TITLE_GENERATION_PROMPT_TEMPLATE,
+    AppConfig,
 )
 from constants import ERROR_MESSAGES
 
@@ -131,7 +134,7 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 
 
 app.state.config.WEBHOOK_URL = WEBHOOK_URL
-
+app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
 
 app.state.MODELS = {}
 
@@ -240,6 +243,78 @@ class RAGMiddleware(BaseHTTPMiddleware):
 app.add_middleware(RAGMiddleware)
 
 
+def filter_pipeline(payload, user):
+    user = {"id": user.id, "name": user.name, "role": user.role}
+    model_id = payload["model"]
+    filters = [
+        model
+        for model in app.state.MODELS.values()
+        if "pipeline" in model
+        and "type" in model["pipeline"]
+        and model["pipeline"]["type"] == "filter"
+        and (
+            model["pipeline"]["pipelines"] == ["*"]
+            or any(
+                model_id == target_model_id
+                for target_model_id in model["pipeline"]["pipelines"]
+            )
+        )
+    ]
+    sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
+
+    model = app.state.MODELS[model_id]
+
+    if "pipeline" in model:
+        sorted_filters.append(model)
+
+    for filter in sorted_filters:
+        r = None
+        try:
+            urlIdx = filter["urlIdx"]
+
+            url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
+            key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
+
+            if key != "":
+                headers = {"Authorization": f"Bearer {key}"}
+                r = requests.post(
+                    f"{url}/{filter['id']}/filter/inlet",
+                    headers=headers,
+                    json={
+                        "user": user,
+                        "body": payload,
+                    },
+                )
+
+                r.raise_for_status()
+                payload = r.json()
+        except Exception as e:
+            # Handle connection error here
+            print(f"Connection error: {e}")
+
+            if r is not None:
+                try:
+                    res = r.json()
+                    if "detail" in res:
+                        return JSONResponse(
+                            status_code=r.status_code,
+                            content=res,
+                        )
+                except:
+                    pass
+
+            else:
+                pass
+
+    if "pipeline" not in app.state.MODELS[model_id]:
+        if "chat_id" in payload:
+            del payload["chat_id"]
+
+        if "title" in payload:
+            del payload["title"]
+    return payload
+
+
 class PipelineMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
         if request.method == "POST" and (
@@ -255,85 +330,10 @@ class PipelineMiddleware(BaseHTTPMiddleware):
             # Parse string to JSON
             data = json.loads(body_str) if body_str else {}
 
-            model_id = data["model"]
-            filters = [
-                model
-                for model in app.state.MODELS.values()
-                if "pipeline" in model
-                and "type" in model["pipeline"]
-                and model["pipeline"]["type"] == "filter"
-                and (
-                    model["pipeline"]["pipelines"] == ["*"]
-                    or any(
-                        model_id == target_model_id
-                        for target_model_id in model["pipeline"]["pipelines"]
-                    )
-                )
-            ]
-            sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
-
-            user = None
-            if len(sorted_filters) > 0:
-                try:
-                    user = get_current_user(
-                        get_http_authorization_cred(
-                            request.headers.get("Authorization")
-                        )
-                    )
-                    user = {"id": user.id, "name": user.name, "role": user.role}
-                except:
-                    pass
-
-            model = app.state.MODELS[model_id]
-
-            if "pipeline" in model:
-                sorted_filters.append(model)
-
-            for filter in sorted_filters:
-                r = None
-                try:
-                    urlIdx = filter["urlIdx"]
-
-                    url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
-                    key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
-
-                    if key != "":
-                        headers = {"Authorization": f"Bearer {key}"}
-                        r = requests.post(
-                            f"{url}/{filter['id']}/filter/inlet",
-                            headers=headers,
-                            json={
-                                "user": user,
-                                "body": data,
-                            },
-                        )
-
-                        r.raise_for_status()
-                        data = r.json()
-                except Exception as e:
-                    # Handle connection error here
-                    print(f"Connection error: {e}")
-
-                    if r is not None:
-                        try:
-                            res = r.json()
-                            if "detail" in res:
-                                return JSONResponse(
-                                    status_code=r.status_code,
-                                    content=res,
-                                )
-                        except:
-                            pass
-
-                    else:
-                        pass
-
-            if "pipeline" not in app.state.MODELS[model_id]:
-                if "chat_id" in data:
-                    del data["chat_id"]
-
-                if "title" in data:
-                    del data["title"]
+            user = get_current_user(
+                get_http_authorization_cred(request.headers.get("Authorization"))
+            )
+            data = filter_pipeline(data, user)
 
             modified_body_bytes = json.dumps(data).encode("utf-8")
             # Replace the request body with the modified one
@@ -494,6 +494,44 @@ async def get_models(user=Depends(get_verified_user)):
     return {"data": models}
 
 
+@app.post("/api/title/completions")
+async def generate_title(form_data: dict, user=Depends(get_verified_user)):
+    print("generate_title")
+    model_id = form_data["model"]
+    if model_id not in app.state.MODELS:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="Model not found",
+        )
+
+    model = app.state.MODELS[model_id]
+
+    template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
+
+    content = title_generation_template(
+        template, form_data["prompt"], user.model_dump()
+    )
+
+    payload = {
+        "model": model_id,
+        "messages": [{"role": "user", "content": content}],
+        "stream": False,
+        "max_tokens": 50,
+        "chat_id": form_data.get("chat_id", None),
+        "title": True,
+    }
+
+    print(payload)
+    payload = filter_pipeline(payload, user)
+
+    if model["owned_by"] == "ollama":
+        return await generate_ollama_chat_completion(
+            OpenAIChatCompletionForm(**payload), user=user
+        )
+    else:
+        return await generate_openai_chat_completion(payload, user=user)
+
+
 @app.post("/api/chat/completions")
 async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
     model_id = form_data["model"]

+ 0 - 10
backend/utils/models.py

@@ -1,10 +0,0 @@
-from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
-
-
-def get_model_id_from_custom_model_id(id: str):
-    model = Models.get_model_by_id(id)
-
-    if model:
-        return model.id
-    else:
-        return id

+ 70 - 0
backend/utils/task.py

@@ -0,0 +1,70 @@
+import re
+import math
+
+from datetime import datetime
+from typing import Optional
+
+
+def prompt_template(
+    template: str, user_name: str = None, current_location: str = None
+) -> str:
+    # Get the current date
+    current_date = datetime.now()
+
+    # Format the date to YYYY-MM-DD
+    formatted_date = current_date.strftime("%Y-%m-%d")
+
+    # Replace {{CURRENT_DATE}} in the template with the formatted date
+    template = template.replace("{{CURRENT_DATE}}", formatted_date)
+
+    if user_name:
+        # Replace {{USER_NAME}} in the template with the user's name
+        template = template.replace("{{USER_NAME}}", user_name)
+
+    if current_location:
+        # Replace {{CURRENT_LOCATION}} in the template with the current location
+        template = template.replace("{{CURRENT_LOCATION}}", current_location)
+
+    return template
+
+
+def title_generation_template(
+    template: str, prompt: str, user: Optional[dict] = None
+) -> str:
+    def replacement_function(match):
+        full_match = match.group(0)
+        start_length = match.group(1)
+        end_length = match.group(2)
+        middle_length = match.group(3)
+
+        if full_match == "{{prompt}}":
+            return prompt
+        elif start_length is not None:
+            return prompt[: int(start_length)]
+        elif end_length is not None:
+            return prompt[-int(end_length) :]
+        elif middle_length is not None:
+            middle_length = int(middle_length)
+            if len(prompt) <= middle_length:
+                return prompt
+            start = prompt[: math.ceil(middle_length / 2)]
+            end = prompt[-math.floor(middle_length / 2) :]
+            return f"{start}...{end}"
+        return ""
+
+    template = re.sub(
+        r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
+        replacement_function,
+        template,
+    )
+
+    template = prompt_template(
+        template,
+        **(
+            {"user_name": user.get("name"), "current_location": user.get("location")}
+            if user
+            else {}
+        ),
+    )
+
+    return template

+ 40 - 0
src/lib/apis/index.ts

@@ -104,6 +104,46 @@ export const chatCompleted = async (token: string, body: ChatCompletedForm) => {
 	return res;
 };
 
+export const generateTitle = async (
+	token: string = '',
+	model: string,
+	prompt: string,
+	chat_id?: string
+) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_BASE_URL}/api/title/completions`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			model: model,
+			prompt: prompt,
+			...(chat_id && { chat_id: chat_id })
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat';
+};
+
 export const getPipelinesList = async (token: string = '') => {
 	let error = null;
 

+ 17 - 31
src/lib/components/chat/Chat.svelte

@@ -7,6 +7,10 @@
 	import { goto } from '$app/navigation';
 	import { page } from '$app/stores';
 
+	import type { Writable } from 'svelte/store';
+	import type { i18n as i18nType } from 'i18next';
+	import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
+
 	import {
 		chatId,
 		chats,
@@ -40,24 +44,17 @@
 		getTagsById,
 		updateChatById
 	} from '$lib/apis/chats';
-	import {
-		generateOpenAIChatCompletion,
-		generateSearchQuery,
-		generateTitle
-	} from '$lib/apis/openai';
+	import { generateOpenAIChatCompletion, generateSearchQuery } from '$lib/apis/openai';
+	import { runWebSearch } from '$lib/apis/rag';
+	import { createOpenAITextStream } from '$lib/apis/streaming';
+	import { queryMemory } from '$lib/apis/memories';
+	import { getUserSettings } from '$lib/apis/users';
+	import { chatCompleted, generateTitle } from '$lib/apis';
 
+	import Banner from '../common/Banner.svelte';
 	import MessageInput from '$lib/components/chat/MessageInput.svelte';
 	import Messages from '$lib/components/chat/Messages.svelte';
 	import Navbar from '$lib/components/layout/Navbar.svelte';
-	import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
-	import { createOpenAITextStream } from '$lib/apis/streaming';
-	import { queryMemory } from '$lib/apis/memories';
-	import type { Writable } from 'svelte/store';
-	import type { i18n as i18nType } from 'i18next';
-	import { runWebSearch } from '$lib/apis/rag';
-	import Banner from '../common/Banner.svelte';
-	import { getUserSettings } from '$lib/apis/users';
-	import { chatCompleted } from '$lib/apis';
 	import CallOverlay from './MessageInput/CallOverlay.svelte';
 
 	const i18n: Writable<i18nType> = getContext('i18n');
@@ -1116,26 +1113,15 @@
 
 	const generateChatTitle = async (userPrompt) => {
 		if ($settings?.title?.auto ?? true) {
-			const model = $models.find((model) => model.id === selectedModels[0]);
-
-			const titleModelId =
-				model?.owned_by === 'openai' ?? false
-					? $settings?.title?.modelExternal ?? selectedModels[0]
-					: $settings?.title?.model ?? selectedModels[0];
-			const titleModel = $models.find((model) => model.id === titleModelId);
-
-			console.log(titleModel);
 			const title = await generateTitle(
 				localStorage.token,
-				$settings?.title?.prompt ??
-					$i18n.t(
-						"Create a concise, 3-5 word phrase as a header for the following query, strictly adhering to the 3-5 word limit and avoiding the use of the word 'title':"
-					) + ' {{prompt}}',
-				titleModelId,
+				selectedModels[0],
 				userPrompt,
-				$chatId,
-				`${WEBUI_BASE_URL}/api`
-			);
+				$chatId
+			).catch((error) => {
+				console.error(error);
+				return 'New Chat';
+			});
 
 			return title;
 		} else {

+ 0 - 0
src/lib/utils/index.test.ts → src/lib/utils/_template_old.ts