channels.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. import json
  2. import logging
  3. from typing import Optional
  4. from fastapi import APIRouter, Depends, HTTPException, Request, status
  5. from pydantic import BaseModel
  6. from open_webui.socket.main import sio
  7. from open_webui.models.users import Users, UserNameResponse
  8. from open_webui.models.channels import Channels, ChannelModel, ChannelForm
  9. from open_webui.models.messages import Messages, MessageModel, MessageForm
  10. from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
  11. from open_webui.constants import ERROR_MESSAGES
  12. from open_webui.env import SRC_LOG_LEVELS
  13. from open_webui.utils.auth import get_admin_user, get_verified_user
  14. from open_webui.utils.access_control import has_access
  15. log = logging.getLogger(__name__)
  16. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  17. router = APIRouter()
  18. ############################
  19. # GetChatList
  20. ############################
  21. @router.get("/", response_model=list[ChannelModel])
  22. async def get_channels(user=Depends(get_verified_user)):
  23. if user.role == "admin":
  24. return Channels.get_channels()
  25. else:
  26. return Channels.get_channels_by_user_id(user.id)
  27. ############################
  28. # CreateNewChannel
  29. ############################
  30. @router.post("/create", response_model=Optional[ChannelModel])
  31. async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user)):
  32. try:
  33. channel = Channels.insert_new_channel(form_data, user.id)
  34. return ChannelModel(**channel.model_dump())
  35. except Exception as e:
  36. log.exception(e)
  37. raise HTTPException(
  38. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  39. )
  40. ############################
  41. # GetChannelById
  42. ############################
  43. @router.get("/{id}", response_model=Optional[ChannelModel])
  44. async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
  45. channel = Channels.get_channel_by_id(id)
  46. if not channel:
  47. raise HTTPException(
  48. status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
  49. )
  50. if not has_access(user.id, type="read", access_control=channel.access_control):
  51. raise HTTPException(
  52. status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
  53. )
  54. return ChannelModel(**channel.model_dump())
  55. ############################
  56. # UpdateChannelById
  57. ############################
  58. @router.post("/{id}/update", response_model=Optional[ChannelModel])
  59. async def update_channel_by_id(
  60. id: str, form_data: ChannelForm, user=Depends(get_admin_user)
  61. ):
  62. channel = Channels.get_channel_by_id(id)
  63. if not channel:
  64. raise HTTPException(
  65. status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
  66. )
  67. try:
  68. channel = Channels.update_channel_by_id(id, form_data)
  69. return ChannelModel(**channel.model_dump())
  70. except Exception as e:
  71. log.exception(e)
  72. raise HTTPException(
  73. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  74. )
  75. ############################
  76. # DeleteChannelById
  77. ############################
  78. @router.delete("/{id}/delete", response_model=bool)
  79. async def delete_channel_by_id(id: str, user=Depends(get_admin_user)):
  80. channel = Channels.get_channel_by_id(id)
  81. if not channel:
  82. raise HTTPException(
  83. status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
  84. )
  85. try:
  86. Channels.delete_channel_by_id(id)
  87. return True
  88. except Exception as e:
  89. log.exception(e)
  90. raise HTTPException(
  91. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  92. )
  93. ############################
  94. # GetChannelMessages
  95. ############################
  96. class MessageUserModel(MessageModel):
  97. user: UserNameResponse
  98. @router.get("/{id}/messages", response_model=list[MessageUserModel])
  99. async def get_channel_messages(
  100. id: str, skip: int = 0, limit: int = 50, user=Depends(get_verified_user)
  101. ):
  102. channel = Channels.get_channel_by_id(id)
  103. if not channel:
  104. raise HTTPException(
  105. status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
  106. )
  107. if not has_access(user.id, type="read", access_control=channel.access_control):
  108. raise HTTPException(
  109. status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
  110. )
  111. message_list = Messages.get_messages_by_channel_id(id, skip, limit)
  112. users = {}
  113. messages = []
  114. for message in message_list:
  115. if message.user_id not in users:
  116. user = Users.get_user_by_id(message.user_id)
  117. users[message.user_id] = user
  118. messages.append(
  119. MessageUserModel(
  120. **{
  121. **message.model_dump(),
  122. "user": UserNameResponse(**users[message.user_id].model_dump()),
  123. }
  124. )
  125. )
  126. return messages
  127. ############################
  128. # PostNewMessage
  129. ############################
  130. @router.post("/{id}/messages/post", response_model=Optional[MessageModel])
  131. async def post_new_message(
  132. id: str, form_data: MessageForm, user=Depends(get_verified_user)
  133. ):
  134. channel = Channels.get_channel_by_id(id)
  135. if not channel:
  136. raise HTTPException(
  137. status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
  138. )
  139. if not has_access(user.id, type="read", access_control=channel.access_control):
  140. raise HTTPException(
  141. status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
  142. )
  143. try:
  144. message = Messages.insert_new_message(form_data, channel.id, user.id)
  145. if message:
  146. await sio.emit(
  147. "channel-events",
  148. {
  149. "channel_id": channel.id,
  150. "message_id": message.id,
  151. "data": {
  152. "type": "message",
  153. "data": {
  154. **message.model_dump(),
  155. "user": UserNameResponse(**user.model_dump()).model_dump(),
  156. },
  157. },
  158. },
  159. to=f"channel:{channel.id}",
  160. )
  161. return MessageModel(**message.model_dump())
  162. except Exception as e:
  163. log.exception(e)
  164. raise HTTPException(
  165. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  166. )
  167. ############################
  168. # UpdateMessageById
  169. ############################
  170. @router.post(
  171. "/{id}/messages/{message_id}/update", response_model=Optional[MessageModel]
  172. )
  173. async def update_message_by_id(
  174. id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
  175. ):
  176. channel = Channels.get_channel_by_id(id)
  177. if not channel:
  178. raise HTTPException(
  179. status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
  180. )
  181. if not has_access(user.id, type="read", access_control=channel.access_control):
  182. raise HTTPException(
  183. status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
  184. )
  185. message = Messages.get_message_by_id(message_id)
  186. if not message:
  187. raise HTTPException(
  188. status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
  189. )
  190. if message.channel_id != id:
  191. raise HTTPException(
  192. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  193. )
  194. try:
  195. message = Messages.update_message_by_id(message_id, form_data)
  196. if message:
  197. await sio.emit(
  198. "channel-events",
  199. {
  200. "channel_id": channel.id,
  201. "message_id": message.id,
  202. "data": {
  203. "type": "message:update",
  204. "data": {
  205. **message.model_dump(),
  206. "user": UserNameResponse(**user.model_dump()).model_dump(),
  207. },
  208. },
  209. },
  210. to=f"channel:{channel.id}",
  211. )
  212. return MessageModel(**message.model_dump())
  213. except Exception as e:
  214. log.exception(e)
  215. raise HTTPException(
  216. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  217. )
  218. ############################
  219. # DeleteMessageById
  220. ############################
  221. @router.delete("/{id}/messages/{message_id}/delete", response_model=bool)
  222. async def delete_message_by_id(
  223. id: str, message_id: str, user=Depends(get_verified_user)
  224. ):
  225. channel = Channels.get_channel_by_id(id)
  226. if not channel:
  227. raise HTTPException(
  228. status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
  229. )
  230. if not has_access(user.id, type="read", access_control=channel.access_control):
  231. raise HTTPException(
  232. status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
  233. )
  234. message = Messages.get_message_by_id(message_id)
  235. if not message:
  236. raise HTTPException(
  237. status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
  238. )
  239. if message.channel_id != id:
  240. raise HTTPException(
  241. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  242. )
  243. try:
  244. Messages.delete_message_by_id(message_id)
  245. await sio.emit(
  246. "channel-events",
  247. {
  248. "channel_id": channel.id,
  249. "message_id": message.id,
  250. "data": {
  251. "type": "message:delete",
  252. "data": {
  253. **message.model_dump(),
  254. "user": UserNameResponse(**user.model_dump()).model_dump(),
  255. },
  256. },
  257. },
  258. to=f"channel:{channel.id}",
  259. )
  260. return True
  261. except Exception as e:
  262. log.exception(e)
  263. raise HTTPException(
  264. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  265. )