modelfiles.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. from fastapi import Depends, FastAPI, HTTPException, status
  2. from datetime import datetime, timedelta
  3. from typing import List, Union, Optional
  4. from fastapi import APIRouter
  5. from pydantic import BaseModel
  6. import json
  7. from apps.web.models.users import Users
  8. from apps.web.models.modelfiles import (
  9. Modelfiles,
  10. ModelfileForm,
  11. ModelfileTagNameForm,
  12. ModelfileUpdateForm,
  13. ModelfileResponse,
  14. )
  15. from utils.utils import bearer_scheme, get_current_user
  16. from constants import ERROR_MESSAGES
  17. router = APIRouter()
  18. ############################
  19. # GetModelfiles
  20. ############################
  21. @router.get("/", response_model=List[ModelfileResponse])
  22. async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
  23. return Modelfiles.get_modelfiles(skip, limit)
  24. ############################
  25. # CreateNewModelfile
  26. ############################
  27. @router.post("/create", response_model=Optional[ModelfileResponse])
  28. async def create_new_modelfile(
  29. form_data: ModelfileForm, user=Depends(get_current_user)
  30. ):
  31. if user.role != "admin":
  32. raise HTTPException(
  33. status_code=status.HTTP_401_UNAUTHORIZED,
  34. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  35. )
  36. modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
  37. if modelfile:
  38. return ModelfileResponse(
  39. **{
  40. **modelfile.model_dump(),
  41. "modelfile": json.loads(modelfile.modelfile),
  42. }
  43. )
  44. else:
  45. raise HTTPException(
  46. status_code=status.HTTP_401_UNAUTHORIZED,
  47. detail=ERROR_MESSAGES.DEFAULT(),
  48. )
  49. ############################
  50. # GetModelfileByTagName
  51. ############################
  52. @router.post("/", response_model=Optional[ModelfileResponse])
  53. async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm):
  54. modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
  55. if modelfile:
  56. return ModelfileResponse(
  57. **{
  58. **modelfile.model_dump(),
  59. "modelfile": json.loads(modelfile.modelfile),
  60. }
  61. )
  62. else:
  63. raise HTTPException(
  64. status_code=status.HTTP_401_UNAUTHORIZED,
  65. detail=ERROR_MESSAGES.NOT_FOUND,
  66. )
  67. ############################
  68. # UpdateModelfileByTagName
  69. ############################
  70. @router.post("/update", response_model=Optional[ModelfileResponse])
  71. async def update_modelfile_by_tag_name(
  72. form_data: ModelfileUpdateForm, user=Depends(get_current_user)
  73. ):
  74. if user.role != "admin":
  75. raise HTTPException(
  76. status_code=status.HTTP_401_UNAUTHORIZED,
  77. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  78. )
  79. modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
  80. if modelfile:
  81. updated_modelfile = {
  82. **json.loads(modelfile.modelfile),
  83. **form_data.modelfile,
  84. }
  85. modelfile = Modelfiles.update_modelfile_by_tag_name(
  86. form_data.tag_name, updated_modelfile
  87. )
  88. return ModelfileResponse(
  89. **{
  90. **modelfile.model_dump(),
  91. "modelfile": json.loads(modelfile.modelfile),
  92. }
  93. )
  94. else:
  95. raise HTTPException(
  96. status_code=status.HTTP_401_UNAUTHORIZED,
  97. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  98. )
  99. ############################
  100. # DeleteModelfileByTagName
  101. ############################
  102. @router.delete("/delete", response_model=bool)
  103. async def delete_modelfile_by_tag_name(
  104. form_data: ModelfileTagNameForm, user=Depends(get_current_user)
  105. ):
  106. if user.role != "admin":
  107. raise HTTPException(
  108. status_code=status.HTTP_401_UNAUTHORIZED,
  109. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  110. )
  111. result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
  112. return result