main.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. from fastapi import FastAPI, Request, Response, HTTPException, Depends
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
  4. import requests
  5. import json
  6. from pydantic import BaseModel
  7. from apps.web.models.users import Users
  8. from constants import ERROR_MESSAGES
  9. from utils.utils import (
  10. decode_token,
  11. get_current_user,
  12. get_verified_user,
  13. get_admin_user,
  14. )
  15. from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR
  16. import hashlib
  17. from pathlib import Path
  18. app = FastAPI()
  19. app.add_middleware(
  20. CORSMiddleware,
  21. allow_origins=["*"],
  22. allow_credentials=True,
  23. allow_methods=["*"],
  24. allow_headers=["*"],
  25. )
  26. app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL
  27. app.state.OPENAI_API_KEY = OPENAI_API_KEY
  28. class UrlUpdateForm(BaseModel):
  29. url: str
  30. class KeyUpdateForm(BaseModel):
  31. key: str
  32. @app.get("/url")
  33. async def get_openai_url(user=Depends(get_admin_user)):
  34. return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
  35. @app.post("/url/update")
  36. async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
  37. app.state.OPENAI_API_BASE_URL = form_data.url
  38. return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
  39. @app.get("/key")
  40. async def get_openai_key(user=Depends(get_admin_user)):
  41. return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
  42. @app.post("/key/update")
  43. async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_admin_user)):
  44. app.state.OPENAI_API_KEY = form_data.key
  45. return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
  46. @app.post("/audio/speech")
  47. async def speech(request: Request, user=Depends(get_verified_user)):
  48. target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech"
  49. if app.state.OPENAI_API_KEY == "":
  50. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  51. body = await request.body()
  52. name = hashlib.sha256(body).hexdigest()
  53. SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
  54. SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  55. file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
  56. file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
  57. # Check if the file already exists in the cache
  58. if file_path.is_file():
  59. return FileResponse(file_path)
  60. headers = {}
  61. headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
  62. headers["Content-Type"] = "application/json"
  63. try:
  64. print("openai")
  65. r = requests.post(
  66. url=target_url,
  67. data=body,
  68. headers=headers,
  69. stream=True,
  70. )
  71. r.raise_for_status()
  72. # Save the streaming content to a file
  73. with open(file_path, "wb") as f:
  74. for chunk in r.iter_content(chunk_size=8192):
  75. f.write(chunk)
  76. with open(file_body_path, "w") as f:
  77. json.dump(json.loads(body.decode("utf-8")), f)
  78. # Return the saved file
  79. return FileResponse(file_path)
  80. except Exception as e:
  81. print(e)
  82. error_detail = "Open WebUI: Server Connection Error"
  83. if r is not None:
  84. try:
  85. res = r.json()
  86. if "error" in res:
  87. error_detail = f"External: {res['error']}"
  88. except:
  89. error_detail = f"External: {e}"
  90. raise HTTPException(status_code=r.status_code, detail=error_detail)
  91. @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
  92. async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
  93. target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
  94. print(target_url, app.state.OPENAI_API_KEY)
  95. if app.state.OPENAI_API_KEY == "":
  96. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  97. body = await request.body()
  98. # TODO: Remove below after gpt-4-vision fix from Open AI
  99. # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
  100. try:
  101. body = body.decode("utf-8")
  102. body = json.loads(body)
  103. # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
  104. # This is a workaround until OpenAI fixes the issue with this model
  105. if body.get("model") == "gpt-4-vision-preview":
  106. if "max_tokens" not in body:
  107. body["max_tokens"] = 4000
  108. print("Modified body_dict:", body)
  109. # Convert the modified body back to JSON
  110. body = json.dumps(body)
  111. except json.JSONDecodeError as e:
  112. print("Error loading request body into a dictionary:", e)
  113. headers = {}
  114. headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
  115. headers["Content-Type"] = "application/json"
  116. try:
  117. r = requests.request(
  118. method=request.method,
  119. url=target_url,
  120. data=body,
  121. headers=headers,
  122. stream=True,
  123. )
  124. r.raise_for_status()
  125. # Check if response is SSE
  126. if "text/event-stream" in r.headers.get("Content-Type", ""):
  127. return StreamingResponse(
  128. r.iter_content(chunk_size=8192),
  129. status_code=r.status_code,
  130. headers=dict(r.headers),
  131. )
  132. else:
  133. # For non-SSE, read the response and return it
  134. # response_data = (
  135. # r.json()
  136. # if r.headers.get("Content-Type", "")
  137. # == "application/json"
  138. # else r.text
  139. # )
  140. response_data = r.json()
  141. if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
  142. response_data["data"] = list(
  143. filter(lambda model: "gpt" in model["id"], response_data["data"])
  144. )
  145. return response_data
  146. except Exception as e:
  147. print(e)
  148. error_detail = "Open WebUI: Server Connection Error"
  149. if r is not None:
  150. try:
  151. res = r.json()
  152. if "error" in res:
  153. error_detail = f"External: {res['error']}"
  154. except:
  155. error_detail = f"External: {e}"
  156. raise HTTPException(status_code=r.status_code, detail=error_detail)