models.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from fastapi import Depends, FastAPI, HTTPException, status, Request
  2. from datetime import datetime, timedelta
  3. from typing import List, Union, Optional
  4. from fastapi import APIRouter
  5. from pydantic import BaseModel
  6. import json
  7. from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
  8. from utils.utils import get_verified_user, get_admin_user
  9. from constants import ERROR_MESSAGES
  10. router = APIRouter()
  11. ###########################
  12. # getModels
  13. ###########################
  14. @router.get("/", response_model=List[ModelResponse])
  15. async def get_models(user=Depends(get_verified_user)):
  16. return Models.get_all_models()
  17. ############################
  18. # AddNewModel
  19. ############################
  20. @router.post("/add", response_model=Optional[ModelModel])
  21. async def add_new_model(
  22. request: Request,
  23. form_data: ModelForm,
  24. user=Depends(get_admin_user),
  25. ):
  26. if form_data.id in request.app.state.MODELS:
  27. raise HTTPException(
  28. status_code=status.HTTP_401_UNAUTHORIZED,
  29. detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
  30. )
  31. else:
  32. model = Models.insert_new_model(form_data, user.id)
  33. if model:
  34. return model
  35. else:
  36. raise HTTPException(
  37. status_code=status.HTTP_401_UNAUTHORIZED,
  38. detail=ERROR_MESSAGES.DEFAULT(),
  39. )
  40. ############################
  41. # GetModelById
  42. ############################
  43. @router.get("/", response_model=Optional[ModelModel])
  44. async def get_model_by_id(id: str, user=Depends(get_verified_user)):
  45. model = Models.get_model_by_id(id)
  46. if model:
  47. return model
  48. else:
  49. raise HTTPException(
  50. status_code=status.HTTP_401_UNAUTHORIZED,
  51. detail=ERROR_MESSAGES.NOT_FOUND,
  52. )
  53. ############################
  54. # UpdateModelById
  55. ############################
  56. @router.post("/update", response_model=Optional[ModelModel])
  57. async def update_model_by_id(
  58. request: Request,
  59. id: str,
  60. form_data: ModelForm,
  61. user=Depends(get_admin_user),
  62. ):
  63. model = Models.get_model_by_id(id)
  64. if model:
  65. model = Models.update_model_by_id(id, form_data)
  66. return model
  67. else:
  68. if form_data.id in request.app.state.MODELS:
  69. model = Models.insert_new_model(form_data, user.id)
  70. if model:
  71. return model
  72. else:
  73. raise HTTPException(
  74. status_code=status.HTTP_401_UNAUTHORIZED,
  75. detail=ERROR_MESSAGES.DEFAULT(),
  76. )
  77. else:
  78. raise HTTPException(
  79. status_code=status.HTTP_401_UNAUTHORIZED,
  80. detail=ERROR_MESSAGES.DEFAULT(),
  81. )
  82. ############################
  83. # DeleteModelById
  84. ############################
  85. @router.delete("/delete", response_model=bool)
  86. async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
  87. result = Models.delete_model_by_id(id)
  88. return result