Browse Source

feat: restart subprocess route

Timothy J. Baek 1 year ago
parent
commit
51191168bc
1 changed files with 27 additions and 38 deletions
  1. 27 38
      backend/apps/litellm/main.py

+ 27 - 38
backend/apps/litellm/main.py

@@ -11,7 +11,7 @@ from starlette.responses import StreamingResponse
 import json
 import requests
 
-from utils.utils import get_verified_user, get_current_user
+from utils.utils import get_verified_user, get_current_user, get_admin_user
 from config import SRC_LOG_LEVELS, ENV
 from constants import ERROR_MESSAGES
 
@@ -112,6 +112,32 @@ async def get_status():
     return {"status": True}
 
 
+@app.get("/restart")
+async def restart_litellm(user=Depends(get_admin_user)):
+    """
+    Endpoint to restart the litellm background service.
+    """
+    log.info("Requested restart of litellm service.")
+    try:
+        # Shut down the existing process if it is running
+        await shutdown_litellm_background()
+        log.info("litellm service shutdown complete.")
+
+        # Restart the background service
+        await start_litellm_background()
+        log.info("litellm service restart complete.")
+
+        return {
+            "status": "success",
+            "message": "litellm service restarted successfully.",
+        }
+    except Exception as e:
+        log.error(f"Error restarting litellm service: {e}")
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
+        )
+
+
 @app.get("/models")
 @app.get("/v1/models")
 async def get_models(user=Depends(get_current_user)):
@@ -199,40 +225,3 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
         raise HTTPException(
             status_code=r.status_code if r else 500, detail=error_detail
         )
-
-
-# class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
-#     async def dispatch(
-#         self, request: Request, call_next: RequestResponseEndpoint
-#     ) -> Response:
-
-#         response = await call_next(request)
-#         user = request.state.user
-
-#         if "/models" in request.url.path:
-#             if isinstance(response, StreamingResponse):
-#                 # Read the content of the streaming response
-#                 body = b""
-#                 async for chunk in response.body_iterator:
-#                     body += chunk
-
-#                 data = json.loads(body.decode("utf-8"))
-
-#                 if app.state.MODEL_FILTER_ENABLED:
-#                     if user and user.role == "user":
-#                         data["data"] = list(
-#                             filter(
-#                                 lambda model: model["id"]
-#                                 in app.state.MODEL_FILTER_LIST,
-#                                 data["data"],
-#                             )
-#                         )
-
-#                 # Modified Flag
-#                 data["modified"] = True
-#                 return JSONResponse(content=data)
-
-#         return response
-
-
-# app.add_middleware(ModifyModelsResponseMiddleware)