Timothy Jaeryang Baek 4 months ago
parent
commit
fe5519e0a2

+ 57 - 179
backend/open_webui/main.py

@@ -75,6 +75,11 @@ from open_webui.routers.retrieval import (
     get_ef,
     get_rf,
 )
+from open_webui.routers.pipelines import (
+    process_pipeline_inlet_filter,
+    process_pipeline_outlet_filter,
+)
+
 from open_webui.retrieval.utils import get_sources_from_files
 
 
@@ -290,6 +295,7 @@ from open_webui.utils.response import (
 )
 
 from open_webui.utils.task import (
+    get_task_model_id,
     rag_template,
     tools_function_calling_generation_template,
 )
@@ -662,34 +668,35 @@ app.state.MODELS = {}
 ##################################
 
 
-def get_filter_function_ids(model):
-    def get_priority(function_id):
-        function = Functions.get_function_by_id(function_id)
-        if function is not None and hasattr(function, "valves"):
-            # TODO: Fix FunctionModel
-            return (function.valves if function.valves else {}).get("priority", 0)
-        return 0
-
-    filter_ids = [function.id for function in Functions.get_global_filter_functions()]
-    if "info" in model and "meta" in model["info"]:
-        filter_ids.extend(model["info"]["meta"].get("filterIds", []))
-        filter_ids = list(set(filter_ids))
+async def chat_completion_filter_functions_handler(body, model, extra_params):
+    skip_files = None
 
-    enabled_filter_ids = [
-        function.id
-        for function in Functions.get_functions_by_type("filter", active_only=True)
-    ]
+    def get_filter_function_ids(model):
+        def get_priority(function_id):
+            function = Functions.get_function_by_id(function_id)
+            if function is not None and hasattr(function, "valves"):
+                # TODO: Fix FunctionModel
+                return (function.valves if function.valves else {}).get("priority", 0)
+            return 0
 
-    filter_ids = [
-        filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
-    ]
+        filter_ids = [
+            function.id for function in Functions.get_global_filter_functions()
+        ]
+        if "info" in model and "meta" in model["info"]:
+            filter_ids.extend(model["info"]["meta"].get("filterIds", []))
+            filter_ids = list(set(filter_ids))
 
-    filter_ids.sort(key=get_priority)
-    return filter_ids
+        enabled_filter_ids = [
+            function.id
+            for function in Functions.get_functions_by_type("filter", active_only=True)
+        ]
 
+        filter_ids = [
+            filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
+        ]
 
-async def chat_completion_filter_functions_handler(body, model, extra_params):
-    skip_files = None
+        filter_ids.sort(key=get_priority)
+        return filter_ids
 
     filter_ids = get_filter_function_ids(model)
     for filter_id in filter_ids:
@@ -791,22 +798,6 @@ async def get_content_from_response(response) -> Optional[str]:
     return content
 
 
-def get_task_model_id(
-    default_model_id: str, task_model: str, task_model_external: str, models
-) -> str:
-    # Set the task model
-    task_model_id = default_model_id
-    # Check if the user has a custom task model and use that model
-    if models[task_model_id]["owned_by"] == "ollama":
-        if task_model and task_model in models:
-            task_model_id = task_model
-    else:
-        if task_model_external and task_model_external in models:
-            task_model_id = task_model_external
-
-    return task_model_id
-
-
 async def chat_completion_tools_handler(
     body: dict, user: UserModel, models, extra_params: dict
 ) -> tuple[dict, dict]:
@@ -857,7 +848,7 @@ async def chat_completion_tools_handler(
     )
 
     try:
-        payload = filter_pipeline(payload, user, models)
+        payload = process_pipeline_inlet_filter(request, payload, user, models)
     except Exception as e:
         raise e
 
@@ -1153,7 +1144,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             if prompt is None:
                 raise Exception("No user message found")
             if (
-                retrieval_app.state.config.RELEVANCE_THRESHOLD == 0
+                app.state.config.RELEVANCE_THRESHOLD == 0
                 and context_string.strip() == ""
             ):
                 log.debug(
@@ -1164,16 +1155,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             # TODO: replace with add_or_update_system_message
             if model["owned_by"] == "ollama":
                 body["messages"] = prepend_to_first_user_message_content(
-                    rag_template(
-                        retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt
-                    ),
+                    rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt),
                     body["messages"],
                 )
             else:
                 body["messages"] = add_or_update_system_message(
-                    rag_template(
-                        retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt
-                    ),
+                    rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt),
                     body["messages"],
                 )
 
