Explorar o código

feat: autocomplete backend endpoint

Timothy Jaeryang Baek hai 5 meses
pai
achega
0e8e9820d0

+ 41 - 0
backend/open_webui/config.py

@@ -999,6 +999,47 @@ Strictly return in JSON format:
 """
 
 
+AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
+    "AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE",
+    "task.autocomplete.prompt_template",
+    os.environ.get("AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE", ""),
+)
+
+DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = """### Task:
+You are an **autocompletion system**. Your sole task is to generate concise, logical continuations for text provided within the `<text>` tag. Additional guidance on the purpose, tone, or format will be included in a `<context>` tag. 
+
+Only output a continuation. If you are unsure how to proceed, output nothing.
+
+### **Instructions**
+1. Analyze the `<text>` to understand its structure, context, and flow.
+2. Refer to the `<context>` for any specific purpose or format (e.g., search queries, general, etc.).
+3. Complete the text concisely and meaningfully without repeating or altering the original.
+4. Do not introduce unrelated ideas or elaborate unnecessarily.
+
+### **Output Rules**
+- Respond *only* with the continuation—no preamble or explanation.
+- Ensure the continuation directly connects to the given text and adheres to the context.
+- If unsure about completing, provide no output.
+
+### **Examples**
+
+**Example 1**  
+<context>General</context>
+<text>The sun was dipping below the horizon, painting the sky in shades of pink and orange as the cool breeze began to set in.</text>
+**Output**: A sense of calm spread through the air, and the first stars started to shimmer faintly above.
+
+**Example 2**  
+<context>Search</context>
+<text>How to prepare for a job interview</text>
+**Output**: effectively, including researching the company and practicing common questions.
+
+**Example 3**  
+<context>Search</context>
+<text>Best destinations for hiking in</text> 
+**Output**: Europe, such as the Alps or the Scottish Highlands.
+"""
+
+
 TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
     "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
     "task.tools.prompt_template",

+ 1 - 0
backend/open_webui/constants.py

@@ -113,5 +113,6 @@ class TASKS(str, Enum):
     TAGS_GENERATION = "tags_generation"
     EMOJI_GENERATION = "emoji_generation"
     QUERY_GENERATION = "query_generation"
+    AUTOCOMPLETION_GENERATION = "autocompletion_generation"
     FUNCTION_CALLING = "function_calling"
     MOA_RESPONSE_GENERATION = "moa_response_generation"

+ 74 - 0
backend/open_webui/main.py

@@ -89,6 +89,8 @@ from open_webui.config import (
     DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
     TITLE_GENERATION_PROMPT_TEMPLATE,
     TAGS_GENERATION_PROMPT_TEMPLATE,
+    AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
+    DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
     TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
     WEBHOOK_URL,
     WEBUI_AUTH,
@@ -127,6 +129,7 @@ from open_webui.utils.task import (
     rag_template,
     title_generation_template,
     query_generation_template,
+    autocomplete_generation_template,
     tags_generation_template,
     emoji_generation_template,
     moa_response_generation_template,
@@ -215,6 +218,10 @@ app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION
 app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
 app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE
 
+app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = (
+    AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
+)
+
 app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
     TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
 )
@@ -1982,6 +1989,73 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)):
     return await generate_chat_completions(form_data=payload, user=user)
 
 
+@app.post("/api/task/auto/completions")
+async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)):
+    context = form_data.get("context")
+
+    model_list = await get_all_models()
+    models = {model["id"]: model for model in model_list}
+
+    model_id = form_data["model"]
+    if model_id not in models:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="Model not found",
+        )
+
+    # Check if the user has a custom task model
+    # If the user has a custom task model, use that model
+    task_model_id = get_task_model_id(
+        model_id,
+        app.state.config.TASK_MODEL,
+        app.state.config.TASK_MODEL_EXTERNAL,
+        models,
+    )
+
+    log.debug(
+        f"generating autocompletion using model {task_model_id} for user {user.email}"
+    )
+
+    if (app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
+        template = app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
+    else:
+        template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
+
+    content = autocomplete_generation_template(
+        template, form_data["messages"], context, {"name": user.name}
+    )
+
+    payload = {
+        "model": task_model_id,
+        "messages": [{"role": "user", "content": content}],
+        "stream": False,
+        "metadata": {
+            "task": str(TASKS.AUTOCOMPLETION_GENERATION),
+            "task_body": form_data,
+            "chat_id": form_data.get("chat_id", None),
+        },
+    }
+
+    # Handle pipeline filters
+    try:
+        payload = filter_pipeline(payload, user, models)
+    except Exception as e:
+        if len(e.args) > 1:
+            return JSONResponse(
+                status_code=e.args[0],
+                content={"detail": e.args[1]},
+            )
+        else:
+            return JSONResponse(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                content={"detail": str(e)},
+            )
+    if "chat_id" in payload:
+        del payload["chat_id"]
+
+    return await generate_chat_completions(form_data=payload, user=user)
+
+
 @app.post("/api/task/emoji/completions")
 async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
 

+ 23 - 0
backend/open_webui/utils/task.py

@@ -212,6 +212,29 @@ def emoji_generation_template(
     return template
 
 
+def autocomplete_generation_template(
+    template: str,
+    messages: list[dict],
+    context: Optional[str] = None,
+    user: Optional[dict] = None,
+) -> str:
+    prompt = get_last_user_message(messages)
+    template = template.replace("{{CONTEXT}}", context if context else "")
+
+    template = replace_prompt_variable(template, prompt)
+    template = replace_messages_variable(template, messages)
+
+    template = prompt_template(
+        template,
+        **(
+            {"user_name": user.get("name"), "user_location": user.get("location")}
+            if user
+            else {}
+        ),
+    )
+    return template
+
+
 def query_generation_template(
     template: str, messages: list[dict], user: Optional[dict] = None
 ) -> str: