tasks.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699
  1. from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
  2. from fastapi.responses import JSONResponse, RedirectResponse
  3. from pydantic import BaseModel
  4. from typing import Optional
  5. import logging
  6. import re
  7. from open_webui.utils.chat import generate_chat_completion
  8. from open_webui.utils.task import (
  9. title_generation_template,
  10. query_generation_template,
  11. image_prompt_generation_template,
  12. autocomplete_generation_template,
  13. tags_generation_template,
  14. emoji_generation_template,
  15. moa_response_generation_template,
  16. )
  17. from open_webui.utils.auth import get_admin_user, get_verified_user
  18. from open_webui.constants import TASKS
  19. from open_webui.routers.pipelines import process_pipeline_inlet_filter
  20. from open_webui.utils.filter import (
  21. get_sorted_filter_ids,
  22. process_filter_functions,
  23. )
  24. from open_webui.utils.task import get_task_model_id
  25. from open_webui.config import (
  26. DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
  27. DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
  28. DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
  29. DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
  30. DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
  31. DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
  32. DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE,
  33. )
  34. from open_webui.env import SRC_LOG_LEVELS
  35. log = logging.getLogger(__name__)
  36. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  37. router = APIRouter()
  38. ##################################
  39. #
  40. # Task Endpoints
  41. #
  42. ##################################
  43. @router.get("/config")
  44. async def get_task_config(request: Request, user=Depends(get_verified_user)):
  45. return {
  46. "TASK_MODEL": request.app.state.config.TASK_MODEL,
  47. "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
  48. "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
  49. "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
  50. "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
  51. "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
  52. "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
  53. "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
  54. "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
  55. "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
  56. "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
  57. "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
  58. "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  59. }
  60. class TaskConfigForm(BaseModel):
  61. TASK_MODEL: Optional[str]
  62. TASK_MODEL_EXTERNAL: Optional[str]
  63. ENABLE_TITLE_GENERATION: bool
  64. TITLE_GENERATION_PROMPT_TEMPLATE: str
  65. IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
  66. ENABLE_AUTOCOMPLETE_GENERATION: bool
  67. AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
  68. TAGS_GENERATION_PROMPT_TEMPLATE: str
  69. ENABLE_TAGS_GENERATION: bool
  70. ENABLE_SEARCH_QUERY_GENERATION: bool
  71. ENABLE_RETRIEVAL_QUERY_GENERATION: bool
  72. QUERY_GENERATION_PROMPT_TEMPLATE: str
  73. TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
  74. @router.post("/config/update")
  75. async def update_task_config(
  76. request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)
  77. ):
  78. request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
  79. request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
  80. request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION
  81. request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
  82. form_data.TITLE_GENERATION_PROMPT_TEMPLATE
  83. )
  84. request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
  85. form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
  86. )
  87. request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
  88. form_data.ENABLE_AUTOCOMPLETE_GENERATION
  89. )
  90. request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
  91. form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
  92. )
  93. request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
  94. form_data.TAGS_GENERATION_PROMPT_TEMPLATE
  95. )
  96. request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
  97. request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
  98. form_data.ENABLE_SEARCH_QUERY_GENERATION
  99. )
  100. request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
  101. form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
  102. )
  103. request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
  104. form_data.QUERY_GENERATION_PROMPT_TEMPLATE
  105. )
  106. request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
  107. form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
  108. )
  109. return {
  110. "TASK_MODEL": request.app.state.config.TASK_MODEL,
  111. "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
  112. "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
  113. "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
  114. "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
  115. "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
  116. "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
  117. "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
  118. "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
  119. "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
  120. "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
  121. "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
  122. "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  123. }
  124. @router.post("/title/completions")
  125. async def generate_title(
  126. request: Request, form_data: dict, user=Depends(get_verified_user)
  127. ):
  128. if not request.app.state.config.ENABLE_TITLE_GENERATION:
  129. return JSONResponse(
  130. status_code=status.HTTP_200_OK,
  131. content={"detail": "Title generation is disabled"},
  132. )
  133. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  134. models = {
  135. request.state.model["id"]: request.state.model,
  136. }
  137. else:
  138. models = request.app.state.MODELS
  139. model_id = form_data["model"]
  140. if model_id not in models:
  141. raise HTTPException(
  142. status_code=status.HTTP_404_NOT_FOUND,
  143. detail="Model not found",
  144. )
  145. # Check if the user has a custom task model
  146. # If the user has a custom task model, use that model
  147. task_model_id = get_task_model_id(
  148. model_id,
  149. request.app.state.config.TASK_MODEL,
  150. request.app.state.config.TASK_MODEL_EXTERNAL,
  151. models,
  152. )
  153. log.debug(
  154. f"generating chat title using model {task_model_id} for user {user.email} "
  155. )
  156. if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
  157. template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
  158. else:
  159. template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
  160. messages = form_data["messages"]
  161. # Remove reasoning details from the messages
  162. for message in messages:
  163. message["content"] = re.sub(
  164. r"<details\s+type=\"reasoning\"[^>]*>.*?<\/details>",
  165. "",
  166. message["content"],
  167. flags=re.S,
  168. ).strip()
  169. content = title_generation_template(
  170. template,
  171. messages,
  172. {
  173. "name": user.name,
  174. "location": user.info.get("location") if user.info else None,
  175. },
  176. )
  177. payload = {
  178. "model": task_model_id,
  179. "messages": [{"role": "user", "content": content}],
  180. "stream": False,
  181. **(
  182. {"max_tokens": 1000}
  183. if models[task_model_id].get("owned_by") == "ollama"
  184. else {
  185. "max_completion_tokens": 1000,
  186. }
  187. ),
  188. "metadata": {
  189. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  190. "task": str(TASKS.TITLE_GENERATION),
  191. "task_body": form_data,
  192. "chat_id": form_data.get("chat_id", None),
  193. },
  194. }
  195. # Process the payload through the pipeline
  196. try:
  197. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  198. except Exception as e:
  199. raise e
  200. try:
  201. return await generate_chat_completion(request, form_data=payload, user=user)
  202. except Exception as e:
  203. log.error("Exception occurred", exc_info=True)
  204. return JSONResponse(
  205. status_code=status.HTTP_400_BAD_REQUEST,
  206. content={"detail": "An internal error has occurred."},
  207. )
  208. @router.post("/tags/completions")
  209. async def generate_chat_tags(
  210. request: Request, form_data: dict, user=Depends(get_verified_user)
  211. ):
  212. if not request.app.state.config.ENABLE_TAGS_GENERATION:
  213. return JSONResponse(
  214. status_code=status.HTTP_200_OK,
  215. content={"detail": "Tags generation is disabled"},
  216. )
  217. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  218. models = {
  219. request.state.model["id"]: request.state.model,
  220. }
  221. else:
  222. models = request.app.state.MODELS
  223. model_id = form_data["model"]
  224. if model_id not in models:
  225. raise HTTPException(
  226. status_code=status.HTTP_404_NOT_FOUND,
  227. detail="Model not found",
  228. )
  229. # Check if the user has a custom task model
  230. # If the user has a custom task model, use that model
  231. task_model_id = get_task_model_id(
  232. model_id,
  233. request.app.state.config.TASK_MODEL,
  234. request.app.state.config.TASK_MODEL_EXTERNAL,
  235. models,
  236. )
  237. log.debug(
  238. f"generating chat tags using model {task_model_id} for user {user.email} "
  239. )
  240. if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
  241. template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
  242. else:
  243. template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE
  244. content = tags_generation_template(
  245. template, form_data["messages"], {"name": user.name}
  246. )
  247. payload = {
  248. "model": task_model_id,
  249. "messages": [{"role": "user", "content": content}],
  250. "stream": False,
  251. "metadata": {
  252. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  253. "task": str(TASKS.TAGS_GENERATION),
  254. "task_body": form_data,
  255. "chat_id": form_data.get("chat_id", None),
  256. },
  257. }
  258. # Process the payload through the pipeline
  259. try:
  260. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  261. except Exception as e:
  262. raise e
  263. try:
  264. return await generate_chat_completion(request, form_data=payload, user=user)
  265. except Exception as e:
  266. log.error(f"Error generating chat completion: {e}")
  267. return JSONResponse(
  268. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  269. content={"detail": "An internal error has occurred."},
  270. )
  271. @router.post("/image_prompt/completions")
  272. async def generate_image_prompt(
  273. request: Request, form_data: dict, user=Depends(get_verified_user)
  274. ):
  275. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  276. models = {
  277. request.state.model["id"]: request.state.model,
  278. }
  279. else:
  280. models = request.app.state.MODELS
  281. model_id = form_data["model"]
  282. if model_id not in models:
  283. raise HTTPException(
  284. status_code=status.HTTP_404_NOT_FOUND,
  285. detail="Model not found",
  286. )
  287. # Check if the user has a custom task model
  288. # If the user has a custom task model, use that model
  289. task_model_id = get_task_model_id(
  290. model_id,
  291. request.app.state.config.TASK_MODEL,
  292. request.app.state.config.TASK_MODEL_EXTERNAL,
  293. models,
  294. )
  295. log.debug(
  296. f"generating image prompt using model {task_model_id} for user {user.email} "
  297. )
  298. if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "":
  299. template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
  300. else:
  301. template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
  302. content = image_prompt_generation_template(
  303. template,
  304. form_data["messages"],
  305. user={
  306. "name": user.name,
  307. },
  308. )
  309. payload = {
  310. "model": task_model_id,
  311. "messages": [{"role": "user", "content": content}],
  312. "stream": False,
  313. "metadata": {
  314. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  315. "task": str(TASKS.IMAGE_PROMPT_GENERATION),
  316. "task_body": form_data,
  317. "chat_id": form_data.get("chat_id", None),
  318. },
  319. }
  320. # Process the payload through the pipeline
  321. try:
  322. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  323. except Exception as e:
  324. raise e
  325. try:
  326. return await generate_chat_completion(request, form_data=payload, user=user)
  327. except Exception as e:
  328. log.error("Exception occurred", exc_info=True)
  329. return JSONResponse(
  330. status_code=status.HTTP_400_BAD_REQUEST,
  331. content={"detail": "An internal error has occurred."},
  332. )
  333. @router.post("/queries/completions")
  334. async def generate_queries(
  335. request: Request, form_data: dict, user=Depends(get_verified_user)
  336. ):
  337. type = form_data.get("type")
  338. if type == "web_search":
  339. if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
  340. raise HTTPException(
  341. status_code=status.HTTP_400_BAD_REQUEST,
  342. detail=f"Search query generation is disabled",
  343. )
  344. elif type == "retrieval":
  345. if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
  346. raise HTTPException(
  347. status_code=status.HTTP_400_BAD_REQUEST,
  348. detail=f"Query generation is disabled",
  349. )
  350. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  351. models = {
  352. request.state.model["id"]: request.state.model,
  353. }
  354. else:
  355. models = request.app.state.MODELS
  356. model_id = form_data["model"]
  357. if model_id not in models:
  358. raise HTTPException(
  359. status_code=status.HTTP_404_NOT_FOUND,
  360. detail="Model not found",
  361. )
  362. # Check if the user has a custom task model
  363. # If the user has a custom task model, use that model
  364. task_model_id = get_task_model_id(
  365. model_id,
  366. request.app.state.config.TASK_MODEL,
  367. request.app.state.config.TASK_MODEL_EXTERNAL,
  368. models,
  369. )
  370. log.debug(
  371. f"generating {type} queries using model {task_model_id} for user {user.email}"
  372. )
  373. if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "":
  374. template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
  375. else:
  376. template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
  377. content = query_generation_template(
  378. template, form_data["messages"], {"name": user.name}
  379. )
  380. payload = {
  381. "model": task_model_id,
  382. "messages": [{"role": "user", "content": content}],
  383. "stream": False,
  384. "metadata": {
  385. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  386. "task": str(TASKS.QUERY_GENERATION),
  387. "task_body": form_data,
  388. "chat_id": form_data.get("chat_id", None),
  389. },
  390. }
  391. # Process the payload through the pipeline
  392. try:
  393. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  394. except Exception as e:
  395. raise e
  396. try:
  397. return await generate_chat_completion(request, form_data=payload, user=user)
  398. except Exception as e:
  399. return JSONResponse(
  400. status_code=status.HTTP_400_BAD_REQUEST,
  401. content={"detail": str(e)},
  402. )
  403. @router.post("/auto/completions")
  404. async def generate_autocompletion(
  405. request: Request, form_data: dict, user=Depends(get_verified_user)
  406. ):
  407. if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION:
  408. raise HTTPException(
  409. status_code=status.HTTP_400_BAD_REQUEST,
  410. detail=f"Autocompletion generation is disabled",
  411. )
  412. type = form_data.get("type")
  413. prompt = form_data.get("prompt")
  414. messages = form_data.get("messages")
  415. if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0:
  416. if (
  417. len(prompt)
  418. > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
  419. ):
  420. raise HTTPException(
  421. status_code=status.HTTP_400_BAD_REQUEST,
  422. detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
  423. )
  424. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  425. models = {
  426. request.state.model["id"]: request.state.model,
  427. }
  428. else:
  429. models = request.app.state.MODELS
  430. model_id = form_data["model"]
  431. if model_id not in models:
  432. raise HTTPException(
  433. status_code=status.HTTP_404_NOT_FOUND,
  434. detail="Model not found",
  435. )
  436. # Check if the user has a custom task model
  437. # If the user has a custom task model, use that model
  438. task_model_id = get_task_model_id(
  439. model_id,
  440. request.app.state.config.TASK_MODEL,
  441. request.app.state.config.TASK_MODEL_EXTERNAL,
  442. models,
  443. )
  444. log.debug(
  445. f"generating autocompletion using model {task_model_id} for user {user.email}"
  446. )
  447. if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
  448. template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
  449. else:
  450. template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
  451. content = autocomplete_generation_template(
  452. template, prompt, messages, type, {"name": user.name}
  453. )
  454. payload = {
  455. "model": task_model_id,
  456. "messages": [{"role": "user", "content": content}],
  457. "stream": False,
  458. "metadata": {
  459. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  460. "task": str(TASKS.AUTOCOMPLETE_GENERATION),
  461. "task_body": form_data,
  462. "chat_id": form_data.get("chat_id", None),
  463. },
  464. }
  465. # Process the payload through the pipeline
  466. try:
  467. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  468. except Exception as e:
  469. raise e
  470. try:
  471. return await generate_chat_completion(request, form_data=payload, user=user)
  472. except Exception as e:
  473. log.error(f"Error generating chat completion: {e}")
  474. return JSONResponse(
  475. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  476. content={"detail": "An internal error has occurred."},
  477. )
  478. @router.post("/emoji/completions")
  479. async def generate_emoji(
  480. request: Request, form_data: dict, user=Depends(get_verified_user)
  481. ):
  482. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  483. models = {
  484. request.state.model["id"]: request.state.model,
  485. }
  486. else:
  487. models = request.app.state.MODELS
  488. model_id = form_data["model"]
  489. if model_id not in models:
  490. raise HTTPException(
  491. status_code=status.HTTP_404_NOT_FOUND,
  492. detail="Model not found",
  493. )
  494. # Check if the user has a custom task model
  495. # If the user has a custom task model, use that model
  496. task_model_id = get_task_model_id(
  497. model_id,
  498. request.app.state.config.TASK_MODEL,
  499. request.app.state.config.TASK_MODEL_EXTERNAL,
  500. models,
  501. )
  502. log.debug(f"generating emoji using model {task_model_id} for user {user.email} ")
  503. template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE
  504. content = emoji_generation_template(
  505. template,
  506. form_data["prompt"],
  507. {
  508. "name": user.name,
  509. "location": user.info.get("location") if user.info else None,
  510. },
  511. )
  512. payload = {
  513. "model": task_model_id,
  514. "messages": [{"role": "user", "content": content}],
  515. "stream": False,
  516. **(
  517. {"max_tokens": 4}
  518. if models[task_model_id].get("owned_by") == "ollama"
  519. else {
  520. "max_completion_tokens": 4,
  521. }
  522. ),
  523. "chat_id": form_data.get("chat_id", None),
  524. "metadata": {
  525. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  526. "task": str(TASKS.EMOJI_GENERATION),
  527. "task_body": form_data,
  528. },
  529. }
  530. # Process the payload through the pipeline
  531. try:
  532. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  533. except Exception as e:
  534. raise e
  535. try:
  536. return await generate_chat_completion(request, form_data=payload, user=user)
  537. except Exception as e:
  538. return JSONResponse(
  539. status_code=status.HTTP_400_BAD_REQUEST,
  540. content={"detail": str(e)},
  541. )
  542. @router.post("/moa/completions")
  543. async def generate_moa_response(
  544. request: Request, form_data: dict, user=Depends(get_verified_user)
  545. ):
  546. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  547. models = {
  548. request.state.model["id"]: request.state.model,
  549. }
  550. else:
  551. models = request.app.state.MODELS
  552. model_id = form_data["model"]
  553. if model_id not in models:
  554. raise HTTPException(
  555. status_code=status.HTTP_404_NOT_FOUND,
  556. detail="Model not found",
  557. )
  558. # Check if the user has a custom task model
  559. # If the user has a custom task model, use that model
  560. task_model_id = get_task_model_id(
  561. model_id,
  562. request.app.state.config.TASK_MODEL,
  563. request.app.state.config.TASK_MODEL_EXTERNAL,
  564. models,
  565. )
  566. log.debug(f"generating MOA model {task_model_id} for user {user.email} ")
  567. template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
  568. content = moa_response_generation_template(
  569. template,
  570. form_data["prompt"],
  571. form_data["responses"],
  572. )
  573. payload = {
  574. "model": task_model_id,
  575. "messages": [{"role": "user", "content": content}],
  576. "stream": form_data.get("stream", False),
  577. "metadata": {
  578. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  579. "chat_id": form_data.get("chat_id", None),
  580. "task": str(TASKS.MOA_RESPONSE_GENERATION),
  581. "task_body": form_data,
  582. },
  583. }
  584. # Process the payload through the pipeline
  585. try:
  586. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  587. except Exception as e:
  588. raise e
  589. try:
  590. return await generate_chat_completion(request, form_data=payload, user=user)
  591. except Exception as e:
  592. return JSONResponse(
  593. status_code=status.HTTP_400_BAD_REQUEST,
  594. content={"detail": str(e)},
  595. )