@@ -1225,77 +1212,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 app.add_middleware(ChatCompletionMiddleware)
 
 
-##################################
-#
-# Pipeline Middleware
-#
-##################################
-
-
-def get_sorted_filters(model_id, models):
-    filters = [
-        model
-        for model in models.values()
-        if "pipeline" in model
-        and "type" in model["pipeline"]
-        and model["pipeline"]["type"] == "filter"
-        and (
-            model["pipeline"]["pipelines"] == ["*"]
-            or any(
-                model_id == target_model_id
-                for target_model_id in model["pipeline"]["pipelines"]
-            )
-        )
-    ]
-    sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
-    return sorted_filters
-
-
-def filter_pipeline(payload, user, models):
-    user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
-    model_id = payload["model"]
-
-    sorted_filters = get_sorted_filters(model_id, models)
-    model = models[model_id]
-
-    if "pipeline" in model:
-        sorted_filters.append(model)
-
-    for filter in sorted_filters:
-        r = None
-        try:
-            urlIdx = filter["urlIdx"]
-
-            url = app.state.config.OPENAI_API_BASE_URLS[urlIdx]
-            key = app.state.config.OPENAI_API_KEYS[urlIdx]
-
-            if key == "":
-                continue
-
-            headers = {"Authorization": f"Bearer {key}"}
-            r = requests.post(
-                f"{url}/{filter['id']}/filter/inlet",
-                headers=headers,
-                json={
-                    "user": user,
-                    "body": payload,
-                },
-            )
-
-            r.raise_for_status()
-            payload = r.json()
-        except Exception as e:
-            # Handle connection error here
-            print(f"Connection error: {e}")
-
-            if r is not None:
-                res = r.json()
-                if "detail" in res:
-                    raise Exception(r.status_code, res["detail"])
-
-    return payload
-
-
 class PipelineMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
         if not request.method == "POST" and any(
@@ -1335,11 +1251,11 @@ class PipelineMiddleware(BaseHTTPMiddleware):
                 content={"detail": e.detail},
             )
 
-        model_list = await get_all_models()
-        models = {model["id"]: model for model in model_list}
+        await get_all_models()
+        models = app.state.MODELS
 
         try:
