models.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from fastapi import Depends, FastAPI, HTTPException, status, Request
  2. from datetime import datetime, timedelta
  3. from typing import List, Union, Optional
  4. from fastapi import APIRouter
  5. from pydantic import BaseModel
  6. import json
  7. from apps.webui.internal.db import get_db
  8. from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
  9. from utils.utils import get_verified_user, get_admin_user
  10. from constants import ERROR_MESSAGES
  11. router = APIRouter()
  12. ###########################
  13. # getModels
  14. ###########################
  15. @router.get("/", response_model=List[ModelResponse])
  16. async def get_models(user=Depends(get_verified_user), db=Depends(get_db)):
  17. return Models.get_all_models(db)
  18. ############################
  19. # AddNewModel
  20. ############################
  21. @router.post("/add", response_model=Optional[ModelModel])
  22. async def add_new_model(
  23. request: Request,
  24. form_data: ModelForm,
  25. user=Depends(get_admin_user),
  26. db=Depends(get_db),
  27. ):
  28. if form_data.id in request.app.state.MODELS:
  29. raise HTTPException(
  30. status_code=status.HTTP_401_UNAUTHORIZED,
  31. detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
  32. )
  33. else:
  34. model = Models.insert_new_model(db, form_data, user.id)
  35. if model:
  36. return model
  37. else:
  38. raise HTTPException(
  39. status_code=status.HTTP_401_UNAUTHORIZED,
  40. detail=ERROR_MESSAGES.DEFAULT(),
  41. )
  42. ############################
  43. # GetModelById
  44. ############################
  45. @router.get("/{id}", response_model=Optional[ModelModel])
  46. async def get_model_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)):
  47. model = Models.get_model_by_id(db, id)
  48. if model:
  49. return model
  50. else:
  51. raise HTTPException(
  52. status_code=status.HTTP_401_UNAUTHORIZED,
  53. detail=ERROR_MESSAGES.NOT_FOUND,
  54. )
  55. ############################
  56. # UpdateModelById
  57. ############################
  58. @router.post("/update", response_model=Optional[ModelModel])
  59. async def update_model_by_id(
  60. request: Request,
  61. id: str,
  62. form_data: ModelForm,
  63. user=Depends(get_admin_user),
  64. db=Depends(get_db),
  65. ):
  66. model = Models.get_model_by_id(db, id)
  67. if model:
  68. model = Models.update_model_by_id(db, id, form_data)
  69. return model
  70. else:
  71. if form_data.id in request.app.state.MODELS:
  72. model = Models.insert_new_model(db, form_data, user.id)
  73. if model:
  74. return model
  75. else:
  76. raise HTTPException(
  77. status_code=status.HTTP_401_UNAUTHORIZED,
  78. detail=ERROR_MESSAGES.DEFAULT(),
  79. )
  80. else:
  81. raise HTTPException(
  82. status_code=status.HTTP_401_UNAUTHORIZED,
  83. detail=ERROR_MESSAGES.DEFAULT(),
  84. )
  85. ############################
  86. # DeleteModelById
  87. ############################
  88. @router.delete("/delete", response_model=bool)
  89. async def delete_model_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)):
  90. result = Models.delete_model_by_id(db, id)
  91. return result