channels.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import json
  2. import logging
  3. from typing import Optional
  4. from fastapi import APIRouter, Depends, HTTPException, Request, status
  5. from pydantic import BaseModel
  6. from open_webui.socket.main import sio
  7. from open_webui.models.users import Users, UserNameResponse
  8. from open_webui.models.channels import Channels, ChannelModel, ChannelForm
  9. from open_webui.models.messages import Messages, MessageModel, MessageForm
  10. from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
  11. from open_webui.constants import ERROR_MESSAGES
  12. from open_webui.env import SRC_LOG_LEVELS
  13. from open_webui.utils.auth import get_admin_user, get_verified_user
  14. from open_webui.utils.access_control import has_access
  15. log = logging.getLogger(__name__)
  16. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  17. router = APIRouter()
  18. ############################
  19. # GetChatList
  20. ############################
  21. @router.get("/", response_model=list[ChannelModel])
  22. async def get_channels(user=Depends(get_verified_user)):
  23. if user.role == "admin":
  24. return Channels.get_channels()
  25. else:
  26. return Channels.get_channels_by_user_id(user.id)
  27. ############################
  28. # CreateNewChannel
  29. ############################
  30. @router.post("/create", response_model=Optional[ChannelModel])
  31. async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user)):
  32. try:
  33. channel = Channels.insert_new_channel(form_data, user.id)
  34. return ChannelModel(**channel.model_dump())
  35. except Exception as e:
  36. log.exception(e)
  37. raise HTTPException(
  38. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  39. )
  40. ############################
  41. # GetChannelById
  42. ############################
  43. @router.get("/{id}", response_model=Optional[ChannelModel])
  44. async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
  45. channel = Channels.get_channel_by_id(id)
  46. if not channel:
  47. raise HTTPException(
  48. status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
  49. )
  50. if not has_access(user.id, type="read", access_control=channel.access_control):
  51. raise HTTPException(
  52. status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
  53. )
  54. return ChannelModel(**channel.model_dump())
  55. ############################
  56. # UpdateChannelById
  57. ############################
  58. @router.post("/{id}/update", response_model=Optional[ChannelModel])
  59. async def update_channel_by_id(
  60. id: str, form_data: ChannelForm, user=Depends(get_admin_user)
  61. ):
  62. channel = Channels.get_channel_by_id(id)
  63. if not channel:
  64. raise HTTPException(
  65. status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
  66. )
  67. try:
  68. channel = Channels.update_channel_by_id(id, form_data)
  69. return ChannelModel(**channel.model_dump())
  70. except Exception as e:
  71. log.exception(e)
  72. raise HTTPException(
  73. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  74. )
  75. ############################
  76. # DeleteChannelById
  77. ############################
  78. @router.delete("/{id}/delete", response_model=bool)
  79. async def delete_channel_by_id(id: str, user=Depends(get_admin_user)):
  80. channel = Channels.get_channel_by_id(id)
  81. if not channel:
  82. raise HTTPException(
  83. status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
  84. )
  85. try:
  86. Channels.delete_channel_by_id(id)
  87. return True
  88. except Exception as e:
  89. log.exception(e)
  90. raise HTTPException(
  91. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  92. )
  93. ############################
  94. # GetChannelMessages
  95. ############################
  96. class MessageUserModel(MessageModel):
  97. user: UserNameResponse
  98. @router.get("/{id}/messages", response_model=list[MessageUserModel])
  99. async def get_channel_messages(
  100. id: str, skip: int = 0, limit: int = 50, user=Depends(get_verified_user)
  101. ):
  102. channel = Channels.get_channel_by_id(id)
  103. if not channel:
  104. raise HTTPException(
  105. status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
  106. )
  107. if not has_access(user.id, type="read", access_control=channel.access_control):
  108. raise HTTPException(
  109. status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
  110. )
  111. message_list = Messages.get_messages_by_channel_id(id, skip, limit)
  112. users = {}
  113. messages = []
  114. for message in message_list:
  115. if message.user_id not in users:
  116. user = Users.get_user_by_id(message.user_id)
  117. users[message.user_id] = user
  118. messages.append(
  119. MessageUserModel(
  120. **{
  121. **message.model_dump(),
  122. "user": UserNameResponse(**users[message.user_id].model_dump()),
  123. }
  124. )
  125. )
  126. return messages
  127. ############################
  128. # PostNewMessage
  129. ############################
  130. @router.post("/{id}/messages/post", response_model=Optional[MessageModel])
  131. async def post_new_message(
  132. id: str, form_data: MessageForm, user=Depends(get_verified_user)
  133. ):
  134. channel = Channels.get_channel_by_id(id)
  135. if not channel:
  136. raise HTTPException(
  137. status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
  138. )
  139. if not has_access(user.id, type="read", access_control=channel.access_control):
  140. raise HTTPException(
  141. status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
  142. )
  143. try:
  144. message = Messages.insert_new_message(form_data, channel.id, user.id)
  145. if message:
  146. await sio.emit(
  147. "channel-events",
  148. {
  149. "channel_id": channel.id,
  150. "message_id": message.id,
  151. "data": {
  152. "type": "message",
  153. "data": {
  154. **message.model_dump(),
  155. "user": UserNameResponse(**user.model_dump()).model_dump(),
  156. },
  157. },
  158. },
  159. to=f"channel:{channel.id}",
  160. )
  161. return MessageModel(**message.model_dump())
  162. except Exception as e:
  163. log.exception(e)
  164. raise HTTPException(
  165. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  166. )