test_chats.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import uuid
  2. from test.util.abstract_integration_test import AbstractPostgresTest
  3. from test.util.mock_user import mock_webui_user
  4. class TestChats(AbstractPostgresTest):
  5. BASE_PATH = "/api/v1/chats"
  6. def setup_class(cls):
  7. super().setup_class()
  8. def setup_method(self):
  9. super().setup_method()
  10. from apps.webui.models.chats import ChatForm
  11. from apps.webui.models.chats import Chats
  12. self.chats = Chats
  13. self.chats.insert_new_chat(
  14. "2",
  15. ChatForm(
  16. **{
  17. "chat": {
  18. "name": "chat1",
  19. "description": "chat1 description",
  20. "tags": ["tag1", "tag2"],
  21. "history": {"currentId": "1", "messages": []},
  22. }
  23. }
  24. ),
  25. )
  26. def test_get_session_user_chat_list(self):
  27. with mock_webui_user(id="2"):
  28. response = self.fast_api_client.get(self.create_url("/"))
  29. assert response.status_code == 200
  30. first_chat = response.json()[0]
  31. assert first_chat["id"] is not None
  32. assert first_chat["title"] == "New Chat"
  33. assert first_chat["created_at"] is not None
  34. assert first_chat["updated_at"] is not None
  35. def test_delete_all_user_chats(self):
  36. with mock_webui_user(id="2"):
  37. response = self.fast_api_client.delete(self.create_url("/"))
  38. assert response.status_code == 200
  39. assert len(self.chats.get_chats()) == 0
  40. def test_get_user_chat_list_by_user_id(self):
  41. with mock_webui_user(id="3"):
  42. response = self.fast_api_client.get(self.create_url("/list/user/2"))
  43. assert response.status_code == 200
  44. first_chat = response.json()[0]
  45. assert first_chat["id"] is not None
  46. assert first_chat["title"] == "New Chat"
  47. assert first_chat["created_at"] is not None
  48. assert first_chat["updated_at"] is not None
  49. def test_create_new_chat(self):
  50. with mock_webui_user(id="2"):
  51. response = self.fast_api_client.post(
  52. self.create_url("/new"),
  53. json={
  54. "chat": {
  55. "name": "chat2",
  56. "description": "chat2 description",
  57. "tags": ["tag1", "tag2"],
  58. }
  59. },
  60. )
  61. assert response.status_code == 200
  62. data = response.json()
  63. assert data["archived"] is False
  64. assert data["chat"] == {
  65. "name": "chat2",
  66. "description": "chat2 description",
  67. "tags": ["tag1", "tag2"],
  68. }
  69. assert data["user_id"] == "2"
  70. assert data["id"] is not None
  71. assert data["share_id"] is None
  72. assert data["title"] == "New Chat"
  73. assert data["updated_at"] is not None
  74. assert data["created_at"] is not None
  75. assert len(self.chats.get_chats()) == 2
  76. def test_get_user_chats(self):
  77. self.test_get_session_user_chat_list()
  78. def test_get_user_archived_chats(self):
  79. self.chats.archive_all_chats_by_user_id("2")
  80. from apps.webui.internal.db import Session
  81. Session.commit()
  82. with mock_webui_user(id="2"):
  83. response = self.fast_api_client.get(self.create_url("/all/archived"))
  84. assert response.status_code == 200
  85. first_chat = response.json()[0]
  86. assert first_chat["id"] is not None
  87. assert first_chat["title"] == "New Chat"
  88. assert first_chat["created_at"] is not None
  89. assert first_chat["updated_at"] is not None
  90. def test_get_all_user_chats_in_db(self):
  91. with mock_webui_user(id="4"):
  92. response = self.fast_api_client.get(self.create_url("/all/db"))
  93. assert response.status_code == 200
  94. assert len(response.json()) == 1
  95. def test_get_archived_session_user_chat_list(self):
  96. self.test_get_user_archived_chats()
  97. def test_archive_all_chats(self):
  98. with mock_webui_user(id="2"):
  99. response = self.fast_api_client.post(self.create_url("/archive/all"))
  100. assert response.status_code == 200
  101. assert len(self.chats.get_archived_chats_by_user_id("2")) == 1
  102. def test_get_shared_chat_by_id(self):
  103. chat_id = self.chats.get_chats()[0].id
  104. self.chats.update_chat_share_id_by_id(chat_id, chat_id)
  105. with mock_webui_user(id="2"):
  106. response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}"))
  107. assert response.status_code == 200
  108. data = response.json()
  109. assert data["id"] == chat_id
  110. assert data["chat"] == {
  111. "name": "chat1",
  112. "description": "chat1 description",
  113. "tags": ["tag1", "tag2"],
  114. "history": {"currentId": "1", "messages": []},
  115. }
  116. assert data["id"] == chat_id
  117. assert data["share_id"] == chat_id
  118. assert data["title"] == "New Chat"
  119. def test_get_chat_by_id(self):
  120. chat_id = self.chats.get_chats()[0].id
  121. with mock_webui_user(id="2"):
  122. response = self.fast_api_client.get(self.create_url(f"/{chat_id}"))
  123. assert response.status_code == 200
  124. data = response.json()
  125. assert data["id"] == chat_id
  126. assert data["chat"] == {
  127. "name": "chat1",
  128. "description": "chat1 description",
  129. "tags": ["tag1", "tag2"],
  130. "history": {"currentId": "1", "messages": []},
  131. }
  132. assert data["share_id"] is None
  133. assert data["title"] == "New Chat"
  134. assert data["user_id"] == "2"
  135. def test_update_chat_by_id(self):
  136. chat_id = self.chats.get_chats()[0].id
  137. with mock_webui_user(id="2"):
  138. response = self.fast_api_client.post(
  139. self.create_url(f"/{chat_id}"),
  140. json={
  141. "chat": {
  142. "name": "chat2",
  143. "description": "chat2 description",
  144. "tags": ["tag2", "tag4"],
  145. "title": "Just another title",
  146. }
  147. },
  148. )
  149. assert response.status_code == 200
  150. data = response.json()
  151. assert data["id"] == chat_id
  152. assert data["chat"] == {
  153. "name": "chat2",
  154. "title": "Just another title",
  155. "description": "chat2 description",
  156. "tags": ["tag2", "tag4"],
  157. "history": {"currentId": "1", "messages": []},
  158. }
  159. assert data["share_id"] is None
  160. assert data["title"] == "Just another title"
  161. assert data["user_id"] == "2"
  162. def test_delete_chat_by_id(self):
  163. chat_id = self.chats.get_chats()[0].id
  164. with mock_webui_user(id="2"):
  165. response = self.fast_api_client.delete(self.create_url(f"/{chat_id}"))
  166. assert response.status_code == 200
  167. assert response.json() is True
  168. def test_clone_chat_by_id(self):
  169. chat_id = self.chats.get_chats()[0].id
  170. with mock_webui_user(id="2"):
  171. response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone"))
  172. assert response.status_code == 200
  173. data = response.json()
  174. assert data["id"] != chat_id
  175. assert data["chat"] == {
  176. "branchPointMessageId": "1",
  177. "description": "chat1 description",
  178. "history": {"currentId": "1", "messages": []},
  179. "name": "chat1",
  180. "originalChatId": chat_id,
  181. "tags": ["tag1", "tag2"],
  182. "title": "Clone of New Chat",
  183. }
  184. assert data["share_id"] is None
  185. assert data["title"] == "Clone of New Chat"
  186. assert data["user_id"] == "2"
  187. def test_archive_chat_by_id(self):
  188. chat_id = self.chats.get_chats()[0].id
  189. with mock_webui_user(id="2"):
  190. response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive"))
  191. assert response.status_code == 200
  192. chat = self.chats.get_chat_by_id(chat_id)
  193. assert chat.archived is True
  194. def test_share_chat_by_id(self):
  195. chat_id = self.chats.get_chats()[0].id
  196. with mock_webui_user(id="2"):
  197. response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share"))
  198. assert response.status_code == 200
  199. chat = self.chats.get_chat_by_id(chat_id)
  200. assert chat.share_id is not None
  201. def test_delete_shared_chat_by_id(self):
  202. chat_id = self.chats.get_chats()[0].id
  203. share_id = str(uuid.uuid4())
  204. self.chats.update_chat_share_id_by_id(chat_id, share_id)
  205. with mock_webui_user(id="2"):
  206. response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share"))
  207. assert response.status_code
  208. chat = self.chats.get_chat_by_id(chat_id)
  209. assert chat.share_id is None