models.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. from typing import Optional
  2. from open_webui.models.models import (
  3. ModelForm,
  4. ModelModel,
  5. ModelResponse,
  6. ModelUserResponse,
  7. Models,
  8. )
  9. from open_webui.constants import ERROR_MESSAGES
  10. from fastapi import APIRouter, Depends, HTTPException, Request, status
  11. from open_webui.utils.auth import get_admin_user, get_verified_user
  12. from open_webui.utils.access_control import has_access, has_permission
  13. router = APIRouter()
  14. ###########################
  15. # GetModels
  16. ###########################
  17. @router.get("/", response_model=list[ModelUserResponse])
  18. async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
  19. if user.role == "admin":
  20. return Models.get_models()
  21. else:
  22. return Models.get_models_by_user_id(user.id)
  23. ############################
  24. # CreateNewModel
  25. ############################
  26. @router.post("/create", response_model=Optional[ModelModel])
  27. async def create_new_model(
  28. request: Request,
  29. form_data: ModelForm,
  30. user=Depends(get_verified_user),
  31. ):
  32. if user.role != "admin" and not has_permission(
  33. user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS
  34. ):
  35. raise HTTPException(
  36. status_code=status.HTTP_401_UNAUTHORIZED,
  37. detail=ERROR_MESSAGES.UNAUTHORIZED,
  38. )
  39. model = Models.get_model_by_id(form_data.id)
  40. if model:
  41. raise HTTPException(
  42. status_code=status.HTTP_401_UNAUTHORIZED,
  43. detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
  44. )
  45. else:
  46. model = Models.insert_new_model(form_data, user.id)
  47. if model:
  48. return model
  49. else:
  50. raise HTTPException(
  51. status_code=status.HTTP_401_UNAUTHORIZED,
  52. detail=ERROR_MESSAGES.DEFAULT(),
  53. )
  54. ###########################
  55. # GetModelById
  56. ###########################
  57. # Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
  58. @router.get("/model", response_model=Optional[ModelResponse])
  59. async def get_model_by_id(id: str, user=Depends(get_verified_user)):
  60. model = Models.get_model_by_id(id)
  61. if model:
  62. if (
  63. user.role == "admin"
  64. or model.user_id == user.id
  65. or has_access(user.id, "read", model.access_control)
  66. ):
  67. return model
  68. else:
  69. raise HTTPException(
  70. status_code=status.HTTP_401_UNAUTHORIZED,
  71. detail=ERROR_MESSAGES.NOT_FOUND,
  72. )
  73. ############################
  74. # ToggelModelById
  75. ############################
  76. @router.post("/model/toggle", response_model=Optional[ModelResponse])
  77. async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
  78. model = Models.get_model_by_id(id)
  79. if model:
  80. if (
  81. user.role == "admin"
  82. or model.user_id == user.id
  83. or has_access(user.id, "write", model.access_control)
  84. ):
  85. model = Models.toggle_model_by_id(id)
  86. if model:
  87. return model
  88. else:
  89. raise HTTPException(
  90. status_code=status.HTTP_400_BAD_REQUEST,
  91. detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
  92. )
  93. else:
  94. raise HTTPException(
  95. status_code=status.HTTP_401_UNAUTHORIZED,
  96. detail=ERROR_MESSAGES.UNAUTHORIZED,
  97. )
  98. else:
  99. raise HTTPException(
  100. status_code=status.HTTP_401_UNAUTHORIZED,
  101. detail=ERROR_MESSAGES.NOT_FOUND,
  102. )
  103. ############################
  104. # UpdateModelById
  105. ############################
  106. @router.post("/model/update", response_model=Optional[ModelModel])
  107. async def update_model_by_id(
  108. id: str,
  109. form_data: ModelForm,
  110. user=Depends(get_verified_user),
  111. ):
  112. model = Models.get_model_by_id(id)
  113. if not model:
  114. raise HTTPException(
  115. status_code=status.HTTP_401_UNAUTHORIZED,
  116. detail=ERROR_MESSAGES.NOT_FOUND,
  117. )
  118. if (
  119. model.user_id != user.id
  120. and not has_access(user.id, "write", model.access_control)
  121. and user.role != "admin"
  122. ):
  123. raise HTTPException(
  124. status_code=status.HTTP_400_BAD_REQUEST,
  125. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  126. )
  127. model = Models.update_model_by_id(id, form_data)
  128. return model
  129. ############################
  130. # DeleteModelById
  131. ############################
  132. @router.delete("/model/delete", response_model=bool)
  133. async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
  134. model = Models.get_model_by_id(id)
  135. if not model:
  136. raise HTTPException(
  137. status_code=status.HTTP_401_UNAUTHORIZED,
  138. detail=ERROR_MESSAGES.NOT_FOUND,
  139. )
  140. if (
  141. user.role != "admin"
  142. and model.user_id != user.id
  143. and not has_access(user.id, "write", model.access_control)
  144. ):
  145. raise HTTPException(
  146. status_code=status.HTTP_401_UNAUTHORIZED,
  147. detail=ERROR_MESSAGES.UNAUTHORIZED,
  148. )
  149. result = Models.delete_model_by_id(id)
  150. return result
  151. @router.delete("/delete/all", response_model=bool)
  152. async def delete_all_models(user=Depends(get_admin_user)):
  153. result = Models.delete_all_models()
  154. return result