|
@@ -89,6 +89,8 @@ from open_webui.config import (
|
|
|
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
|
TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
|
TAGS_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
|
|
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
|
WEBHOOK_URL,
|
|
|
WEBUI_AUTH,
|
|
@@ -127,6 +129,7 @@ from open_webui.utils.task import (
|
|
|
rag_template,
|
|
|
title_generation_template,
|
|
|
query_generation_template,
|
|
|
+ autocomplete_generation_template,
|
|
|
tags_generation_template,
|
|
|
emoji_generation_template,
|
|
|
moa_response_generation_template,
|
|
@@ -215,6 +218,10 @@ app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION
|
|
|
app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
|
|
|
app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE
|
|
|
|
|
|
+app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = (
|
|
|
+ AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
|
|
|
+)
|
|
|
+
|
|
|
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
|
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
|
)
|
|
@@ -1982,6 +1989,73 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)):
|
|
|
return await generate_chat_completions(form_data=payload, user=user)
|
|
|
|
|
|
|
|
|
+@app.post("/api/task/auto/completions")
|
|
|
+async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)):
|
|
|
+ context = form_data.get("context")
|
|
|
+
|
|
|
+ model_list = await get_all_models()
|
|
|
+ models = {model["id"]: model for model in model_list}
|
|
|
+
|
|
|
+ model_id = form_data["model"]
|
|
|
+ if model_id not in models:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
+ detail="Model not found",
|
|
|
+ )
|
|
|
+
|
|
|
+ # Check if the user has a custom task model
|
|
|
+ # If the user has a custom task model, use that model
|
|
|
+ task_model_id = get_task_model_id(
|
|
|
+ model_id,
|
|
|
+ app.state.config.TASK_MODEL,
|
|
|
+ app.state.config.TASK_MODEL_EXTERNAL,
|
|
|
+ models,
|
|
|
+ )
|
|
|
+
|
|
|
+ log.debug(
|
|
|
+ f"generating autocompletion using model {task_model_id} for user {user.email}"
|
|
|
+ )
|
|
|
+
|
|
|
+ if (app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
|
|
|
+ template = app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
|
|
|
+ else:
|
|
|
+ template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
|
|
|
+
|
|
|
+ content = autocomplete_generation_template(
|
|
|
+ template, form_data["messages"], context, {"name": user.name}
|
|
|
+ )
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ "model": task_model_id,
|
|
|
+ "messages": [{"role": "user", "content": content}],
|
|
|
+ "stream": False,
|
|
|
+ "metadata": {
|
|
|
+ "task": str(TASKS.AUTOCOMPLETION_GENERATION),
|
|
|
+ "task_body": form_data,
|
|
|
+ "chat_id": form_data.get("chat_id", None),
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ # Handle pipeline filters
|
|
|
+ try:
|
|
|
+ payload = filter_pipeline(payload, user, models)
|
|
|
+ except Exception as e:
|
|
|
+ if len(e.args) > 1:
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=e.args[0],
|
|
|
+ content={"detail": e.args[1]},
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
+ content={"detail": str(e)},
|
|
|
+ )
|
|
|
+ if "chat_id" in payload:
|
|
|
+ del payload["chat_id"]
|
|
|
+
|
|
|
+ return await generate_chat_completions(form_data=payload, user=user)
|
|
|
+
|
|
|
+
|
|
|
@app.post("/api/task/emoji/completions")
|
|
|
async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
|
|
|
|