Bläddra i källkod

refac: prompt variables

Timothy Jaeryang Baek 3 månader sedan
förälder
incheckning
cc99673906

+ 1 - 0
backend/open_webui/main.py

@@ -875,6 +875,7 @@ async def chat_completion(
             "tool_ids": form_data.get("tool_ids", None),
             "tool_ids": form_data.get("tool_ids", None),
             "files": form_data.get("files", None),
             "files": form_data.get("files", None),
             "features": form_data.get("features", None),
             "features": form_data.get("features", None),
+            "variables": form_data.get("variables", None),
         }
         }
         form_data["metadata"] = metadata
         form_data["metadata"] = metadata
 
 

+ 2 - 3
backend/open_webui/routers/ollama.py

@@ -977,6 +977,7 @@ async def generate_chat_completion(
     if BYPASS_MODEL_ACCESS_CONTROL:
     if BYPASS_MODEL_ACCESS_CONTROL:
         bypass_filter = True
         bypass_filter = True
 
 
+    metadata = form_data.pop("metadata", None)
     try:
     try:
         form_data = GenerateChatCompletionForm(**form_data)
         form_data = GenerateChatCompletionForm(**form_data)
     except Exception as e:
     except Exception as e:
@@ -987,8 +988,6 @@ async def generate_chat_completion(
         )
         )
 
 
     payload = {**form_data.model_dump(exclude_none=True)}
     payload = {**form_data.model_dump(exclude_none=True)}
-    if "metadata" in payload:
-        del payload["metadata"]
 
 
     model_id = payload["model"]
     model_id = payload["model"]
     model_info = Models.get_model_by_id(model_id)
     model_info = Models.get_model_by_id(model_id)
@@ -1006,7 +1005,7 @@ async def generate_chat_completion(
             payload["options"] = apply_model_params_to_body_ollama(
             payload["options"] = apply_model_params_to_body_ollama(
                 params, payload["options"]
                 params, payload["options"]
             )
             )
-            payload = apply_model_system_prompt_to_body(params, payload, user)
+            payload = apply_model_system_prompt_to_body(params, payload, metadata)
 
 
         # Check if user has access to the model
         # Check if user has access to the model
         if not bypass_filter and user.role == "user":
         if not bypass_filter and user.role == "user":

+ 3 - 3
backend/open_webui/routers/openai.py

@@ -551,9 +551,9 @@ async def generate_chat_completion(
         bypass_filter = True
         bypass_filter = True
 
 
     idx = 0
     idx = 0
+
     payload = {**form_data}
     payload = {**form_data}
-    if "metadata" in payload:
-        del payload["metadata"]
+    metadata = payload.pop("metadata", None)
 
 
     model_id = form_data.get("model")
     model_id = form_data.get("model")
     model_info = Models.get_model_by_id(model_id)
     model_info = Models.get_model_by_id(model_id)
@@ -566,7 +566,7 @@ async def generate_chat_completion(
 
 
         params = model_info.params.model_dump()
         params = model_info.params.model_dump()
         payload = apply_model_params_to_body_openai(params, payload)
         payload = apply_model_params_to_body_openai(params, payload)
-        payload = apply_model_system_prompt_to_body(params, payload, user)
+        payload = apply_model_system_prompt_to_body(params, payload, metadata)
 
 
         # Check if user has access to the model
         # Check if user has access to the model
         if not bypass_filter and user.role == "user":
         if not bypass_filter and user.role == "user":

+ 2 - 0
backend/open_webui/utils/middleware.py

@@ -749,6 +749,8 @@ async def process_chat_payload(request, form_data, metadata, user, model):
         files.extend(knowledge_files)
         files.extend(knowledge_files)
         form_data["files"] = files
         form_data["files"] = files
 
 
+    variables = form_data.pop("variables", None)
+
     features = form_data.pop("features", None)
     features = form_data.pop("features", None)
     if features:
     if features:
         if "web_search" in features and features["web_search"]:
         if "web_search" in features and features["web_search"]:

+ 12 - 10
backend/open_webui/utils/payload.py

@@ -1,4 +1,4 @@
-from open_webui.utils.task import prompt_template
+from open_webui.utils.task import prompt_variables_template
 from open_webui.utils.misc import (
 from open_webui.utils.misc import (
     add_or_update_system_message,
     add_or_update_system_message,
 )
 )
@@ -7,19 +7,18 @@ from typing import Callable, Optional
 
 
 
 
 # inplace function: form_data is modified
 # inplace function: form_data is modified
-def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
+def apply_model_system_prompt_to_body(
+    params: dict, form_data: dict, metadata: Optional[dict] = None
+) -> dict:
     system = params.get("system", None)
     system = params.get("system", None)
     if not system:
     if not system:
         return form_data
         return form_data
 
 
-    if user:
-        template_params = {
-            "user_name": user.name,
-            "user_location": user.info.get("location") if user.info else None,
-        }
-    else:
-        template_params = {}
-    system = prompt_template(system, **template_params)
+    if metadata:
+        print("apply_model_system_prompt_to_body: metadata", metadata)
+        variables = metadata.get("variables", {})
+        system = prompt_variables_template(system, variables)
+
     form_data["messages"] = add_or_update_system_message(
     form_data["messages"] = add_or_update_system_message(
         system, form_data.get("messages", [])
         system, form_data.get("messages", [])
     )
     )
@@ -188,4 +187,7 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
     if ollama_options:
     if ollama_options:
         ollama_payload["options"] = ollama_options
         ollama_payload["options"] = ollama_options
 
 
+    if "metadata" in openai_payload:
+        ollama_payload["metadata"] = openai_payload["metadata"]
+
     return ollama_payload
     return ollama_payload

+ 6 - 0
backend/open_webui/utils/task.py

@@ -32,6 +32,12 @@ def get_task_model_id(
     return task_model_id
     return task_model_id
 
 
 
 
+def prompt_variables_template(template: str, variables: dict[str, str]) -> str:
+    for variable, value in variables.items():
+        template = template.replace(variable, value)
+    return template
+
+
 def prompt_template(
 def prompt_template(
     template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
     template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
 ) -> str:
 ) -> str:

+ 10 - 2
src/lib/components/chat/Chat.svelte

@@ -45,7 +45,8 @@
 		promptTemplate,
 		promptTemplate,
 		splitStream,
 		splitStream,
 		sleep,
 		sleep,
-		removeDetailsWithReasoning
+		removeDetailsWithReasoning,
+		getPromptVariables
 	} from '$lib/utils';
 	} from '$lib/utils';
 
 
 	import { generateChatCompletion } from '$lib/apis/ollama';
 	import { generateChatCompletion } from '$lib/apis/ollama';
@@ -628,7 +629,7 @@
 		} catch (e) {
 		} catch (e) {
 			// Remove the failed doc from the files array
 			// Remove the failed doc from the files array
 			files = files.filter((f) => f.name !== url);
 			files = files.filter((f) => f.name !== url);
-			toast.error(e);
+			toast.error(`${e}`);
 		}
 		}
 	};
 	};
 
 
