test_chats.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import uuid
  2. from test.util.abstract_integration_test import AbstractPostgresTest
  3. from test.util.mock_user import mock_webui_user
  4. class TestChats(AbstractPostgresTest):
  5. BASE_PATH = "/api/v1/chats"
  6. def setup_class(cls):
  7. super().setup_class()
  8. def setup_method(self):
  9. super().setup_method()
  10. from open_webui.apps.webui.models.chats import ChatForm, Chats
  11. self.chats = Chats
  12. self.chats.insert_new_chat(
  13. "2",
  14. ChatForm(
  15. **{
  16. "chat": {
  17. "name": "chat1",
  18. "description": "chat1 description",
  19. "tags": ["tag1", "tag2"],
  20. "history": {"currentId": "1", "messages": []},
  21. }
  22. }
  23. ),
  24. )
  25. def test_get_session_user_chat_list(self):
  26. with mock_webui_user(id="2"):
  27. response = self.fast_api_client.get(self.create_url("/"))
  28. assert response.status_code == 200
  29. first_chat = response.json()[0]
  30. assert first_chat["id"] is not None
  31. assert first_chat["title"] == "New Chat"
  32. assert first_chat["created_at"] is not None
  33. assert first_chat["updated_at"] is not None
  34. def test_delete_all_user_chats(self):
  35. with mock_webui_user(id="2"):
  36. response = self.fast_api_client.delete(self.create_url("/"))
  37. assert response.status_code == 200
  38. assert len(self.chats.get_chats()) == 0
  39. def test_get_user_chat_list_by_user_id(self):
  40. with mock_webui_user(id="3"):
  41. response = self.fast_api_client.get(self.create_url("/list/user/2"))
  42. assert response.status_code == 200
  43. first_chat = response.json()[0]
  44. assert first_chat["id"] is not None
  45. assert first_chat["title"] == "New Chat"
  46. assert first_chat["created_at"] is not None
  47. assert first_chat["updated_at"] is not None
  48. def test_create_new_chat(self):
  49. with mock_webui_user(id="2"):
  50. response = self.fast_api_client.post(
  51. self.create_url("/new"),
  52. json={
  53. "chat": {
  54. "name": "chat2",
  55. "description": "chat2 description",
  56. "tags": ["tag1", "tag2"],
  57. }
  58. },
  59. )
  60. assert response.status_code == 200
  61. data = response.json()
  62. assert data["archived"] is False
  63. assert data["chat"] == {
  64. "name": "chat2",
  65. "description": "chat2 description",
  66. "tags": ["tag1", "tag2"],
  67. }
  68. assert data["user_id"] == "2"
  69. assert data["id"] is not None
  70. assert data["share_id"] is None
  71. assert data["title"] == "New Chat"
  72. assert data["updated_at"] is not None
  73. assert data["created_at"] is not None
  74. assert len(self.chats.get_chats()) == 2
  75. def test_get_user_chats(self):
  76. self.test_get_session_user_chat_list()
  77. def test_get_user_archived_chats(self):
  78. self.chats.archive_all_chats_by_user_id("2")
  79. from open_webui.apps.webui.internal.db import Session
  80. Session.commit()
  81. with mock_webui_user(id="2"):
  82. response = self.fast_api_client.get(self.create_url("/all/archived"))
  83. assert response.status_code == 200
  84. first_chat = response.json()[0]
  85. assert first_chat["id"] is not None
  86. assert first_chat["title"] == "New Chat"
  87. assert first_chat["created_at"] is not None
  88. assert first_chat["updated_at"] is not None
  89. def test_get_all_user_chats_in_db(self):
  90. with mock_webui_user(id="4"):
  91. response = self.fast_api_client.get(self.create_url("/all/db"))
  92. assert response.status_code == 200
  93. assert len(response.json()) == 1
  94. def test_get_archived_session_user_chat_list(self):
  95. self.test_get_user_archived_chats()
  96. def test_archive_all_chats(self):
  97. with mock_webui_user(id="2"):
  98. response = self.fast_api_client.post(self.create_url("/archive/all"))
  99. assert response.status_code == 200
  100. assert len(self.chats.get_archived_chats_by_user_id("2")) == 1
  101. def test_get_shared_chat_by_id(self):
  102. chat_id = self.chats.get_chats()[0].id
  103. self.chats.update_chat_share_id_by_id(chat_id, chat_id)
  104. with mock_webui_user(id="2"):
  105. response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}"))
  106. assert response.status_code == 200
  107. data = response.json()
  108. assert data["id"] == chat_id
  109. assert data["chat"] == {
  110. "name": "chat1",
  111. "description": "chat1 description",
  112. "tags": ["tag1", "tag2"],
  113. "history": {"currentId": "1", "messages": []},
  114. }
  115. assert data["id"] == chat_id
  116. assert data["share_id"] == chat_id
  117. assert data["title"] == "New Chat"
  118. def test_get_chat_by_id(self):
  119. chat_id = self.chats.get_chats()[0].id
  120. with mock_webui_user(id="2"):
  121. response = self.fast_api_client.get(self.create_url(f"/{chat_id}"))
  122. assert response.status_code == 200
  123. data = response.json()
  124. assert data["id"] == chat_id
  125. assert data["chat"] == {
  126. "name": "chat1",
  127. "description": "chat1 description",
  128. "tags": ["tag1", "tag2"],
  129. "history": {"currentId": "1", "messages": []},
  130. }
  131. assert data["share_id"] is None
  132. assert data["title"] == "New Chat"
  133. assert data["user_id"] == "2"
  134. def test_update_chat_by_id(self):
  135. chat_id = self.chats.get_chats()[0].id
  136. with mock_webui_user(id="2"):
  137. response = self.fast_api_client.post(
  138. self.create_url(f"/{chat_id}"),
  139. json={
  140. "chat": {
  141. "name": "chat2",
  142. "description": "chat2 description",
  143. "tags": ["tag2", "tag4"],
  144. "title": "Just another title",
  145. }
  146. },
  147. )
  148. assert response.status_code == 200
  149. data = response.json()
  150. assert data["id"] == chat_id
  151. assert data["chat"] == {
  152. "name": "chat2",
  153. "title": "Just another title",
  154. "description": "chat2 description",
  155. "tags": ["tag2", "tag4"],
  156. "history": {"currentId": "1", "messages": []},
  157. }
  158. assert data["share_id"] is None
  159. assert data["title"] == "Just another title"
  160. assert data["user_id"] == "2"
  161. def test_delete_chat_by_id(self):
  162. chat_id = self.chats.get_chats()[0].id
  163. with mock_webui_user(id="2"):
  164. response = self.fast_api_client.delete(self.create_url(f"/{chat_id}"))
  165. assert response.status_code == 200
  166. assert response.json() is True
  167. def test_clone_chat_by_id(self):
  168. chat_id = self.chats.get_chats()[0].id
  169. with mock_webui_user(id="2"):
  170. response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone"))
  171. assert response.status_code == 200
  172. data = response.json()
  173. assert data["id"] != chat_id
  174. assert data["chat"] == {
  175. "branchPointMessageId": "1",
  176. "description": "chat1 description",
  177. "history": {"currentId": "1", "messages": []},
  178. "name": "chat1",
  179. "originalChatId": chat_id,
  180. "tags": ["tag1", "tag2"],
  181. "title": "Clone of New Chat",
  182. }
  183. assert data["share_id"] is None
  184. assert data["title"] == "Clone of New Chat"
  185. assert data["user_id"] == "2"
  186. def test_archive_chat_by_id(self):
  187. chat_id = self.chats.get_chats()[0].id
  188. with mock_webui_user(id="2"):
  189. response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive"))
  190. assert response.status_code == 200
  191. chat = self.chats.get_chat_by_id(chat_id)
  192. assert chat.archived is True
  193. def test_share_chat_by_id(self):
  194. chat_id = self.chats.get_chats()[0].id
  195. with mock_webui_user(id="2"):
  196. response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share"))
  197. assert response.status_code == 200
  198. chat = self.chats.get_chat_by_id(chat_id)
  199. assert chat.share_id is not None
  200. def test_delete_shared_chat_by_id(self):
  201. chat_id = self.chats.get_chats()[0].id
  202. share_id = str(uuid.uuid4())
  203. self.chats.update_chat_share_id_by_id(chat_id, share_id)
  204. with mock_webui_user(id="2"):
  205. response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share"))
  206. assert response.status_code
  207. chat = self.chats.get_chat_by_id(chat_id)
  208. assert chat.share_id is None