models.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. from typing import Optional
  2. from open_webui.models.models import (
  3. ModelForm,
  4. ModelModel,
  5. ModelResponse,
  6. ModelUserResponse,
  7. Models,
  8. )
  9. from open_webui.constants import ERROR_MESSAGES
  10. from fastapi import APIRouter, Depends, HTTPException, Request, status
  11. from open_webui.utils.auth import get_admin_user, get_verified_user
  12. from open_webui.utils.access_control import has_access, has_permission
  13. router = APIRouter()
  14. ###########################
  15. # GetModels
  16. ###########################
  17. @router.get("/", response_model=list[ModelUserResponse])
  18. async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
  19. if user.role == "admin":
  20. return Models.get_models()
  21. else:
  22. return Models.get_models_by_user_id(user.id)
  23. ###########################
  24. # GetBaseModels
  25. ###########################
  26. @router.get("/base", response_model=list[ModelResponse])
  27. async def get_base_models(user=Depends(get_admin_user)):
  28. return Models.get_base_models()
  29. ############################
  30. # CreateNewModel
  31. ############################
  32. @router.post("/create", response_model=Optional[ModelModel])
  33. async def create_new_model(
  34. request: Request,
  35. form_data: ModelForm,
  36. user=Depends(get_verified_user),
  37. ):
  38. if user.role != "admin" and not has_permission(
  39. user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS
  40. ):
  41. raise HTTPException(
  42. status_code=status.HTTP_401_UNAUTHORIZED,
  43. detail=ERROR_MESSAGES.UNAUTHORIZED,
  44. )
  45. model = Models.get_model_by_id(form_data.id)
  46. if model:
  47. raise HTTPException(
  48. status_code=status.HTTP_401_UNAUTHORIZED,
  49. detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
  50. )
  51. else:
  52. model = Models.insert_new_model(form_data, user.id)
  53. if model:
  54. return model
  55. else:
  56. raise HTTPException(
  57. status_code=status.HTTP_401_UNAUTHORIZED,
  58. detail=ERROR_MESSAGES.DEFAULT(),
  59. )
  60. ###########################
  61. # GetModelById
  62. ###########################
  63. # Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
  64. @router.get("/model", response_model=Optional[ModelResponse])
  65. async def get_model_by_id(id: str, user=Depends(get_verified_user)):
  66. model = Models.get_model_by_id(id)
  67. if model:
  68. if (
  69. user.role == "admin"
  70. or model.user_id == user.id
  71. or has_access(user.id, "read", model.access_control)
  72. ):
  73. return model
  74. else:
  75. raise HTTPException(
  76. status_code=status.HTTP_401_UNAUTHORIZED,
  77. detail=ERROR_MESSAGES.NOT_FOUND,
  78. )
  79. ############################
  80. # ToggelModelById
  81. ############################
  82. @router.post("/model/toggle", response_model=Optional[ModelResponse])
  83. async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
  84. model = Models.get_model_by_id(id)
  85. if model:
  86. if (
  87. user.role == "admin"
  88. or model.user_id == user.id
  89. or has_access(user.id, "write", model.access_control)
  90. ):
  91. model = Models.toggle_model_by_id(id)
  92. if model:
  93. return model
  94. else:
  95. raise HTTPException(
  96. status_code=status.HTTP_400_BAD_REQUEST,
  97. detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
  98. )
  99. else:
  100. raise HTTPException(
  101. status_code=status.HTTP_401_UNAUTHORIZED,
  102. detail=ERROR_MESSAGES.UNAUTHORIZED,
  103. )
  104. else:
  105. raise HTTPException(
  106. status_code=status.HTTP_401_UNAUTHORIZED,
  107. detail=ERROR_MESSAGES.NOT_FOUND,
  108. )
  109. ############################
  110. # UpdateModelById
  111. ############################
  112. @router.post("/model/update", response_model=Optional[ModelModel])
  113. async def update_model_by_id(
  114. id: str,
  115. form_data: ModelForm,
  116. user=Depends(get_verified_user),
  117. ):
  118. model = Models.get_model_by_id(id)
  119. if not model:
  120. raise HTTPException(
  121. status_code=status.HTTP_401_UNAUTHORIZED,
  122. detail=ERROR_MESSAGES.NOT_FOUND,
  123. )
  124. if (
  125. model.user_id != user.id
  126. and not has_access(user.id, "write", model.access_control)
  127. and user.role != "admin"
  128. ):
  129. raise HTTPException(
  130. status_code=status.HTTP_400_BAD_REQUEST,
  131. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  132. )
  133. model = Models.update_model_by_id(id, form_data)
  134. return model
  135. ############################
  136. # DeleteModelById
  137. ############################
  138. @router.delete("/model/delete", response_model=bool)
  139. async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
  140. model = Models.get_model_by_id(id)
  141. if not model:
  142. raise HTTPException(
  143. status_code=status.HTTP_401_UNAUTHORIZED,
  144. detail=ERROR_MESSAGES.NOT_FOUND,
  145. )
  146. if (
  147. user.role != "admin"
  148. and model.user_id != user.id
  149. and not has_access(user.id, "write", model.access_control)
  150. ):
  151. raise HTTPException(
  152. status_code=status.HTTP_401_UNAUTHORIZED,
  153. detail=ERROR_MESSAGES.UNAUTHORIZED,
  154. )
  155. result = Models.delete_model_by_id(id)
  156. return result
  157. @router.delete("/delete/all", response_model=bool)
  158. async def delete_all_models(user=Depends(get_admin_user)):
  159. result = Models.delete_all_models()
  160. return result