-            data = filter_pipeline(data, user, models)
+            data = process_pipeline_inlet_filter(request, data, user, models)
         except Exception as e:
             if len(e.args) > 1:
                 return JSONResponse(
@@ -1447,8 +1363,8 @@ app.include_router(ollama.router, prefix="/ollama", tags=["ollama"])
 app.include_router(openai.router, prefix="/openai", tags=["openai"])
 
 
-app.include_router(pipelines.router, prefix="/pipelines", tags=["pipelines"])
-app.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
+app.include_router(pipelines.router, prefix="/api/pipelines", tags=["pipelines"])
+app.include_router(tasks.router, prefix="/api/tasks", tags=["tasks"])
 
 
 app.include_router(images.router, prefix="/api/v1/images", tags=["images"])
@@ -2105,7 +2021,6 @@ async def generate_chat_completions(
     if model["owned_by"] == "ollama":
         # Using /ollama/api/chat endpoint
         form_data = convert_payload_openai_to_ollama(form_data)
-        form_data = GenerateChatCompletionForm(**form_data)
         response = await generate_ollama_chat_completion(
             form_data=form_data, user=user, bypass_filter=bypass_filter
         )
@@ -2124,7 +2039,9 @@ async def generate_chat_completions(
 
 
 @app.post("/api/chat/completed")
-async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
+async def chat_completed(
+    request: Request, form_data: dict, user=Depends(get_verified_user)
+):
     model_list = await get_all_models()
     models = {model["id"]: model for model in model_list}
 
@@ -2137,53 +2054,14 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
         )
 
     model = models[model_id]
-    sorted_filters = get_sorted_filters(model_id, models)
-    if "pipeline" in model:
-        sorted_filters = [model] + sorted_filters
-
-    for filter in sorted_filters:
-        r = None
-        try:
-            urlIdx = filter["urlIdx"]
-
-            url = app.state.config.OPENAI_API_BASE_URLS[urlIdx]
-            key = app.state.config.OPENAI_API_KEYS[urlIdx]
-
-            if key != "":
-                headers = {"Authorization": f"Bearer {key}"}
-                r = requests.post(
-                    f"{url}/{filter['id']}/filter/outlet",
-                    headers=headers,
-                    json={
-                        "user": {
-                            "id": user.id,
-                            "name": user.name,
-                            "email": user.email,
-                            "role": user.role,
-                        },
-                        "body": data,
-                    },
-                )
 
-                r.raise_for_status()
-                data = r.json()
-        except Exception as e:
-            # Handle connection error here
-            print(f"Connection error: {e}")
-
-            if r is not None:
-                try:
-                    res = r.json()
-                    if "detail" in res:
-                        return JSONResponse(
-                            status_code=r.status_code,
-                            content=res,
-                        )
-                except Exception:
-                    pass
-
-            else:
-                pass
+    try:
+        data = process_pipeline_outlet_filter(request, data, user, models)
+    except Exception as e:
+        return HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=str(e),
+        )
 
     __event_emitter__ = get_event_emitter(
         {
@@ -2455,8 +2333,8 @@ async def get_app_config(request: Request):
             "enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
             **(
                 {
-                    "enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH,
-                    "enable_image_generation": images_app.state.config.ENABLED,
+                    "enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
+                    "enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
                     "enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
                     "enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
                     "enable_admin_export": ENABLE_ADMIN_EXPORT,
@@ -2472,17 +2350,17 @@ async def get_app_config(request: Request):
                 "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
                 "audio": {
                     "tts": {
-                        "engine": audio_app.state.config.TTS_ENGINE,
-                        "voice": audio_app.state.config.TTS_VOICE,
-                        "split_on": audio_app.state.config.TTS_SPLIT_ON,
+                        "engine": app.state.config.TTS_ENGINE,
+                        "voice": app.state.config.TTS_VOICE,
+                        "split_on": app.state.config.TTS_SPLIT_ON,
                     },
                     "stt": {
-                        "engine": audio_app.state.config.STT_ENGINE,
+                        "engine": app.state.config.STT_ENGINE,
                     },
                 },
                 "file": {
-                    "max_size": retrieval_app.state.config.FILE_MAX_SIZE,
-                    "max_count": retrieval_app.state.config.FILE_MAX_COUNT,
+                    "max_size": app.state.config.FILE_MAX_SIZE,
+                    "max_count": app.state.config.FILE_MAX_COUNT,
                 },
                 "permissions": {**app.state.config.USER_PERMISSIONS},
             }

+ 10 - 1
backend/open_webui/routers/ollama.py

@@ -941,7 +941,7 @@ async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] =
 @router.post("/api/chat/{url_idx}")
 async def generate_chat_completion(
     request: Request,
-    form_data: GenerateChatCompletionForm,
+    form_data: dict,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
     bypass_filter: Optional[bool] = False,
@@ -949,6 +949,15 @@ async def generate_chat_completion(
     if BYPASS_MODEL_ACCESS_CONTROL:
         bypass_filter = True
 
+    try:
+        form_data = GenerateChatCompletionForm(**form_data)
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(
+            status_code=400,
+            detail=str(e),
+        )
+
     payload = {**form_data.model_dump(exclude_none=True)}
     if "metadata" in payload:
         del payload["metadata"]

+ 132 - 8
backend/open_webui/routers/pipelines.py

@@ -30,6 +30,130 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MAIN"])
 
 
+##################################
+#
+# Pipeline Middleware
+#
+##################################
+
+
+def get_sorted_filters(model_id, models):
+    filters = [
+        model
+        for model in models.values()
+        if "pipeline" in model
+        and "type" in model["pipeline"]
+        and model["pipeline"]["type"] == "filter"
+        and (
+            model["pipeline"]["pipelines"] == ["*"]
+            or any(
+                model_id == target_model_id
+                for target_model_id in model["pipeline"]["pipelines"]
+            )
+        )
+    ]
+    sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
+    return sorted_filters
+
+
+def process_pipeline_inlet_filter(request, payload, user, models):
+    user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
+    model_id = payload["model"]
+
+    sorted_filters = get_sorted_filters(model_id, models)
+    model = models[model_id]
+
+    if "pipeline" in model:
+        sorted_filters.append(model)
+
+    for filter in sorted_filters:
+        r = None
+        try:
+            urlIdx = filter["urlIdx"]
+
+            url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
+            key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
+
+            if key == "":
+                continue
+
+            headers = {"Authorization": f"Bearer {key}"}
+            r = requests.post(
+                f"{url}/{filter['id']}/filter/inlet",
+                headers=headers,
+                json={
+                    "user": user,
+                    "body": payload,
+                },
+            )
+
+            r.raise_for_status()
+            payload = r.json()
+        except Exception as e:
+            # Handle connection error here
+            print(f"Connection error: {e}")
+
+            if r is not None:
+                res = r.json()
+                if "detail" in res:
+                    raise Exception(r.status_code, res["detail"])
+
+    return payload
+
+
+def process_pipeline_outlet_filter(request, payload, user, models):
+    user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
+    model_id = payload["model"]
+
+    sorted_filters = get_sorted_filters(model_id, models)
+    model = models[model_id]
+
+    if "pipeline" in model:
+        sorted_filters = [model] + sorted_filters
+
+    for filter in sorted_filters:
+        r = None
+        try:
+            urlIdx = filter["urlIdx"]
+
+            url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
+            key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
+
+            if key != "":
+                r = requests.post(
+                    f"{url}/{filter['id']}/filter/outlet",
+                    headers={"Authorization": f"Bearer {key}"},
+                    json={
+                        "user": {
+                            "id": user.id,
+                            "name": user.name,
+                            "email": user.email,
+                            "role": user.role,
+                        },
+                        "body": data,
+                    },
+                )
+
+                r.raise_for_status()
+                data = r.json()
+        except Exception as e:
+            # Handle connection error here
+            print(f"Connection error: {e}")
+
+            if r is not None:
+                try:
+                    res = r.json()
+                    if "detail" in res:
+                        return Exception(r.status_code, res)
+                except Exception:
+                    pass
+
+            else:
+                pass
+
+    return payload
+
+
 ##################################
 #
 # Pipelines Endpoints
@@ -39,7 +163,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
 router = APIRouter()
 
 
-@router.get("/api/pipelines/list")
+@router.get("/list")
 async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
     responses = await get_all_models_responses(request)
     log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
@@ -61,7 +185,7 @@ async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
     }
 
 
-@router.post("/api/pipelines/upload")
+@router.post("/upload")
 async def upload_pipeline(
     request: Request,
     urlIdx: int = Form(...),
@@ -131,7 +255,7 @@ class AddPipelineForm(BaseModel):
     urlIdx: int
 
 
-@router.post("/api/pipelines/add")
+@router.post("/add")
 async def add_pipeline(
     request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user)
 ):
@@ -176,7 +300,7 @@ class DeletePipelineForm(BaseModel):
     urlIdx: int
 
 
-@router.delete("/api/pipelines/delete")
+@router.delete("/delete")
 async def delete_pipeline(
     request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user)
 ):
@@ -216,7 +340,7 @@ async def delete_pipeline(
         )
 
 
-@router.get("/api/pipelines")
+@router.get("/")
 async def get_pipelines(
     request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user)
 ):
@@ -250,7 +374,7 @@ async def get_pipelines(
         )
 
 
-@router.get("/api/pipelines/{pipeline_id}/valves")
+@router.get("/{pipeline_id}/valves")
 async def get_pipeline_valves(
     request: Request,
     urlIdx: Optional[int],
@@ -289,7 +413,7 @@ async def get_pipeline_valves(
         )
 
 
-@router.get("/api/pipelines/{pipeline_id}/valves/spec")
+@router.get("/{pipeline_id}/valves/spec")
 async def get_pipeline_valves_spec(
     request: Request,
     urlIdx: Optional[int],
@@ -329,7 +453,7 @@ async def get_pipeline_valves_spec(
         )
 
 
-@router.post("/api/pipelines/{pipeline_id}/valves/update")
+@router.post("/{pipeline_id}/valves/update")
 async def update_pipeline_valves(
     request: Request,
     urlIdx: Optional[int],

+ 18 - 21
backend/open_webui/routers/tasks.py

@@ -1,6 +1,7 @@
 from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
+from fastapi.responses import JSONResponse, RedirectResponse
+
 from pydantic import BaseModel
-from starlette.responses import FileResponse
 from typing import Optional
 import logging
 
@@ -16,6 +17,9 @@ from open_webui.utils.task import (
 from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.constants import TASKS
 
+from open_webui.routers.pipelines import process_pipeline_inlet_filter
+from open_webui.utils.task import get_task_model_id
+
 from open_webui.config import (
     DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
     DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
@@ -121,9 +125,7 @@ async def update_task_config(
 async def generate_title(
     request: Request, form_data: dict, user=Depends(get_verified_user)
 ):
-
-    model_list = await get_all_models()
-    models = {model["id"]: model for model in model_list}
+    models = request.app.state.MODELS
 
     model_id = form_data["model"]
     if model_id not in models:
@@ -191,7 +193,7 @@ Artificial Intelligence in Healthcare
 
     # Handle pipeline filters
     try:
-        payload = filter_pipeline(payload, user, models)
+        payload = process_pipeline_inlet_filter(payload, user, models)
     except Exception as e:
         if len(e.args) > 1:
             return JSONResponse(
@@ -220,8 +222,7 @@ async def generate_chat_tags(
             content={"detail": "Tags generation is disabled"},
         )
 
-    model_list = await get_all_models()
-    models = {model["id"]: model for model in model_list}
+    models = request.app.state.MODELS
 
     model_id = form_data["model"]
     if model_id not in models:
@@ -281,7 +282,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
 
     # Handle pipeline filters
     try:
-        payload = filter_pipeline(payload, user, models)
+        payload = process_pipeline_inlet_filter(payload, user, models)
     except Exception as e:
         if len(e.args) > 1:
             return JSONResponse(
@@ -318,8 +319,7 @@ async def generate_queries(
                 detail=f"Query generation is disabled",
             )
 
-    model_list = await get_all_models()
-    models = {model["id"]: model for model in model_list}
+    models = request.app.state.MODELS
 
     model_id = form_data["model"]
     if model_id not in models:
@@ -363,7 +363,7 @@ async def generate_queries(
 
     # Handle pipeline filters
     try:
-        payload = filter_pipeline(payload, user, models)
+        payload = process_pipeline_inlet_filter(payload, user, models)
     except Exception as e:
         if len(e.args) > 1:
             return JSONResponse(
@@ -405,8 +405,7 @@ async def generate_autocompletion(
                 detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
             )
 
-    model_list = await get_all_models()
-    models = {model["id"]: model for model in model_list}
+    models = request.app.state.MODELS
 
     model_id = form_data["model"]
     if model_id not in models:
@@ -450,7 +449,7 @@ async def generate_autocompletion(
 
     # Handle pipeline filters
     try:
-        payload = filter_pipeline(payload, user, models)
+        payload = process_pipeline_inlet_filter(payload, user, models)
     except Exception as e:
         if len(e.args) > 1:
             return JSONResponse(
@@ -473,8 +472,7 @@ async def generate_emoji(
     request: Request, form_data: dict, user=Depends(get_verified_user)
 ):
 
-    model_list = await get_all_models()
-    models = {model["id"]: model for model in model_list}
+    models = request.app.state.MODELS
 
     model_id = form_data["model"]
     if model_id not in models:
@@ -525,7 +523,7 @@ Message: """{{prompt}}"""
 
     # Handle pipeline filters
     try:
-        payload = filter_pipeline(payload, user, models)
+        payload = process_pipeline_inlet_filter(payload, user, models)
     except Exception as e:
         if len(e.args) > 1:
             return JSONResponse(
@@ -548,10 +546,9 @@ async def generate_moa_response(
     request: Request, form_data: dict, user=Depends(get_verified_user)
 ):
 
-    model_list = await get_all_models()
-    models = {model["id"]: model for model in model_list}
-
+    models = request.app.state.MODELS
     model_id = form_data["model"]
+
     if model_id not in models:
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND,
@@ -593,7 +590,7 @@ Responses from models: {{responses}}"""
     }
 
     try:
-        payload = filter_pipeline(payload, user, models)
+        payload = process_pipeline_inlet_filter(payload, user, models)
     except Exception as e:
         if len(e.args) > 1:
             return JSONResponse(

+ 16 - 0
backend/open_webui/utils/task.py

@@ -16,6 +16,22 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
+def get_task_model_id(
+    default_model_id: str, task_model: str, task_model_external: str, models
+) -> str:
+    # Set the task model
+    task_model_id = default_model_id
+    # Check if the user has a custom task model and use that model
+    if models[task_model_id]["owned_by"] == "ollama":
+        if task_model and task_model in models:
+            task_model_id = task_model
+    else:
+        if task_model_external and task_model_external in models:
+            task_model_id = task_model_external
+
+    return task_model_id
+
+
 def prompt_template(
     template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
 ) -> str: