main.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. from fastapi import FastAPI, Request, Response, HTTPException, Depends
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from fastapi.responses import StreamingResponse, JSONResponse
  4. import requests
  5. import json
  6. from pydantic import BaseModel
  7. from apps.web.models.users import Users
  8. from constants import ERROR_MESSAGES
  9. from utils.utils import decode_token, get_current_user
  10. from config import OPENAI_API_BASE_URL, OPENAI_API_KEY
  11. app = FastAPI()
  12. app.add_middleware(
  13. CORSMiddleware,
  14. allow_origins=["*"],
  15. allow_credentials=True,
  16. allow_methods=["*"],
  17. allow_headers=["*"],
  18. )
  19. app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL
  20. app.state.OPENAI_API_KEY = OPENAI_API_KEY
  21. class UrlUpdateForm(BaseModel):
  22. url: str
  23. class KeyUpdateForm(BaseModel):
  24. key: str
  25. @app.get("/url")
  26. async def get_openai_url(user=Depends(get_current_user)):
  27. if user and user.role == "admin":
  28. return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
  29. else:
  30. raise HTTPException(status_code=401,
  31. detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  32. @app.post("/url/update")
  33. async def update_openai_url(form_data: UrlUpdateForm,
  34. user=Depends(get_current_user)):
  35. if user and user.role == "admin":
  36. app.state.OPENAI_API_BASE_URL = form_data.url
  37. return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
  38. else:
  39. raise HTTPException(status_code=401,
  40. detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  41. @app.get("/key")
  42. async def get_openai_key(user=Depends(get_current_user)):
  43. if user and user.role == "admin":
  44. return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
  45. else:
  46. raise HTTPException(status_code=401,
  47. detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  48. @app.post("/key/update")
  49. async def update_openai_key(form_data: KeyUpdateForm,
  50. user=Depends(get_current_user)):
  51. if user and user.role == "admin":
  52. app.state.OPENAI_API_KEY = form_data.key
  53. return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
  54. else:
  55. raise HTTPException(status_code=401,
  56. detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  57. @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
  58. async def proxy(path: str, request: Request, user=Depends(get_current_user)):
  59. target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
  60. print(target_url, app.state.OPENAI_API_KEY)
  61. if user.role not in ["user", "admin"]:
  62. raise HTTPException(status_code=401,
  63. detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  64. if app.state.OPENAI_API_KEY == "":
  65. raise HTTPException(status_code=401,
  66. detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  67. # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
  68. try:
  69. body_str = (await request.body()).decode('utf-8')
  70. except UnicodeDecodeError as e:
  71. print("Error decoding request body:", e)
  72. raise HTTPException(status_code=400, detail="Invalid request body")
  73. # Check if the body is not empty
  74. if body_str:
  75. try:
  76. body_dict = json.loads(body_str)
  77. except json.JSONDecodeError as e:
  78. print("Error loading request body into a dictionary:", e)
  79. raise HTTPException(status_code=400, detail="Invalid JSON in request body")
  80. # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 10000
  81. # This is a workaround until OpenAI fixes the issue with this model
  82. if body_dict.get("model") == "gpt-4-vision-preview":
  83. body_dict["max_tokens"] = 10000
  84. print("Modified body_dict:", body_dict)
  85. # Try to convert the modified body back to JSON
  86. try:
  87. # Convert the modified body back to JSON
  88. body_json = json.dumps(body_dict)
  89. except TypeError as e:
  90. print("Error converting modified body to JSON:", e)
  91. raise HTTPException(status_code=500, detail="Internal server error")
  92. else:
  93. body_json = body_str # If the body is empty, use it as is
  94. headers = {}
  95. headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
  96. headers["Content-Type"] = "application/json"
  97. try:
  98. r = requests.request(
  99. method=request.method,
  100. url=target_url,
  101. data=body_json,
  102. headers=headers,
  103. stream=True,
  104. )
  105. r.raise_for_status()
  106. # Check if response is SSE
  107. if "text/event-stream" in r.headers.get("Content-Type", ""):
  108. return StreamingResponse(
  109. r.iter_content(chunk_size=8192),
  110. status_code=r.status_code,
  111. headers=dict(r.headers),
  112. )
  113. else:
  114. # For non-SSE, read the response and return it
  115. # response_data = (
  116. # r.json()
  117. # if r.headers.get("Content-Type", "")
  118. # == "application/json"
  119. # else r.text
  120. # )
  121. response_data = r.json()
  122. print(type(response_data))
  123. if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
  124. response_data["data"] = list(
  125. filter(lambda model: "gpt" in model["id"],
  126. response_data["data"]))
  127. return response_data
  128. except Exception as e:
  129. print(e)
  130. error_detail = "Ollama WebUI: Server Connection Error"
  131. if r is not None:
  132. try:
  133. res = r.json()
  134. if "error" in res:
  135. error_detail = f"External: {res['error']}"
  136. except:
  137. error_detail = f"External: {e}"
  138. raise HTTPException(status_code=r.status_code, detail=error_detail)