test_chats.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. import uuid
  2. from test.util.abstract_integration_test import AbstractPostgresTest
  3. from test.util.mock_user import mock_webui_user
  4. class TestChats(AbstractPostgresTest):
  5. BASE_PATH = "/api/v1/chats"
  6. def setup_class(cls):
  7. super().setup_class()
  8. def setup_method(self):
  9. super().setup_method()
  10. from apps.webui.models.chats import ChatForm
  11. from apps.webui.models.chats import Chats
  12. self.chats = Chats
  13. self.chats.insert_new_chat(
  14. self.db_session,
  15. "2",
  16. ChatForm(
  17. **{
  18. "chat": {
  19. "name": "chat1",
  20. "description": "chat1 description",
  21. "tags": ["tag1", "tag2"],
  22. "history": {"currentId": "1", "messages": []},
  23. }
  24. }
  25. ),
  26. )
  27. def test_get_session_user_chat_list(self):
  28. with mock_webui_user(id="2"):
  29. response = self.fast_api_client.get(self.create_url("/"))
  30. assert response.status_code == 200
  31. first_chat = response.json()[0]
  32. assert first_chat["id"] is not None
  33. assert first_chat["title"] == "New Chat"
  34. assert first_chat["created_at"] is not None
  35. assert first_chat["updated_at"] is not None
  36. def test_delete_all_user_chats(self):
  37. with mock_webui_user(id="2"):
  38. response = self.fast_api_client.delete(self.create_url("/"))
  39. assert response.status_code == 200
  40. assert len(self.chats.get_chats(self.db_session)) == 0
  41. def test_get_user_chat_list_by_user_id(self):
  42. with mock_webui_user(id="3"):
  43. response = self.fast_api_client.get(self.create_url("/list/user/2"))
  44. assert response.status_code == 200
  45. first_chat = response.json()[0]
  46. assert first_chat["id"] is not None
  47. assert first_chat["title"] == "New Chat"
  48. assert first_chat["created_at"] is not None
  49. assert first_chat["updated_at"] is not None
  50. def test_create_new_chat(self):
  51. with mock_webui_user(id="2"):
  52. response = self.fast_api_client.post(
  53. self.create_url("/new"),
  54. json={
  55. "chat": {
  56. "name": "chat2",
  57. "description": "chat2 description",
  58. "tags": ["tag1", "tag2"],
  59. }
  60. },
  61. )
  62. assert response.status_code == 200
  63. data = response.json()
  64. assert data["archived"] is False
  65. assert data["chat"] == {
  66. "name": "chat2",
  67. "description": "chat2 description",
  68. "tags": ["tag1", "tag2"],
  69. }
  70. assert data["user_id"] == "2"
  71. assert data["id"] is not None
  72. assert data["share_id"] is None
  73. assert data["title"] == "New Chat"
  74. assert data["updated_at"] is not None
  75. assert data["created_at"] is not None
  76. assert len(self.chats.get_chats(self.db_session)) == 2
  77. def test_get_user_chats(self):
  78. self.test_get_session_user_chat_list()
  79. def test_get_user_archived_chats(self):
  80. self.chats.archive_all_chats_by_user_id(self.db_session, "2")
  81. self.db_session.commit()
  82. with mock_webui_user(id="2"):
  83. response = self.fast_api_client.get(self.create_url("/all/archived"))
  84. assert response.status_code == 200
  85. first_chat = response.json()[0]
  86. assert first_chat["id"] is not None
  87. assert first_chat["title"] == "New Chat"
  88. assert first_chat["created_at"] is not None
  89. assert first_chat["updated_at"] is not None
  90. def test_get_all_user_chats_in_db(self):
  91. with mock_webui_user(id="4"):
  92. response = self.fast_api_client.get(self.create_url("/all/db"))
  93. assert response.status_code == 200
  94. assert len(response.json()) == 1
  95. def test_get_archived_session_user_chat_list(self):
  96. self.test_get_user_archived_chats()
  97. def test_archive_all_chats(self):
  98. with mock_webui_user(id="2"):
  99. response = self.fast_api_client.post(self.create_url("/archive/all"))
  100. assert response.status_code == 200
  101. assert len(self.chats.get_archived_chats_by_user_id(self.db_session, "2")) == 1
  102. def test_get_shared_chat_by_id(self):
  103. chat_id = self.chats.get_chats(self.db_session)[0].id
  104. self.chats.update_chat_share_id_by_id(self.db_session, chat_id, chat_id)
  105. self.db_session.commit()
  106. with mock_webui_user(id="2"):
  107. response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}"))
  108. assert response.status_code == 200
  109. data = response.json()
  110. assert data["id"] == chat_id
  111. assert data["chat"] == {
  112. "name": "chat1",
  113. "description": "chat1 description",
  114. "tags": ["tag1", "tag2"],
  115. "history": {"currentId": "1", "messages": []},
  116. }
  117. assert data["id"] == chat_id
  118. assert data["share_id"] == chat_id
  119. assert data["title"] == "New Chat"
  120. def test_get_chat_by_id(self):
  121. chat_id = self.chats.get_chats(self.db_session)[0].id
  122. with mock_webui_user(id="2"):
  123. response = self.fast_api_client.get(self.create_url(f"/{chat_id}"))
  124. assert response.status_code == 200
  125. data = response.json()
  126. assert data["id"] == chat_id
  127. assert data["chat"] == {
  128. "name": "chat1",
  129. "description": "chat1 description",
  130. "tags": ["tag1", "tag2"],
  131. "history": {"currentId": "1", "messages": []},
  132. }
  133. assert data["share_id"] is None
  134. assert data["title"] == "New Chat"
  135. assert data["user_id"] == "2"
  136. def test_update_chat_by_id(self):
  137. chat_id = self.chats.get_chats(self.db_session)[0].id
  138. with mock_webui_user(id="2"):
  139. response = self.fast_api_client.post(
  140. self.create_url(f"/{chat_id}"),
  141. json={
  142. "chat": {
  143. "name": "chat2",
  144. "description": "chat2 description",
  145. "tags": ["tag2", "tag4"],
  146. "title": "Just another title",
  147. }
  148. },
  149. )
  150. assert response.status_code == 200
  151. data = response.json()
  152. assert data["id"] == chat_id
  153. assert data["chat"] == {
  154. "name": "chat2",
  155. "title": "Just another title",
  156. "description": "chat2 description",
  157. "tags": ["tag2", "tag4"],
  158. "history": {"currentId": "1", "messages": []},
  159. }
  160. assert data["share_id"] is None
  161. assert data["title"] == "Just another title"
  162. assert data["user_id"] == "2"
  163. def test_delete_chat_by_id(self):
  164. chat_id = self.chats.get_chats(self.db_session)[0].id
  165. with mock_webui_user(id="2"):
  166. response = self.fast_api_client.delete(self.create_url(f"/{chat_id}"))
  167. assert response.status_code == 200
  168. assert response.json() is True
  169. def test_clone_chat_by_id(self):
  170. chat_id = self.chats.get_chats(self.db_session)[0].id
  171. with mock_webui_user(id="2"):
  172. response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone"))
  173. assert response.status_code == 200
  174. data = response.json()
  175. assert data["id"] != chat_id
  176. assert data["chat"] == {
  177. "branchPointMessageId": "1",
  178. "description": "chat1 description",
  179. "history": {"currentId": "1", "messages": []},
  180. "name": "chat1",
  181. "originalChatId": chat_id,
  182. "tags": ["tag1", "tag2"],
  183. "title": "Clone of New Chat",
  184. }
  185. assert data["share_id"] is None
  186. assert data["title"] == "Clone of New Chat"
  187. assert data["user_id"] == "2"
  188. def test_archive_chat_by_id(self):
  189. chat_id = self.chats.get_chats(self.db_session)[0].id
  190. with mock_webui_user(id="2"):
  191. response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive"))
  192. assert response.status_code == 200
  193. chat = self.chats.get_chat_by_id(self.db_session, chat_id)
  194. assert chat.archived is True
  195. def test_share_chat_by_id(self):
  196. chat_id = self.chats.get_chats(self.db_session)[0].id
  197. with mock_webui_user(id="2"):
  198. response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share"))
  199. assert response.status_code == 200
  200. chat = self.chats.get_chat_by_id(self.db_session, chat_id)
  201. assert chat.share_id is not None
  202. def test_delete_shared_chat_by_id(self):
  203. chat_id = self.chats.get_chats(self.db_session)[0].id
  204. share_id = str(uuid.uuid4())
  205. self.chats.update_chat_share_id_by_id(self.db_session, chat_id, share_id)
  206. self.db_session.commit()
  207. with mock_webui_user(id="2"):
  208. response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share"))
  209. assert response.status_code
  210. chat = self.chats.get_chat_by_id(self.db_session, chat_id)
  211. assert chat.share_id is None