123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642 |
- from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
- from fastapi.responses import JSONResponse, RedirectResponse
- from pydantic import BaseModel
- from typing import Optional
- import logging
- import re
- from open_webui.utils.chat import generate_chat_completion
- from open_webui.utils.task import (
- title_generation_template,
- query_generation_template,
- image_prompt_generation_template,
- autocomplete_generation_template,
- tags_generation_template,
- emoji_generation_template,
- moa_response_generation_template,
- )
- from open_webui.utils.auth import get_admin_user, get_verified_user
- from open_webui.constants import TASKS
- from open_webui.routers.pipelines import process_pipeline_inlet_filter
- from open_webui.utils.task import get_task_model_id
- from open_webui.config import (
- DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
- DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
- DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
- DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
- DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
- DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
- DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE,
- )
- from open_webui.env import SRC_LOG_LEVELS
- log = logging.getLogger(__name__)
- log.setLevel(SRC_LOG_LEVELS["MODELS"])
- router = APIRouter()
- ##################################
- #
- # Task Endpoints
- #
- ##################################
- @router.get("/config")
- async def get_task_config(request: Request, user=Depends(get_verified_user)):
- return {
- "TASK_MODEL": request.app.state.config.TASK_MODEL,
- "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
- "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
- "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
- "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
- "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
- "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
- "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
- "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
- "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
- "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
- "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
- }
- class TaskConfigForm(BaseModel):
- TASK_MODEL: Optional[str]
- TASK_MODEL_EXTERNAL: Optional[str]
- TITLE_GENERATION_PROMPT_TEMPLATE: str
- IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
- ENABLE_AUTOCOMPLETE_GENERATION: bool
- AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
- TAGS_GENERATION_PROMPT_TEMPLATE: str
- ENABLE_TAGS_GENERATION: bool
- ENABLE_SEARCH_QUERY_GENERATION: bool
- ENABLE_RETRIEVAL_QUERY_GENERATION: bool
- QUERY_GENERATION_PROMPT_TEMPLATE: str
- TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
- @router.post("/config/update")
- async def update_task_config(
- request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)
- ):
- request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
- request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
- request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
- form_data.TITLE_GENERATION_PROMPT_TEMPLATE
- )
- request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
- form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
- )
- request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
- form_data.ENABLE_AUTOCOMPLETE_GENERATION
- )
- request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
- form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
- )
- request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
- form_data.TAGS_GENERATION_PROMPT_TEMPLATE
- )
- request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
- request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
- form_data.ENABLE_SEARCH_QUERY_GENERATION
- )
- request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
- form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
- )
- request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
- form_data.QUERY_GENERATION_PROMPT_TEMPLATE
- )
- request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
- form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
- )
- return {
- "TASK_MODEL": request.app.state.config.TASK_MODEL,
- "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
- "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
- "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
- "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
- "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
- "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
- "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
- "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
- "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
- "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
- "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
- }
- @router.post("/title/completions")
- async def generate_title(
- request: Request, form_data: dict, user=Depends(get_verified_user)
- ):
- if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
- models = {
- request.state.model["id"]: request.state.model,
- }
- else:
- models = request.app.state.MODELS
- model_id = form_data["model"]
- if model_id not in models:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Model not found",
- )
- # Check if the user has a custom task model
- # If the user has a custom task model, use that model
- task_model_id = get_task_model_id(
- model_id,
- request.app.state.config.TASK_MODEL,
- request.app.state.config.TASK_MODEL_EXTERNAL,
- models,
- )
- log.debug(
- f"generating chat title using model {task_model_id} for user {user.email} "
- )
- if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
- template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
- else:
- template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
- messages = form_data["messages"]
- # Remove reasoning details from the messages
- for message in messages:
- message["content"] = re.sub(
- r"<details\s+type=\"reasoning\"[^>]*>.*?<\/details>",
- "",
- message["content"],
- flags=re.S,
- ).strip()
- content = title_generation_template(
- template,
- messages,
- {
- "name": user.name,
- "location": user.info.get("location") if user.info else None,
- },
- )
- payload = {
- "model": task_model_id,
- "messages": [{"role": "user", "content": content}],
- "stream": False,
- **(
- {"max_tokens": 1000}
- if models[task_model_id]["owned_by"] == "ollama"
- else {
- "max_completion_tokens": 1000,
- }
- ),
- "metadata": {
- **(request.state.metadata if hasattr(request.state, "metadata") else {}),
- "task": str(TASKS.TITLE_GENERATION),
- "task_body": form_data,
- "chat_id": form_data.get("chat_id", None),
- },
- }
- try:
- return await generate_chat_completion(request, form_data=payload, user=user)
- except Exception as e:
- log.error("Exception occurred", exc_info=True)
- return JSONResponse(
- status_code=status.HTTP_400_BAD_REQUEST,
- content={"detail": "An internal error has occurred."},
- )
- @router.post("/tags/completions")
- async def generate_chat_tags(
- request: Request, form_data: dict, user=Depends(get_verified_user)
- ):
- if not request.app.state.config.ENABLE_TAGS_GENERATION:
- return JSONResponse(
- status_code=status.HTTP_200_OK,
- content={"detail": "Tags generation is disabled"},
- )
- if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
- models = {
- request.state.model["id"]: request.state.model,
- }
- else:
- models = request.app.state.MODELS
- model_id = form_data["model"]
- if model_id not in models:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Model not found",
- )
- # Check if the user has a custom task model
- # If the user has a custom task model, use that model
- task_model_id = get_task_model_id(
- model_id,
- request.app.state.config.TASK_MODEL,
- request.app.state.config.TASK_MODEL_EXTERNAL,
- models,
- )
- log.debug(
- f"generating chat tags using model {task_model_id} for user {user.email} "
- )
- if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
- template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
- else:
- template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE
- content = tags_generation_template(
- template, form_data["messages"], {"name": user.name}
- )
- payload = {
- "model": task_model_id,
- "messages": [{"role": "user", "content": content}],
- "stream": False,
- "metadata": {
- **(request.state.metadata if hasattr(request.state, "metadata") else {}),
- "task": str(TASKS.TAGS_GENERATION),
- "task_body": form_data,
- "chat_id": form_data.get("chat_id", None),
- },
- }
- try:
- return await generate_chat_completion(request, form_data=payload, user=user)
- except Exception as e:
- log.error(f"Error generating chat completion: {e}")
- return JSONResponse(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- content={"detail": "An internal error has occurred."},
- )
- @router.post("/image_prompt/completions")
- async def generate_image_prompt(
- request: Request, form_data: dict, user=Depends(get_verified_user)
- ):
- if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
- models = {
- request.state.model["id"]: request.state.model,
- }
- else:
- models = request.app.state.MODELS
- model_id = form_data["model"]
- if model_id not in models:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Model not found",
- )
- # Check if the user has a custom task model
- # If the user has a custom task model, use that model
- task_model_id = get_task_model_id(
- model_id,
- request.app.state.config.TASK_MODEL,
- request.app.state.config.TASK_MODEL_EXTERNAL,
- models,
- )
- log.debug(
- f"generating image prompt using model {task_model_id} for user {user.email} "
- )
- if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "":
- template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
- else:
- template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
- content = image_prompt_generation_template(
- template,
- form_data["messages"],
- user={
- "name": user.name,
- },
- )
- payload = {
- "model": task_model_id,
- "messages": [{"role": "user", "content": content}],
- "stream": False,
- "metadata": {
- **(request.state.metadata if hasattr(request.state, "metadata") else {}),
- "task": str(TASKS.IMAGE_PROMPT_GENERATION),
- "task_body": form_data,
- "chat_id": form_data.get("chat_id", None),
- },
- }
- try:
- return await generate_chat_completion(request, form_data=payload, user=user)
- except Exception as e:
- log.error("Exception occurred", exc_info=True)
- return JSONResponse(
- status_code=status.HTTP_400_BAD_REQUEST,
- content={"detail": "An internal error has occurred."},
- )
- @router.post("/queries/completions")
- async def generate_queries(
- request: Request, form_data: dict, user=Depends(get_verified_user)
- ):
- type = form_data.get("type")
- if type == "web_search":
- if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=f"Search query generation is disabled",
- )
- elif type == "retrieval":
- if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=f"Query generation is disabled",
- )
- if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
- models = {
- request.state.model["id"]: request.state.model,
- }
- else:
- models = request.app.state.MODELS
- model_id = form_data["model"]
- if model_id not in models:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Model not found",
- )
- # Check if the user has a custom task model
- # If the user has a custom task model, use that model
- task_model_id = get_task_model_id(
- model_id,
- request.app.state.config.TASK_MODEL,
- request.app.state.config.TASK_MODEL_EXTERNAL,
- models,
- )
- log.debug(
- f"generating {type} queries using model {task_model_id} for user {user.email}"
- )
- if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "":
- template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
- else:
- template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
- content = query_generation_template(
- template, form_data["messages"], {"name": user.name}
- )
- payload = {
- "model": task_model_id,
- "messages": [{"role": "user", "content": content}],
- "stream": False,
- "metadata": {
- **(request.state.metadata if hasattr(request.state, "metadata") else {}),
- "task": str(TASKS.QUERY_GENERATION),
- "task_body": form_data,
- "chat_id": form_data.get("chat_id", None),
- },
- }
- try:
- return await generate_chat_completion(request, form_data=payload, user=user)
- except Exception as e:
- return JSONResponse(
- status_code=status.HTTP_400_BAD_REQUEST,
- content={"detail": str(e)},
- )
- @router.post("/auto/completions")
- async def generate_autocompletion(
- request: Request, form_data: dict, user=Depends(get_verified_user)
- ):
- if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=f"Autocompletion generation is disabled",
- )
- type = form_data.get("type")
- prompt = form_data.get("prompt")
- messages = form_data.get("messages")
- if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0:
- if (
- len(prompt)
- > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
- ):
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
- )
- if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
- models = {
- request.state.model["id"]: request.state.model,
- }
- else:
- models = request.app.state.MODELS
- model_id = form_data["model"]
- if model_id not in models:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Model not found",
- )
- # Check if the user has a custom task model
- # If the user has a custom task model, use that model
- task_model_id = get_task_model_id(
- model_id,
- request.app.state.config.TASK_MODEL,
- request.app.state.config.TASK_MODEL_EXTERNAL,
- models,
- )
- log.debug(
- f"generating autocompletion using model {task_model_id} for user {user.email}"
- )
- if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
- template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
- else:
- template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
- content = autocomplete_generation_template(
- template, prompt, messages, type, {"name": user.name}
- )
- payload = {
- "model": task_model_id,
- "messages": [{"role": "user", "content": content}],
- "stream": False,
- "metadata": {
- **(request.state.metadata if hasattr(request.state, "metadata") else {}),
- "task": str(TASKS.AUTOCOMPLETE_GENERATION),
- "task_body": form_data,
- "chat_id": form_data.get("chat_id", None),
- },
- }
- try:
- return await generate_chat_completion(request, form_data=payload, user=user)
- except Exception as e:
- log.error(f"Error generating chat completion: {e}")
- return JSONResponse(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- content={"detail": "An internal error has occurred."},
- )
- @router.post("/emoji/completions")
- async def generate_emoji(
- request: Request, form_data: dict, user=Depends(get_verified_user)
- ):
- if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
- models = {
- request.state.model["id"]: request.state.model,
- }
- else:
- models = request.app.state.MODELS
- model_id = form_data["model"]
- if model_id not in models:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Model not found",
- )
- # Check if the user has a custom task model
- # If the user has a custom task model, use that model
- task_model_id = get_task_model_id(
- model_id,
- request.app.state.config.TASK_MODEL,
- request.app.state.config.TASK_MODEL_EXTERNAL,
- models,
- )
- log.debug(f"generating emoji using model {task_model_id} for user {user.email} ")
- template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE
- content = emoji_generation_template(
- template,
- form_data["prompt"],
- {
- "name": user.name,
- "location": user.info.get("location") if user.info else None,
- },
- )
- payload = {
- "model": task_model_id,
- "messages": [{"role": "user", "content": content}],
- "stream": False,
- **(
- {"max_tokens": 4}
- if models[task_model_id]["owned_by"] == "ollama"
- else {
- "max_completion_tokens": 4,
- }
- ),
- "chat_id": form_data.get("chat_id", None),
- "metadata": {
- **(request.state.metadata if hasattr(request.state, "metadata") else {}),
- "task": str(TASKS.EMOJI_GENERATION),
- "task_body": form_data,
- },
- }
- try:
- return await generate_chat_completion(request, form_data=payload, user=user)
- except Exception as e:
- return JSONResponse(
- status_code=status.HTTP_400_BAD_REQUEST,
- content={"detail": str(e)},
- )
- @router.post("/moa/completions")
- async def generate_moa_response(
- request: Request, form_data: dict, user=Depends(get_verified_user)
- ):
- if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
- models = {
- request.state.model["id"]: request.state.model,
- }
- else:
- models = request.app.state.MODELS
- model_id = form_data["model"]
- if model_id not in models:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Model not found",
- )
- # Check if the user has a custom task model
- # If the user has a custom task model, use that model
- task_model_id = get_task_model_id(
- model_id,
- request.app.state.config.TASK_MODEL,
- request.app.state.config.TASK_MODEL_EXTERNAL,
- models,
- )
- log.debug(f"generating MOA model {task_model_id} for user {user.email} ")
- template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
- content = moa_response_generation_template(
- template,
- form_data["prompt"],
- form_data["responses"],
- )
- payload = {
- "model": task_model_id,
- "messages": [{"role": "user", "content": content}],
- "stream": form_data.get("stream", False),
- "metadata": {
- **(request.state.metadata if hasattr(request.state, "metadata") else {}),
- "chat_id": form_data.get("chat_id", None),
- "task": str(TASKS.MOA_RESPONSE_GENERATION),
- "task_body": form_data,
- },
- }
- try:
- return await generate_chat_completion(request, form_data=payload, user=user)
- except Exception as e:
- return JSONResponse(
- status_code=status.HTTP_400_BAD_REQUEST,
- content={"detail": str(e)},
- )
|