main.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. from fastapi import FastAPI, Depends, HTTPException
  2. from fastapi.routing import APIRoute
  3. from fastapi.middleware.cors import CORSMiddleware
  4. import logging
  5. from fastapi import FastAPI, Request, Depends, status, Response
  6. from fastapi.responses import JSONResponse
  7. from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
  8. from starlette.responses import StreamingResponse
  9. import json
  10. import requests
  11. from utils.utils import get_verified_user, get_current_user
  12. from config import SRC_LOG_LEVELS, ENV
  13. from constants import ERROR_MESSAGES
  14. log = logging.getLogger(__name__)
  15. log.setLevel(SRC_LOG_LEVELS["LITELLM"])
  16. from config import (
  17. MODEL_FILTER_ENABLED,
  18. MODEL_FILTER_LIST,
  19. )
  20. import asyncio
  21. import subprocess
  22. app = FastAPI()
  23. origins = ["*"]
  24. app.add_middleware(
  25. CORSMiddleware,
  26. allow_origins=origins,
  27. allow_credentials=True,
  28. allow_methods=["*"],
  29. allow_headers=["*"],
  30. )
  31. # Global variable to store the subprocess reference
  32. background_process = None
  33. async def run_background_process(command):
  34. global background_process
  35. print("run_background_process")
  36. try:
  37. # Log the command to be executed
  38. print(f"Executing command: {command}")
  39. # Execute the command and create a subprocess
  40. process = await asyncio.create_subprocess_exec(
  41. *command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE
  42. )
  43. background_process = process
  44. print("Subprocess started successfully.")
  45. # Capture STDERR for debugging purposes
  46. stderr_output = await process.stderr.read()
  47. stderr_text = stderr_output.decode().strip()
  48. if stderr_text:
  49. print(f"Subprocess STDERR: {stderr_text}")
  50. # Print output line by line
  51. async for line in process.stdout:
  52. print(line.decode().strip())
  53. # Wait for the process to finish
  54. returncode = await process.wait()
  55. print(f"Subprocess exited with return code {returncode}")
  56. except Exception as e:
  57. log.error(f"Failed to start subprocess: {e}")
  58. raise # Optionally re-raise the exception if you want it to propagate
  59. async def start_litellm_background():
  60. print("start_litellm_background")
  61. # Command to run in the background
  62. command = "litellm --telemetry False --config ./data/litellm/config.yaml"
  63. await run_background_process(command)
  64. async def shutdown_litellm_background():
  65. print("shutdown_litellm_background")
  66. global background_process
  67. if background_process:
  68. background_process.terminate()
  69. await background_process.wait() # Ensure the process has terminated
  70. print("Subprocess terminated")
  71. @app.on_event("startup")
  72. async def startup_event():
  73. print("startup_event")
  74. # TODO: Check config.yaml file and create one
  75. asyncio.create_task(start_litellm_background())
  76. app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
  77. app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
  78. @app.get("/")
  79. async def get_status():
  80. return {"status": True}
  81. @app.get("/models")
  82. @app.get("/v1/models")
  83. async def get_models(user=Depends(get_current_user)):
  84. url = "http://localhost:4000/v1"
  85. r = None
  86. try:
  87. r = requests.request(method="GET", url=f"{url}/models")
  88. r.raise_for_status()
  89. data = r.json()
  90. if app.state.MODEL_FILTER_ENABLED:
  91. if user and user.role == "user":
  92. data["data"] = list(
  93. filter(
  94. lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
  95. data["data"],
  96. )
  97. )
  98. return data
  99. except Exception as e:
  100. log.exception(e)
  101. error_detail = "Open WebUI: Server Connection Error"
  102. if r is not None:
  103. try:
  104. res = r.json()
  105. if "error" in res:
  106. error_detail = f"External: {res['error']}"
  107. except:
  108. error_detail = f"External: {e}"
  109. raise HTTPException(
  110. status_code=r.status_code if r else 500,
  111. detail=error_detail,
  112. )
  113. @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
  114. async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
  115. body = await request.body()
  116. url = "http://localhost:4000"
  117. target_url = f"{url}/{path}"
  118. headers = {}
  119. # headers["Authorization"] = f"Bearer {key}"
  120. headers["Content-Type"] = "application/json"
  121. r = None
  122. try:
  123. r = requests.request(
  124. method=request.method,
  125. url=target_url,
  126. data=body,
  127. headers=headers,
  128. stream=True,
  129. )
  130. r.raise_for_status()
  131. # Check if response is SSE
  132. if "text/event-stream" in r.headers.get("Content-Type", ""):
  133. return StreamingResponse(
  134. r.iter_content(chunk_size=8192),
  135. status_code=r.status_code,
  136. headers=dict(r.headers),
  137. )
  138. else:
  139. response_data = r.json()
  140. return response_data
  141. except Exception as e:
  142. log.exception(e)
  143. error_detail = "Open WebUI: Server Connection Error"
  144. if r is not None:
  145. try:
  146. res = r.json()
  147. if "error" in res:
  148. error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
  149. except:
  150. error_detail = f"External: {e}"
  151. raise HTTPException(
  152. status_code=r.status_code if r else 500, detail=error_detail
  153. )
  154. # class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
  155. # async def dispatch(
  156. # self, request: Request, call_next: RequestResponseEndpoint
  157. # ) -> Response:
  158. # response = await call_next(request)
  159. # user = request.state.user
  160. # if "/models" in request.url.path:
  161. # if isinstance(response, StreamingResponse):
  162. # # Read the content of the streaming response
  163. # body = b""
  164. # async for chunk in response.body_iterator:
  165. # body += chunk
  166. # data = json.loads(body.decode("utf-8"))
  167. # if app.state.MODEL_FILTER_ENABLED:
  168. # if user and user.role == "user":
  169. # data["data"] = list(
  170. # filter(
  171. # lambda model: model["id"]
  172. # in app.state.MODEL_FILTER_LIST,
  173. # data["data"],
  174. # )
  175. # )
  176. # # Modified Flag
  177. # data["modified"] = True
  178. # return JSONResponse(content=data)
  179. # return response
  180. # app.add_middleware(ModifyModelsResponseMiddleware)