@@ -1558,10 +1559,17 @@
 
 
 				files: (files?.length ?? 0) > 0 ? files : undefined,
 				files: (files?.length ?? 0) > 0 ? files : undefined,
 				tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 				tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
+
 				features: {
 				features: {
 					image_generation: imageGenerationEnabled,
 					image_generation: imageGenerationEnabled,
 					web_search: webSearchEnabled
 					web_search: webSearchEnabled
 				},
 				},
+				variables: {
+					...getPromptVariables(
+						$user.name,
+						$settings?.userLocation ? await getAndUpdateUserLocation(localStorage.token) : undefined
+					)
+				},
 
 
 				session_id: $socket?.id,
 				session_id: $socket?.id,
 				chat_id: $chatId,
 				chat_id: $chatId,

+ 13 - 0
src/lib/utils/index.ts

@@ -766,6 +766,19 @@ export const blobToFile = (blob, fileName) => {
 	return file;
 	return file;
 };
 };
 
 
+export const getPromptVariables = (user_name, user_location) => {
+	return {
+		'{{USER_NAME}}': user_name,
+		'{{USER_LOCATION}}': user_location || 'Unknown',
+		'{{CURRENT_DATETIME}}': getCurrentDateTime(),
+		'{{CURRENT_DATE}}': getFormattedDate(),
+		'{{CURRENT_TIME}}': getFormattedTime(),
+		'{{CURRENT_WEEKDAY}}': getWeekday(),
+		'{{CURRENT_TIMEZONE}}': getUserTimezone(),
+		'{{USER_LANGUAGE}}': localStorage.getItem('locale') || 'en-US'
+	};
+};
+
 /**
 /**
  * @param {string} template - The template string containing placeholders.
  * @param {string} template - The template string containing placeholders.
  * @returns {string} The template string with the placeholders replaced by the prompt.
  * @returns {string} The template string with the placeholders replaced by the prompt.