tasks.py 19 KB


  1. from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
  2. from fastapi.responses import JSONResponse, RedirectResponse
  3. from pydantic import BaseModel
  4. from typing import Optional
  5. import logging
  6. from open_webui.utils.chat import generate_chat_completion
  7. from open_webui.utils.task import (
  8. title_generation_template,
  9. query_generation_template,
  10. image_prompt_generation_template,
  11. autocomplete_generation_template,
  12. tags_generation_template,
  13. emoji_generation_template,
  14. moa_response_generation_template,
  15. )
  16. from open_webui.utils.auth import get_admin_user, get_verified_user
  17. from open_webui.constants import TASKS
  18. from open_webui.routers.pipelines import process_pipeline_inlet_filter
  19. from open_webui.utils.task import get_task_model_id
  20. from open_webui.config import (
  21. DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
  22. DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
  23. DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
  24. DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
  25. DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
  26. DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
  27. DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE,
  28. )
  29. from open_webui.env import SRC_LOG_LEVELS
  30. log = logging.getLogger(__name__)
  31. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  32. router = APIRouter()
  33. ##################################
  34. #
  35. # Task Endpoints
  36. #
  37. ##################################
  38. @router.get("/config")
  39. async def get_task_config(request: Request, user=Depends(get_verified_user)):
  40. return {
  41. "TASK_MODEL": request.app.state.config.TASK_MODEL,
  42. "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
  43. "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
  44. "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
  45. "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
  46. "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
  47. "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
  48. "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
  49. "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
  50. "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
  51. "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
  52. "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  53. }
  54. class TaskConfigForm(BaseModel):
  55. TASK_MODEL: Optional[str]
  56. TASK_MODEL_EXTERNAL: Optional[str]
  57. TITLE_GENERATION_PROMPT_TEMPLATE: str
  58. IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
  59. ENABLE_AUTOCOMPLETE_GENERATION: bool
  60. AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
  61. TAGS_GENERATION_PROMPT_TEMPLATE: str
  62. ENABLE_TAGS_GENERATION: bool
  63. ENABLE_SEARCH_QUERY_GENERATION: bool
  64. ENABLE_RETRIEVAL_QUERY_GENERATION: bool
  65. QUERY_GENERATION_PROMPT_TEMPLATE: str
  66. TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
  67. @router.post("/config/update")
  68. async def update_task_config(
  69. request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)
  70. ):
  71. request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
  72. request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
  73. request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
  74. form_data.TITLE_GENERATION_PROMPT_TEMPLATE
  75. )
  76. request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
  77. form_data.ENABLE_AUTOCOMPLETE_GENERATION
  78. )
  79. request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
  80. form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
  81. )
  82. request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
  83. form_data.TAGS_GENERATION_PROMPT_TEMPLATE
  84. )
  85. request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
  86. request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
  87. form_data.ENABLE_SEARCH_QUERY_GENERATION
  88. )
  89. request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
  90. form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
  91. )
  92. request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
  93. form_data.QUERY_GENERATION_PROMPT_TEMPLATE
  94. )
  95. request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
  96. form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
  97. )
  98. return {
  99. "TASK_MODEL": request.app.state.config.TASK_MODEL,
  100. "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
  101. "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
  102. "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
  103. "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
  104. "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
  105. "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
  106. "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
  107. "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
  108. "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
  109. "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
  110. "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  111. }
  112. @router.post("/title/completions")
  113. async def generate_title(
  114. request: Request, form_data: dict, user=Depends(get_verified_user)
  115. ):
  116. models = request.app.state.MODELS
  117. model_id = form_data["model"]
  118. if model_id not in models:
  119. raise HTTPException(
  120. status_code=status.HTTP_404_NOT_FOUND,
  121. detail="Model not found",
  122. )
  123. # Check if the user has a custom task model
  124. # If the user has a custom task model, use that model
  125. task_model_id = get_task_model_id(
  126. model_id,
  127. request.app.state.config.TASK_MODEL,
  128. request.app.state.config.TASK_MODEL_EXTERNAL,
  129. models,
  130. )
  131. log.debug(
  132. f"generating chat title using model {task_model_id} for user {user.email} "
  133. )
  134. if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
  135. template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
  136. else:
  137. template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
  138. content = title_generation_template(
  139. template,
  140. form_data["messages"],
  141. {
  142. "name": user.name,
  143. "location": user.info.get("location") if user.info else None,
  144. },
  145. )
  146. payload = {
  147. "model": task_model_id,
  148. "messages": [{"role": "user", "content": content}],
  149. "stream": False,
  150. **(
  151. {"max_tokens": 50}
  152. if models[task_model_id]["owned_by"] == "ollama"
  153. else {
  154. "max_completion_tokens": 50,
  155. }
  156. ),
  157. "metadata": {
  158. "task": str(TASKS.TITLE_GENERATION),
  159. "task_body": form_data,
  160. "chat_id": form_data.get("chat_id", None),
  161. },
  162. }
  163. try:
  164. return await generate_chat_completion(request, form_data=payload, user=user)
  165. except Exception as e:
  166. log.error("Exception occurred", exc_info=True)
  167. return JSONResponse(
  168. status_code=status.HTTP_400_BAD_REQUEST,
  169. content={"detail": "An internal error has occurred."},
  170. )
  171. @router.post("/tags/completions")
  172. async def generate_chat_tags(
  173. request: Request, form_data: dict, user=Depends(get_verified_user)
  174. ):
  175. if not request.app.state.config.ENABLE_TAGS_GENERATION:
  176. return JSONResponse(
  177. status_code=status.HTTP_200_OK,
  178. content={"detail": "Tags generation is disabled"},
  179. )
  180. models = request.app.state.MODELS
  181. model_id = form_data["model"]
  182. if model_id not in models:
  183. raise HTTPException(
  184. status_code=status.HTTP_404_NOT_FOUND,
  185. detail="Model not found",
  186. )
  187. # Check if the user has a custom task model
  188. # If the user has a custom task model, use that model
  189. task_model_id = get_task_model_id(
  190. model_id,
  191. request.app.state.config.TASK_MODEL,
  192. request.app.state.config.TASK_MODEL_EXTERNAL,
  193. models,
  194. )
  195. log.debug(
  196. f"generating chat tags using model {task_model_id} for user {user.email} "
  197. )
  198. if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
  199. template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
  200. else:
  201. template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE
  202. content = tags_generation_template(
  203. template, form_data["messages"], {"name": user.name}
  204. )
  205. payload = {
  206. "model": task_model_id,
  207. "messages": [{"role": "user", "content": content}],
  208. "stream": False,
  209. "metadata": {
  210. "task": str(TASKS.TAGS_GENERATION),
  211. "task_body": form_data,
  212. "chat_id": form_data.get("chat_id", None),
  213. },
  214. }
  215. try:
  216. return await generate_chat_completion(request, form_data=payload, user=user)
  217. except Exception as e:
  218. log.error(f"Error generating chat completion: {e}")
  219. return JSONResponse(
  220. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  221. content={"detail": "An internal error has occurred."},
  222. )
  223. @router.post("/image_prompt/completions")
  224. async def generate_image_prompt(
  225. request: Request, form_data: dict, user=Depends(get_verified_user)
  226. ):
  227. models = request.app.state.MODELS
  228. model_id = form_data["model"]
  229. if model_id not in models:
  230. raise HTTPException(
  231. status_code=status.HTTP_404_NOT_FOUND,
  232. detail="Model not found",
  233. )
  234. # Check if the user has a custom task model
  235. # If the user has a custom task model, use that model
  236. task_model_id = get_task_model_id(
  237. model_id,
  238. request.app.state.config.TASK_MODEL,
  239. request.app.state.config.TASK_MODEL_EXTERNAL,
  240. models,
  241. )
  242. log.debug(
  243. f"generating image prompt using model {task_model_id} for user {user.email} "
  244. )
  245. if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "":
  246. template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
  247. else:
  248. template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
  249. content = image_prompt_generation_template(
  250. template,
  251. form_data["messages"],
  252. user={
  253. "name": user.name,
  254. },
  255. )
  256. payload = {
  257. "model": task_model_id,
  258. "messages": [{"role": "user", "content": content}],
  259. "stream": False,
  260. "metadata": {
  261. "task": str(TASKS.IMAGE_PROMPT_GENERATION),
  262. "task_body": form_data,
  263. "chat_id": form_data.get("chat_id", None),
  264. },
  265. }
  266. try:
  267. return await generate_chat_completion(request, form_data=payload, user=user)
  268. except Exception as e:
  269. log.error("Exception occurred", exc_info=True)
  270. return JSONResponse(
  271. status_code=status.HTTP_400_BAD_REQUEST,
  272. content={"detail": "An internal error has occurred."},
  273. )
  274. @router.post("/queries/completions")
  275. async def generate_queries(
  276. request: Request, form_data: dict, user=Depends(get_verified_user)
  277. ):
  278. type = form_data.get("type")
  279. if type == "web_search":
  280. if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
  281. raise HTTPException(
  282. status_code=status.HTTP_400_BAD_REQUEST,
  283. detail=f"Search query generation is disabled",
  284. )
  285. elif type == "retrieval":
  286. if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
  287. raise HTTPException(
  288. status_code=status.HTTP_400_BAD_REQUEST,
  289. detail=f"Query generation is disabled",
  290. )
  291. models = request.app.state.MODELS
  292. model_id = form_data["model"]
  293. if model_id not in models:
  294. raise HTTPException(
  295. status_code=status.HTTP_404_NOT_FOUND,
  296. detail="Model not found",
  297. )
  298. # Check if the user has a custom task model
  299. # If the user has a custom task model, use that model
  300. task_model_id = get_task_model_id(
  301. model_id,
  302. request.app.state.config.TASK_MODEL,
  303. request.app.state.config.TASK_MODEL_EXTERNAL,
  304. models,
  305. )
  306. log.debug(
  307. f"generating {type} queries using model {task_model_id} for user {user.email}"
  308. )
  309. if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "":
  310. template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
  311. else:
  312. template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
  313. content = query_generation_template(
  314. template, form_data["messages"], {"name": user.name}
  315. )
  316. payload = {
  317. "model": task_model_id,
  318. "messages": [{"role": "user", "content": content}],
  319. "stream": False,
  320. "metadata": {
  321. "task": str(TASKS.QUERY_GENERATION),
  322. "task_body": form_data,
  323. "chat_id": form_data.get("chat_id", None),
  324. },
  325. }
  326. try:
  327. return await generate_chat_completion(request, form_data=payload, user=user)
  328. except Exception as e:
  329. return JSONResponse(
  330. status_code=status.HTTP_400_BAD_REQUEST,
  331. content={"detail": str(e)},
  332. )
  333. @router.post("/auto/completions")
  334. async def generate_autocompletion(
  335. request: Request, form_data: dict, user=Depends(get_verified_user)
  336. ):
  337. if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION:
  338. raise HTTPException(
  339. status_code=status.HTTP_400_BAD_REQUEST,
  340. detail=f"Autocompletion generation is disabled",
  341. )
  342. type = form_data.get("type")
  343. prompt = form_data.get("prompt")
  344. messages = form_data.get("messages")
  345. if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0:
  346. if (
  347. len(prompt)
  348. > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
  349. ):
  350. raise HTTPException(
  351. status_code=status.HTTP_400_BAD_REQUEST,
  352. detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
  353. )
  354. models = request.app.state.MODELS
  355. model_id = form_data["model"]
  356. if model_id not in models:
  357. raise HTTPException(
  358. status_code=status.HTTP_404_NOT_FOUND,
  359. detail="Model not found",
  360. )
  361. # Check if the user has a custom task model
  362. # If the user has a custom task model, use that model
  363. task_model_id = get_task_model_id(
  364. model_id,
  365. request.app.state.config.TASK_MODEL,
  366. request.app.state.config.TASK_MODEL_EXTERNAL,
  367. models,
  368. )
  369. log.debug(
  370. f"generating autocompletion using model {task_model_id} for user {user.email}"
  371. )
  372. if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
  373. template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
  374. else:
  375. template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
  376. content = autocomplete_generation_template(
  377. template, prompt, messages, type, {"name": user.name}
  378. )
  379. payload = {
  380. "model": task_model_id,
  381. "messages": [{"role": "user", "content": content}],
  382. "stream": False,
  383. "metadata": {
  384. "task": str(TASKS.AUTOCOMPLETE_GENERATION),
  385. "task_body": form_data,
  386. "chat_id": form_data.get("chat_id", None),
  387. },
  388. }
  389. try:
  390. return await generate_chat_completion(request, form_data=payload, user=user)
  391. except Exception as e:
  392. log.error(f"Error generating chat completion: {e}")
  393. return JSONResponse(
  394. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  395. content={"detail": "An internal error has occurred."},
  396. )
  397. @router.post("/emoji/completions")
  398. async def generate_emoji(
  399. request: Request, form_data: dict, user=Depends(get_verified_user)
  400. ):
  401. models = request.app.state.MODELS
  402. model_id = form_data["model"]
  403. if model_id not in models:
  404. raise HTTPException(
  405. status_code=status.HTTP_404_NOT_FOUND,
  406. detail="Model not found",
  407. )
  408. # Check if the user has a custom task model
  409. # If the user has a custom task model, use that model
  410. task_model_id = get_task_model_id(
  411. model_id,
  412. request.app.state.config.TASK_MODEL,
  413. request.app.state.config.TASK_MODEL_EXTERNAL,
  414. models,
  415. )
  416. log.debug(f"generating emoji using model {task_model_id} for user {user.email} ")
  417. template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE
  418. content = emoji_generation_template(
  419. template,
  420. form_data["prompt"],
  421. {
  422. "name": user.name,
  423. "location": user.info.get("location") if user.info else None,
  424. },
  425. )
  426. payload = {
  427. "model": task_model_id,
  428. "messages": [{"role": "user", "content": content}],
  429. "stream": False,
  430. **(
  431. {"max_tokens": 4}
  432. if models[task_model_id]["owned_by"] == "ollama"
  433. else {
  434. "max_completion_tokens": 4,
  435. }
  436. ),
  437. "chat_id": form_data.get("chat_id", None),
  438. "metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data},
  439. }
  440. try:
  441. return await generate_chat_completion(request, form_data=payload, user=user)
  442. except Exception as e:
  443. return JSONResponse(
  444. status_code=status.HTTP_400_BAD_REQUEST,
  445. content={"detail": str(e)},
  446. )
  447. @router.post("/moa/completions")
  448. async def generate_moa_response(
  449. request: Request, form_data: dict, user=Depends(get_verified_user)
  450. ):
  451. models = request.app.state.MODELS
  452. model_id = form_data["model"]
  453. if model_id not in models:
  454. raise HTTPException(
  455. status_code=status.HTTP_404_NOT_FOUND,
  456. detail="Model not found",
  457. )
  458. # Check if the user has a custom task model
  459. # If the user has a custom task model, use that model
  460. task_model_id = get_task_model_id(
  461. model_id,
  462. request.app.state.config.TASK_MODEL,
  463. request.app.state.config.TASK_MODEL_EXTERNAL,
  464. models,
  465. )
  466. log.debug(f"generating MOA model {task_model_id} for user {user.email} ")
  467. template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
  468. content = moa_response_generation_template(
  469. template,
  470. form_data["prompt"],
  471. form_data["responses"],
  472. )
  473. payload = {
  474. "model": task_model_id,
  475. "messages": [{"role": "user", "content": content}],
  476. "stream": form_data.get("stream", False),
  477. "metadata": {
  478. "chat_id": form_data.get("chat_id", None),
  479. "task": str(TASKS.MOA_RESPONSE_GENERATION),
  480. "task_body": form_data,
  481. },
  482. }
  483. try:
  484. return await generate_chat_completion(request, form_data=payload, user=user)
  485. except Exception as e:
  486. return JSONResponse(
  487. status_code=status.HTTP_400_BAD_REQUEST,
  488. content={"detail": str(e)},
  489. )