knowledge.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. import json
  2. from typing import Optional, Union
  3. from pydantic import BaseModel
  4. from fastapi import APIRouter, Depends, HTTPException, status
  5. import logging
  6. from open_webui.apps.webui.models.knowledge import (
  7. Knowledges,
  8. KnowledgeUpdateForm,
  9. KnowledgeForm,
  10. KnowledgeResponse,
  11. )
  12. from open_webui.apps.webui.models.files import Files, FileModel
  13. from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
  14. from open_webui.apps.retrieval.main import process_file, ProcessFileForm
  15. from open_webui.constants import ERROR_MESSAGES
  16. from open_webui.utils.utils import get_admin_user, get_verified_user
  17. from open_webui.env import SRC_LOG_LEVELS
  18. log = logging.getLogger(__name__)
  19. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  20. router = APIRouter()
  21. ############################
  22. # GetKnowledgeItems
  23. ############################
  24. @router.get(
  25. "/", response_model=Optional[Union[list[KnowledgeResponse], KnowledgeResponse]]
  26. )
  27. async def get_knowledge_items(
  28. id: Optional[str] = None, user=Depends(get_verified_user)
  29. ):
  30. if id:
  31. knowledge = Knowledges.get_knowledge_by_id(id=id)
  32. if knowledge:
  33. return knowledge
  34. else:
  35. raise HTTPException(
  36. status_code=status.HTTP_401_UNAUTHORIZED,
  37. detail=ERROR_MESSAGES.NOT_FOUND,
  38. )
  39. else:
  40. return [
  41. KnowledgeResponse(
  42. **knowledge.model_dump(),
  43. files=Files.get_file_metadatas_by_ids(
  44. knowledge.data.get("file_ids", []) if knowledge.data else []
  45. ),
  46. )
  47. for knowledge in Knowledges.get_knowledge_items()
  48. ]
  49. ############################
  50. # CreateNewKnowledge
  51. ############################
  52. @router.post("/create", response_model=Optional[KnowledgeResponse])
  53. async def create_new_knowledge(form_data: KnowledgeForm, user=Depends(get_admin_user)):
  54. knowledge = Knowledges.insert_new_knowledge(user.id, form_data)
  55. if knowledge:
  56. return knowledge
  57. else:
  58. raise HTTPException(
  59. status_code=status.HTTP_400_BAD_REQUEST,
  60. detail=ERROR_MESSAGES.FILE_EXISTS,
  61. )
  62. ############################
  63. # GetKnowledgeById
  64. ############################
  65. class KnowledgeFilesResponse(KnowledgeResponse):
  66. files: list[FileModel]
  67. @router.get("/{id}", response_model=Optional[KnowledgeFilesResponse])
  68. async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
  69. knowledge = Knowledges.get_knowledge_by_id(id=id)
  70. if knowledge:
  71. file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
  72. files = Files.get_files_by_ids(file_ids)
  73. return KnowledgeFilesResponse(
  74. **knowledge.model_dump(),
  75. files=files,
  76. )
  77. else:
  78. raise HTTPException(
  79. status_code=status.HTTP_401_UNAUTHORIZED,
  80. detail=ERROR_MESSAGES.NOT_FOUND,
  81. )
  82. ############################
  83. # UpdateKnowledgeById
  84. ############################
  85. @router.post("/{id}/update", response_model=Optional[KnowledgeFilesResponse])
  86. async def update_knowledge_by_id(
  87. id: str,
  88. form_data: KnowledgeUpdateForm,
  89. user=Depends(get_admin_user),
  90. ):
  91. knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
  92. if knowledge:
  93. file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
  94. files = Files.get_files_by_ids(file_ids)
  95. return KnowledgeFilesResponse(
  96. **knowledge.model_dump(),
  97. files=files,
  98. )
  99. else:
  100. raise HTTPException(
  101. status_code=status.HTTP_400_BAD_REQUEST,
  102. detail=ERROR_MESSAGES.ID_TAKEN,
  103. )
  104. ############################
  105. # AddFileToKnowledge
  106. ############################
  107. class KnowledgeFileIdForm(BaseModel):
  108. file_id: str
  109. @router.post("/{id}/file/add", response_model=Optional[KnowledgeFilesResponse])
  110. def add_file_to_knowledge_by_id(
  111. id: str,
  112. form_data: KnowledgeFileIdForm,
  113. user=Depends(get_admin_user),
  114. ):
  115. knowledge = Knowledges.get_knowledge_by_id(id=id)
  116. file = Files.get_file_by_id(form_data.file_id)
  117. if not file:
  118. raise HTTPException(
  119. status_code=status.HTTP_400_BAD_REQUEST,
  120. detail=ERROR_MESSAGES.NOT_FOUND,
  121. )
  122. if not file.data:
  123. raise HTTPException(
  124. status_code=status.HTTP_400_BAD_REQUEST,
  125. detail=ERROR_MESSAGES.FILE_NOT_PROCESSED,
  126. )
  127. # Add content to the vector database
  128. try:
  129. process_file(ProcessFileForm(file_id=form_data.file_id, collection_name=id))
  130. except Exception as e:
  131. log.debug(e)
  132. raise HTTPException(
  133. status_code=status.HTTP_400_BAD_REQUEST,
  134. detail=str(e),
  135. )
  136. if knowledge:
  137. data = knowledge.data or {}
  138. file_ids = data.get("file_ids", [])
  139. if form_data.file_id not in file_ids:
  140. file_ids.append(form_data.file_id)
  141. data["file_ids"] = file_ids
  142. knowledge = Knowledges.update_knowledge_by_id(
  143. id=id, form_data=KnowledgeUpdateForm(data=data)
  144. )
  145. if knowledge:
  146. files = Files.get_files_by_ids(file_ids)
  147. return KnowledgeFilesResponse(
  148. **knowledge.model_dump(),
  149. files=files,
  150. )
  151. else:
  152. raise HTTPException(
  153. status_code=status.HTTP_400_BAD_REQUEST,
  154. detail=ERROR_MESSAGES.DEFAULT("knowledge"),
  155. )
  156. else:
  157. raise HTTPException(
  158. status_code=status.HTTP_400_BAD_REQUEST,
  159. detail=ERROR_MESSAGES.DEFAULT("file_id"),
  160. )
  161. else:
  162. raise HTTPException(
  163. status_code=status.HTTP_400_BAD_REQUEST,
  164. detail=ERROR_MESSAGES.NOT_FOUND,
  165. )
  166. @router.post("/{id}/file/update", response_model=Optional[KnowledgeFilesResponse])
  167. def update_file_from_knowledge_by_id(
  168. id: str,
  169. form_data: KnowledgeFileIdForm,
  170. user=Depends(get_admin_user),
  171. ):
  172. knowledge = Knowledges.get_knowledge_by_id(id=id)
  173. file = Files.get_file_by_id(form_data.file_id)
  174. if not file:
  175. raise HTTPException(
  176. status_code=status.HTTP_400_BAD_REQUEST,
  177. detail=ERROR_MESSAGES.NOT_FOUND,
  178. )
  179. # Remove content from the vector database
  180. VECTOR_DB_CLIENT.delete(
  181. collection_name=knowledge.id, filter={"file_id": form_data.file_id}
  182. )
  183. # Add content to the vector database
  184. try:
  185. process_file(ProcessFileForm(file_id=form_data.file_id, collection_name=id))
  186. except Exception as e:
  187. raise HTTPException(
  188. status_code=status.HTTP_400_BAD_REQUEST,
  189. detail=str(e),
  190. )
  191. if knowledge:
  192. data = knowledge.data or {}
  193. file_ids = data.get("file_ids", [])
  194. files = Files.get_files_by_ids(file_ids)
  195. return KnowledgeFilesResponse(
  196. **knowledge.model_dump(),
  197. files=files,
  198. )
  199. else:
  200. raise HTTPException(
  201. status_code=status.HTTP_400_BAD_REQUEST,
  202. detail=ERROR_MESSAGES.NOT_FOUND,
  203. )
  204. ############################
  205. # RemoveFileFromKnowledge
  206. ############################
  207. @router.post("/{id}/file/remove", response_model=Optional[KnowledgeFilesResponse])
  208. def remove_file_from_knowledge_by_id(
  209. id: str,
  210. form_data: KnowledgeFileIdForm,
  211. user=Depends(get_admin_user),
  212. ):
  213. knowledge = Knowledges.get_knowledge_by_id(id=id)
  214. file = Files.get_file_by_id(form_data.file_id)
  215. if not file:
  216. raise HTTPException(
  217. status_code=status.HTTP_400_BAD_REQUEST,
  218. detail=ERROR_MESSAGES.NOT_FOUND,
  219. )
  220. # Remove content from the vector database
  221. VECTOR_DB_CLIENT.delete(
  222. collection_name=knowledge.id, filter={"file_id": form_data.file_id}
  223. )
  224. result = VECTOR_DB_CLIENT.query(
  225. collection_name=knowledge.id,
  226. filter={"file_id": form_data.file_id},
  227. )
  228. Files.delete_file_by_id(form_data.file_id)
  229. if knowledge:
  230. data = knowledge.data or {}
  231. file_ids = data.get("file_ids", [])
  232. if form_data.file_id in file_ids:
  233. file_ids.remove(form_data.file_id)
  234. data["file_ids"] = file_ids
  235. knowledge = Knowledges.update_knowledge_by_id(
  236. id=id, form_data=KnowledgeUpdateForm(data=data)
  237. )
  238. if knowledge:
  239. files = Files.get_files_by_ids(file_ids)
  240. return KnowledgeFilesResponse(
  241. **knowledge.model_dump(),
  242. files=files,
  243. )
  244. else:
  245. raise HTTPException(
  246. status_code=status.HTTP_400_BAD_REQUEST,
  247. detail=ERROR_MESSAGES.DEFAULT("knowledge"),
  248. )
  249. else:
  250. raise HTTPException(
  251. status_code=status.HTTP_400_BAD_REQUEST,
  252. detail=ERROR_MESSAGES.DEFAULT("file_id"),
  253. )
  254. else:
  255. raise HTTPException(
  256. status_code=status.HTTP_400_BAD_REQUEST,
  257. detail=ERROR_MESSAGES.NOT_FOUND,
  258. )
  259. ############################
  260. # ResetKnowledgeById
  261. ############################
  262. @router.post("/{id}/reset", response_model=Optional[KnowledgeResponse])
  263. async def reset_knowledge_by_id(id: str, user=Depends(get_admin_user)):
  264. try:
  265. VECTOR_DB_CLIENT.delete_collection(collection_name=id)
  266. except Exception as e:
  267. log.debug(e)
  268. pass
  269. knowledge = Knowledges.update_knowledge_by_id(
  270. id=id, form_data=KnowledgeUpdateForm(data={"file_ids": []})
  271. )
  272. return knowledge
  273. ############################
  274. # DeleteKnowledgeById
  275. ############################
  276. @router.delete("/{id}/delete", response_model=bool)
  277. async def delete_knowledge_by_id(id: str, user=Depends(get_admin_user)):
  278. try:
  279. VECTOR_DB_CLIENT.delete_collection(collection_name=id)
  280. except Exception as e:
  281. log.debug(e)
  282. pass
  283. result = Knowledges.delete_knowledge_by_id(id=id)
  284. return result