tasks.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
  2. from fastapi.responses import JSONResponse, RedirectResponse
  3. from pydantic import BaseModel
  4. from typing import Optional
  5. import logging
  6. from open_webui.utils.chat import generate_chat_completion
  7. from open_webui.utils.task import (
  8. title_generation_template,
  9. query_generation_template,
  10. autocomplete_generation_template,
  11. tags_generation_template,
  12. emoji_generation_template,
  13. moa_response_generation_template,
  14. )
  15. from open_webui.utils.auth import get_admin_user, get_verified_user
  16. from open_webui.constants import TASKS
  17. from open_webui.routers.pipelines import process_pipeline_inlet_filter
  18. from open_webui.utils.task import get_task_model_id
  19. from open_webui.config import (
  20. DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
  21. DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
  22. DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
  23. DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
  24. DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
  25. DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE,
  26. )
  27. from open_webui.env import SRC_LOG_LEVELS
  28. log = logging.getLogger(__name__)
  29. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  30. router = APIRouter()
  31. ##################################
  32. #
  33. # Task Endpoints
  34. #
  35. ##################################
  36. @router.get("/config")
  37. async def get_task_config(request: Request, user=Depends(get_verified_user)):
  38. return {
  39. "TASK_MODEL": request.app.state.config.TASK_MODEL,
  40. "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
  41. "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
  42. "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
  43. "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
  44. "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
  45. "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
  46. "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
  47. "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
  48. "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
  49. "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  50. }
  51. class TaskConfigForm(BaseModel):
  52. TASK_MODEL: Optional[str]
  53. TASK_MODEL_EXTERNAL: Optional[str]
  54. TITLE_GENERATION_PROMPT_TEMPLATE: str
  55. ENABLE_AUTOCOMPLETE_GENERATION: bool
  56. AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
  57. TAGS_GENERATION_PROMPT_TEMPLATE: str
  58. ENABLE_TAGS_GENERATION: bool
  59. ENABLE_SEARCH_QUERY_GENERATION: bool
  60. ENABLE_RETRIEVAL_QUERY_GENERATION: bool
  61. QUERY_GENERATION_PROMPT_TEMPLATE: str
  62. TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
  63. @router.post("/config/update")
  64. async def update_task_config(
  65. request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)
  66. ):
  67. request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
  68. request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
  69. request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
  70. form_data.TITLE_GENERATION_PROMPT_TEMPLATE
  71. )
  72. request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
  73. form_data.ENABLE_AUTOCOMPLETE_GENERATION
  74. )
  75. request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
  76. form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
  77. )
  78. request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
  79. form_data.TAGS_GENERATION_PROMPT_TEMPLATE
  80. )
  81. request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
  82. request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
  83. form_data.ENABLE_SEARCH_QUERY_GENERATION
  84. )
  85. request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
  86. form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
  87. )
  88. request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
  89. form_data.QUERY_GENERATION_PROMPT_TEMPLATE
  90. )
  91. request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
  92. form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
  93. )
  94. return {
  95. "TASK_MODEL": request.app.state.config.TASK_MODEL,
  96. "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
  97. "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
  98. "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
  99. "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
  100. "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
  101. "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
  102. "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
  103. "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
  104. "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
  105. "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  106. }
  107. @router.post("/title/completions")
  108. async def generate_title(
  109. request: Request, form_data: dict, user=Depends(get_verified_user)
  110. ):
  111. models = request.app.state.MODELS
  112. model_id = form_data["model"]
  113. if model_id not in models:
  114. raise HTTPException(
  115. status_code=status.HTTP_404_NOT_FOUND,
  116. detail="Model not found",
  117. )
  118. # Check if the user has a custom task model
  119. # If the user has a custom task model, use that model
  120. task_model_id = get_task_model_id(
  121. model_id,
  122. request.app.state.config.TASK_MODEL,
  123. request.app.state.config.TASK_MODEL_EXTERNAL,
  124. models,
  125. )
  126. log.debug(
  127. f"generating chat title using model {task_model_id} for user {user.email} "
  128. )
  129. if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
  130. template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
  131. else:
  132. template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
  133. content = title_generation_template(
  134. template,
  135. form_data["messages"],
  136. {
  137. "name": user.name,
  138. "location": user.info.get("location") if user.info else None,
  139. },
  140. )
  141. payload = {
  142. "model": task_model_id,
  143. "messages": [{"role": "user", "content": content}],
  144. "stream": False,
  145. **(
  146. {"max_tokens": 50}
  147. if models[task_model_id]["owned_by"] == "ollama"
  148. else {
  149. "max_completion_tokens": 50,
  150. }
  151. ),
  152. "metadata": {
  153. "task": str(TASKS.TITLE_GENERATION),
  154. "task_body": form_data,
  155. "chat_id": form_data.get("chat_id", None),
  156. },
  157. }
  158. try:
  159. return await generate_chat_completion(request, form_data=payload, user=user)
  160. except Exception as e:
  161. log.error("Exception occurred", exc_info=True)
  162. return JSONResponse(
  163. status_code=status.HTTP_400_BAD_REQUEST,
  164. content={"detail": "An internal error has occurred."},
  165. )
  166. @router.post("/tags/completions")
  167. async def generate_chat_tags(
  168. request: Request, form_data: dict, user=Depends(get_verified_user)
  169. ):
  170. if not request.app.state.config.ENABLE_TAGS_GENERATION:
  171. return JSONResponse(
  172. status_code=status.HTTP_200_OK,
  173. content={"detail": "Tags generation is disabled"},
  174. )
  175. models = request.app.state.MODELS
  176. model_id = form_data["model"]
  177. if model_id not in models:
  178. raise HTTPException(
  179. status_code=status.HTTP_404_NOT_FOUND,
  180. detail="Model not found",
  181. )
  182. # Check if the user has a custom task model
  183. # If the user has a custom task model, use that model
  184. task_model_id = get_task_model_id(
  185. model_id,
  186. request.app.state.config.TASK_MODEL,
  187. request.app.state.config.TASK_MODEL_EXTERNAL,
  188. models,
  189. )
  190. log.debug(
  191. f"generating chat tags using model {task_model_id} for user {user.email} "
  192. )
  193. if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
  194. template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
  195. else:
  196. template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE
  197. content = tags_generation_template(
  198. template, form_data["messages"], {"name": user.name}
  199. )
  200. payload = {
  201. "model": task_model_id,
  202. "messages": [{"role": "user", "content": content}],
  203. "stream": False,
  204. "metadata": {
  205. "task": str(TASKS.TAGS_GENERATION),
  206. "task_body": form_data,
  207. "chat_id": form_data.get("chat_id", None),
  208. },
  209. }
  210. try:
  211. return await generate_chat_completion(request, form_data=payload, user=user)
  212. except Exception as e:
  213. log.error(f"Error generating chat completion: {e}")
  214. return JSONResponse(
  215. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  216. content={"detail": "An internal error has occurred."},
  217. )
  218. @router.post("/queries/completions")
  219. async def generate_queries(
  220. request: Request, form_data: dict, user=Depends(get_verified_user)
  221. ):
  222. type = form_data.get("type")
  223. if type == "web_search":
  224. if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
  225. raise HTTPException(
  226. status_code=status.HTTP_400_BAD_REQUEST,
  227. detail=f"Search query generation is disabled",
  228. )
  229. elif type == "retrieval":
  230. if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
  231. raise HTTPException(
  232. status_code=status.HTTP_400_BAD_REQUEST,
  233. detail=f"Query generation is disabled",
  234. )
  235. models = request.app.state.MODELS
  236. model_id = form_data["model"]
  237. if model_id not in models:
  238. raise HTTPException(
  239. status_code=status.HTTP_404_NOT_FOUND,
  240. detail="Model not found",
  241. )
  242. # Check if the user has a custom task model
  243. # If the user has a custom task model, use that model
  244. task_model_id = get_task_model_id(
  245. model_id,
  246. request.app.state.config.TASK_MODEL,
  247. request.app.state.config.TASK_MODEL_EXTERNAL,
  248. models,
  249. )
  250. log.debug(
  251. f"generating {type} queries using model {task_model_id} for user {user.email}"
  252. )
  253. if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "":
  254. template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
  255. else:
  256. template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
  257. content = query_generation_template(
  258. template, form_data["messages"], {"name": user.name}
  259. )
  260. payload = {
  261. "model": task_model_id,
  262. "messages": [{"role": "user", "content": content}],
  263. "stream": False,
  264. "metadata": {
  265. "task": str(TASKS.QUERY_GENERATION),
  266. "task_body": form_data,
  267. "chat_id": form_data.get("chat_id", None),
  268. },
  269. }
  270. try:
  271. return await generate_chat_completion(request, form_data=payload, user=user)
  272. except Exception as e:
  273. return JSONResponse(
  274. status_code=status.HTTP_400_BAD_REQUEST,
  275. content={"detail": str(e)},
  276. )
  277. @router.post("/auto/completions")
  278. async def generate_autocompletion(
  279. request: Request, form_data: dict, user=Depends(get_verified_user)
  280. ):
  281. if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION:
  282. raise HTTPException(
  283. status_code=status.HTTP_400_BAD_REQUEST,
  284. detail=f"Autocompletion generation is disabled",
  285. )
  286. type = form_data.get("type")
  287. prompt = form_data.get("prompt")
  288. messages = form_data.get("messages")
  289. if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0:
  290. if (
  291. len(prompt)
  292. > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
  293. ):
  294. raise HTTPException(
  295. status_code=status.HTTP_400_BAD_REQUEST,
  296. detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
  297. )
  298. models = request.app.state.MODELS
  299. model_id = form_data["model"]
  300. if model_id not in models:
  301. raise HTTPException(
  302. status_code=status.HTTP_404_NOT_FOUND,
  303. detail="Model not found",
  304. )
  305. # Check if the user has a custom task model
  306. # If the user has a custom task model, use that model
  307. task_model_id = get_task_model_id(
  308. model_id,
  309. request.app.state.config.TASK_MODEL,
  310. request.app.state.config.TASK_MODEL_EXTERNAL,
  311. models,
  312. )
  313. log.debug(
  314. f"generating autocompletion using model {task_model_id} for user {user.email}"
  315. )
  316. if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
  317. template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
  318. else:
  319. template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
  320. content = autocomplete_generation_template(
  321. template, prompt, messages, type, {"name": user.name}
  322. )
  323. payload = {
  324. "model": task_model_id,
  325. "messages": [{"role": "user", "content": content}],
  326. "stream": False,
  327. "metadata": {
  328. "task": str(TASKS.AUTOCOMPLETE_GENERATION),
  329. "task_body": form_data,
  330. "chat_id": form_data.get("chat_id", None),
  331. },
  332. }
  333. try:
  334. return await generate_chat_completion(request, form_data=payload, user=user)
  335. except Exception as e:
  336. log.error(f"Error generating chat completion: {e}")
  337. return JSONResponse(
  338. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  339. content={"detail": "An internal error has occurred."},
  340. )
  341. @router.post("/emoji/completions")
  342. async def generate_emoji(
  343. request: Request, form_data: dict, user=Depends(get_verified_user)
  344. ):
  345. models = request.app.state.MODELS
  346. model_id = form_data["model"]
  347. if model_id not in models:
  348. raise HTTPException(
  349. status_code=status.HTTP_404_NOT_FOUND,
  350. detail="Model not found",
  351. )
  352. # Check if the user has a custom task model
  353. # If the user has a custom task model, use that model
  354. task_model_id = get_task_model_id(
  355. model_id,
  356. request.app.state.config.TASK_MODEL,
  357. request.app.state.config.TASK_MODEL_EXTERNAL,
  358. models,
  359. )
  360. log.debug(f"generating emoji using model {task_model_id} for user {user.email} ")
  361. template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE
  362. content = emoji_generation_template(
  363. template,
  364. form_data["prompt"],
  365. {
  366. "name": user.name,
  367. "location": user.info.get("location") if user.info else None,
  368. },
  369. )
  370. payload = {
  371. "model": task_model_id,
  372. "messages": [{"role": "user", "content": content}],
  373. "stream": False,
  374. **(
  375. {"max_tokens": 4}
  376. if models[task_model_id]["owned_by"] == "ollama"
  377. else {
  378. "max_completion_tokens": 4,
  379. }
  380. ),
  381. "chat_id": form_data.get("chat_id", None),
  382. "metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data},
  383. }
  384. try:
  385. return await generate_chat_completion(request, form_data=payload, user=user)
  386. except Exception as e:
  387. return JSONResponse(
  388. status_code=status.HTTP_400_BAD_REQUEST,
  389. content={"detail": str(e)},
  390. )
  391. @router.post("/moa/completions")
  392. async def generate_moa_response(
  393. request: Request, form_data: dict, user=Depends(get_verified_user)
  394. ):
  395. models = request.app.state.MODELS
  396. model_id = form_data["model"]
  397. if model_id not in models:
  398. raise HTTPException(
  399. status_code=status.HTTP_404_NOT_FOUND,
  400. detail="Model not found",
  401. )
  402. # Check if the user has a custom task model
  403. # If the user has a custom task model, use that model
  404. task_model_id = get_task_model_id(
  405. model_id,
  406. request.app.state.config.TASK_MODEL,
  407. request.app.state.config.TASK_MODEL_EXTERNAL,
  408. models,
  409. )
  410. log.debug(f"generating MOA model {task_model_id} for user {user.email} ")
  411. template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
  412. content = moa_response_generation_template(
  413. template,
  414. form_data["prompt"],
  415. form_data["responses"],
  416. )
  417. payload = {
  418. "model": task_model_id,
  419. "messages": [{"role": "user", "content": content}],
  420. "stream": form_data.get("stream", False),
  421. "metadata": {
  422. "chat_id": form_data.get("chat_id", None),
  423. "task": str(TASKS.MOA_RESPONSE_GENERATION),
  424. "task_body": form_data,
  425. },
  426. }
  427. try:
  428. return await generate_chat_completion(request, form_data=payload, user=user)
  429. except Exception as e:
  430. return JSONResponse(
  431. status_code=status.HTTP_400_BAD_REQUEST,
  432. content={"detail": str(e)},
  433. )