|
@@ -134,6 +134,7 @@ from open_webui.utils.misc import (
|
|
)
|
|
)
|
|
from open_webui.utils.task import (
|
|
from open_webui.utils.task import (
|
|
moa_response_generation_template,
|
|
moa_response_generation_template,
|
|
|
|
+ tags_generation_template,
|
|
search_query_generation_template,
|
|
search_query_generation_template,
|
|
title_generation_template,
|
|
title_generation_template,
|
|
tools_function_calling_generation_template,
|
|
tools_function_calling_generation_template,
|
|
@@ -1545,6 +1546,72 @@ Prompt: {{prompt:middletruncate:8000}}"""
|
|
return await generate_chat_completions(form_data=payload, user=user)
|
|
return await generate_chat_completions(form_data=payload, user=user)
|
|
|
|
|
|
|
|
|
|
|
|
+@app.post("/api/task/tags/completions")
|
|
|
|
+async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
|
|
|
|
+ print("generate_chat_tags")
|
|
|
|
+ model_id = form_data["model"]
|
|
|
|
+ if model_id not in app.state.MODELS:
|
|
|
|
+ raise HTTPException(
|
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
|
+ detail="Model not found",
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # Check if the user has a custom task model
|
|
|
|
+ # If the user has a custom task model, use that model
|
|
|
|
+ task_model_id = get_task_model_id(model_id)
|
|
|
|
+ print(task_model_id)
|
|
|
|
+
|
|
|
|
+ template = """### Task:
|
|
|
|
+Generate 1-3 broad tags categorizing the main themes of the chat history.
|
|
|
|
+
|
|
|
|
+### Guidelines:
|
|
|
|
+- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education)
|
|
|
|
+- Only add more specific subdomains if they are strongly represented throughout the conversation
|
|
|
|
+- If content is too short (less than 3 messages) or too diverse, use only ["General"]
|
|
|
|
+- Use the chat's primary language; default to English if multilingual
|
|
|
|
+- Prioritize accuracy over specificity
|
|
|
|
+
|
|
|
|
+### Output:
|
|
|
|
+JSON format: { "tags": ["tag1", "tag2", "tag3"] }
|
|
|
|
+
|
|
|
|
+### Chat History:
|
|
|
|
+<chat_history>
|
|
|
|
+{{MESSAGES:END:6}}
|
|
|
|
+</chat_history>"""
|
|
|
|
+
|
|
|
|
+ content = tags_generation_template(
|
|
|
|
+ template, form_data["messages"], {"name": user.name}
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ print("content", content)
|
|
|
|
+ payload = {
|
|
|
|
+ "model": task_model_id,
|
|
|
|
+ "messages": [{"role": "user", "content": content}],
|
|
|
|
+ "stream": False,
|
|
|
|
+ "metadata": {"task": str(TASKS.TAGS_GENERATION), "task_body": form_data},
|
|
|
|
+ }
|
|
|
|
+ log.debug(payload)
|
|
|
|
+
|
|
|
|
+ # Handle pipeline filters
|
|
|
|
+ try:
|
|
|
|
+ payload = filter_pipeline(payload, user)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ if len(e.args) > 1:
|
|
|
|
+ return JSONResponse(
|
|
|
|
+ status_code=e.args[0],
|
|
|
|
+ content={"detail": e.args[1]},
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ return JSONResponse(
|
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
+ content={"detail": str(e)},
|
|
|
|
+ )
|
|
|
|
+ if "chat_id" in payload:
|
|
|
|
+ del payload["chat_id"]
|
|
|
|
+
|
|
|
|
+ return await generate_chat_completions(form_data=payload, user=user)
|
|
|
|
+
|
|
|
|
+
|
|
@app.post("/api/task/query/completions")
|
|
@app.post("/api/task/query/completions")
|
|
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
|
|
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
|
|
print("generate_search_query")
|
|
print("generate_search_query")
|