main.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. from fastapi import FastAPI, Depends, HTTPException
  2. from fastapi.routing import APIRoute
  3. from fastapi.middleware.cors import CORSMiddleware
  4. import logging
  5. from fastapi import FastAPI, Request, Depends, status, Response
  6. from fastapi.responses import JSONResponse
  7. from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
  8. from starlette.responses import StreamingResponse
  9. import json
  10. import requests
  11. from utils.utils import get_verified_user, get_current_user
  12. from config import SRC_LOG_LEVELS, ENV
  13. from constants import ERROR_MESSAGES
  14. log = logging.getLogger(__name__)
  15. log.setLevel(SRC_LOG_LEVELS["LITELLM"])
  16. from config import (
  17. MODEL_FILTER_ENABLED,
  18. MODEL_FILTER_LIST,
  19. )
  20. import asyncio
  21. import subprocess
  22. app = FastAPI()
  23. origins = ["*"]
  24. app.add_middleware(
  25. CORSMiddleware,
  26. allow_origins=origins,
  27. allow_credentials=True,
  28. allow_methods=["*"],
  29. allow_headers=["*"],
  30. )
  31. async def run_background_process(command):
  32. # Start the process
  33. process = await asyncio.create_subprocess_exec(
  34. *command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE
  35. )
  36. # Read output asynchronously
  37. async for line in process.stdout:
  38. print(line.decode().strip()) # Print stdout line by line
  39. await process.wait() # Wait for the subprocess to finish
  40. async def start_litellm_background():
  41. print("start_litellm_background")
  42. # Command to run in the background
  43. command = "litellm --telemetry False --config ./data/litellm/config.yaml"
  44. await run_background_process(command)
  45. @app.on_event("startup")
  46. async def startup_event():
  47. print("startup_event")
  48. # TODO: Check config.yaml file and create one
  49. asyncio.create_task(start_litellm_background())
  50. app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
  51. app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
  52. @app.get("/")
  53. async def get_status():
  54. return {"status": True}
  55. @app.get("/models")
  56. @app.get("/v1/models")
  57. async def get_models(user=Depends(get_current_user)):
  58. url = "http://localhost:4000/v1"
  59. r = None
  60. try:
  61. r = requests.request(method="GET", url=f"{url}/models")
  62. r.raise_for_status()
  63. data = r.json()
  64. if app.state.MODEL_FILTER_ENABLED:
  65. if user and user.role == "user":
  66. data["data"] = list(
  67. filter(
  68. lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
  69. data["data"],
  70. )
  71. )
  72. return data
  73. except Exception as e:
  74. log.exception(e)
  75. error_detail = "Open WebUI: Server Connection Error"
  76. if r is not None:
  77. try:
  78. res = r.json()
  79. if "error" in res:
  80. error_detail = f"External: {res['error']}"
  81. except:
  82. error_detail = f"External: {e}"
  83. raise HTTPException(
  84. status_code=r.status_code if r else 500,
  85. detail=error_detail,
  86. )
  87. @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
  88. async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
  89. body = await request.body()
  90. url = "http://localhost:4000/v1"
  91. target_url = f"{url}/{path}"
  92. headers = {}
  93. # headers["Authorization"] = f"Bearer {key}"
  94. headers["Content-Type"] = "application/json"
  95. r = None
  96. try:
  97. r = requests.request(
  98. method=request.method,
  99. url=target_url,
  100. data=body,
  101. headers=headers,
  102. stream=True,
  103. )
  104. r.raise_for_status()
  105. # Check if response is SSE
  106. if "text/event-stream" in r.headers.get("Content-Type", ""):
  107. return StreamingResponse(
  108. r.iter_content(chunk_size=8192),
  109. status_code=r.status_code,
  110. headers=dict(r.headers),
  111. )
  112. else:
  113. response_data = r.json()
  114. return response_data
  115. except Exception as e:
  116. log.exception(e)
  117. error_detail = "Open WebUI: Server Connection Error"
  118. if r is not None:
  119. try:
  120. res = r.json()
  121. if "error" in res:
  122. error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
  123. except:
  124. error_detail = f"External: {e}"
  125. raise HTTPException(
  126. status_code=r.status_code if r else 500, detail=error_detail
  127. )
  128. # class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
  129. # async def dispatch(
  130. # self, request: Request, call_next: RequestResponseEndpoint
  131. # ) -> Response:
  132. # response = await call_next(request)
  133. # user = request.state.user
  134. # if "/models" in request.url.path:
  135. # if isinstance(response, StreamingResponse):
  136. # # Read the content of the streaming response
  137. # body = b""
  138. # async for chunk in response.body_iterator:
  139. # body += chunk
  140. # data = json.loads(body.decode("utf-8"))
  141. # if app.state.MODEL_FILTER_ENABLED:
  142. # if user and user.role == "user":
  143. # data["data"] = list(
  144. # filter(
  145. # lambda model: model["id"]
  146. # in app.state.MODEL_FILTER_LIST,
  147. # data["data"],
  148. # )
  149. # )
  150. # # Modified Flag
  151. # data["modified"] = True
  152. # return JSONResponse(content=data)
  153. # return response
  154. # app.add_middleware(ModifyModelsResponseMiddleware)