main.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. from fastapi import FastAPI, Depends, HTTPException
  2. from fastapi.routing import APIRoute
  3. from fastapi.middleware.cors import CORSMiddleware
  4. import logging
  5. from fastapi import FastAPI, Request, Depends, status, Response
  6. from fastapi.responses import JSONResponse
  7. from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
  8. from starlette.responses import StreamingResponse
  9. import json
  10. import requests
  11. from utils.utils import get_verified_user, get_current_user
  12. from config import SRC_LOG_LEVELS, ENV
  13. from constants import ERROR_MESSAGES
  14. log = logging.getLogger(__name__)
  15. log.setLevel(SRC_LOG_LEVELS["LITELLM"])
  16. from config import (
  17. MODEL_FILTER_ENABLED,
  18. MODEL_FILTER_LIST,
  19. )
  20. import asyncio
  21. import subprocess
  22. app = FastAPI()
  23. origins = ["*"]
  24. app.add_middleware(
  25. CORSMiddleware,
  26. allow_origins=origins,
  27. allow_credentials=True,
  28. allow_methods=["*"],
  29. allow_headers=["*"],
  30. )
  31. async def run_background_process(command):
  32. process = await asyncio.create_subprocess_exec(
  33. *command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE
  34. )
  35. return process
  36. async def start_litellm_background():
  37. # Command to run in the background
  38. command = "litellm --telemetry False --config ./data/litellm/config.yaml"
  39. await run_background_process(command)
  40. @app.on_event("startup")
  41. async def startup_event():
  42. # TODO: Check config.yaml file and create one
  43. asyncio.create_task(start_litellm_background())
  44. app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
  45. app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
  46. @app.get("/")
  47. async def get_status():
  48. return {"status": True}
  49. @app.get("/models")
  50. @app.get("/v1/models")
  51. async def get_models(user=Depends(get_current_user)):
  52. url = "http://localhost:4000/v1"
  53. r = None
  54. try:
  55. r = requests.request(method="GET", url=f"{url}/models")
  56. r.raise_for_status()
  57. data = r.json()
  58. if app.state.MODEL_FILTER_ENABLED:
  59. if user and user.role == "user":
  60. data["data"] = list(
  61. filter(
  62. lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
  63. data["data"],
  64. )
  65. )
  66. return data
  67. except Exception as e:
  68. log.exception(e)
  69. error_detail = "Open WebUI: Server Connection Error"
  70. if r is not None:
  71. try:
  72. res = r.json()
  73. if "error" in res:
  74. error_detail = f"External: {res['error']}"
  75. except:
  76. error_detail = f"External: {e}"
  77. raise HTTPException(
  78. status_code=r.status_code if r else 500,
  79. detail=error_detail,
  80. )
  81. @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
  82. async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
  83. body = await request.body()
  84. url = "http://localhost:4000/v1"
  85. target_url = f"{url}/{path}"
  86. headers = {}
  87. # headers["Authorization"] = f"Bearer {key}"
  88. headers["Content-Type"] = "application/json"
  89. r = None
  90. try:
  91. r = requests.request(
  92. method=request.method,
  93. url=target_url,
  94. data=body,
  95. headers=headers,
  96. stream=True,
  97. )
  98. r.raise_for_status()
  99. # Check if response is SSE
  100. if "text/event-stream" in r.headers.get("Content-Type", ""):
  101. return StreamingResponse(
  102. r.iter_content(chunk_size=8192),
  103. status_code=r.status_code,
  104. headers=dict(r.headers),
  105. )
  106. else:
  107. response_data = r.json()
  108. return response_data
  109. except Exception as e:
  110. log.exception(e)
  111. error_detail = "Open WebUI: Server Connection Error"
  112. if r is not None:
  113. try:
  114. res = r.json()
  115. if "error" in res:
  116. error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
  117. except:
  118. error_detail = f"External: {e}"
  119. raise HTTPException(
  120. status_code=r.status_code if r else 500, detail=error_detail
  121. )
  122. # class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
  123. # async def dispatch(
  124. # self, request: Request, call_next: RequestResponseEndpoint
  125. # ) -> Response:
  126. # response = await call_next(request)
  127. # user = request.state.user
  128. # if "/models" in request.url.path:
  129. # if isinstance(response, StreamingResponse):
  130. # # Read the content of the streaming response
  131. # body = b""
  132. # async for chunk in response.body_iterator:
  133. # body += chunk
  134. # data = json.loads(body.decode("utf-8"))
  135. # if app.state.MODEL_FILTER_ENABLED:
  136. # if user and user.role == "user":
  137. # data["data"] = list(
  138. # filter(
  139. # lambda model: model["id"]
  140. # in app.state.MODEL_FILTER_LIST,
  141. # data["data"],
  142. # )
  143. # )
  144. # # Modified Flag
  145. # data["modified"] = True
  146. # return JSONResponse(content=data)
  147. # return response
  148. # app.add_middleware(ModifyModelsResponseMiddleware)