main.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. from fastapi import FastAPI, Request, Response, HTTPException, Depends
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
  4. import requests
  5. import json
  6. from pydantic import BaseModel
  7. from apps.web.models.users import Users
  8. from constants import ERROR_MESSAGES
  9. from utils.utils import decode_token, get_current_user
  10. from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR
  11. import hashlib
  12. from pathlib import Path
  13. app = FastAPI()
  14. app.add_middleware(
  15. CORSMiddleware,
  16. allow_origins=["*"],
  17. allow_credentials=True,
  18. allow_methods=["*"],
  19. allow_headers=["*"],
  20. )
  21. app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL
  22. app.state.OPENAI_API_KEY = OPENAI_API_KEY
  23. class UrlUpdateForm(BaseModel):
  24. url: str
  25. class KeyUpdateForm(BaseModel):
  26. key: str
  27. @app.get("/url")
  28. async def get_openai_url(user=Depends(get_current_user)):
  29. if user and user.role == "admin":
  30. return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
  31. else:
  32. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  33. @app.post("/url/update")
  34. async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_user)):
  35. if user and user.role == "admin":
  36. app.state.OPENAI_API_BASE_URL = form_data.url
  37. return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
  38. else:
  39. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  40. @app.get("/key")
  41. async def get_openai_key(user=Depends(get_current_user)):
  42. if user and user.role == "admin":
  43. return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
  44. else:
  45. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  46. @app.post("/key/update")
  47. async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_user)):
  48. if user and user.role == "admin":
  49. app.state.OPENAI_API_KEY = form_data.key
  50. return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
  51. else:
  52. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  53. @app.post("/audio/speech")
  54. async def speech(request: Request, user=Depends(get_current_user)):
  55. target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech"
  56. if user.role not in ["user", "admin"]:
  57. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  58. if app.state.OPENAI_API_KEY == "":
  59. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  60. body = await request.body()
  61. print(body)
  62. print(type(body))
  63. name = hashlib.sha256(body).hexdigest()
  64. SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
  65. SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  66. file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
  67. file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
  68. print(file_path)
  69. # Check if the file already exists in the cache
  70. if file_path.is_file():
  71. return FileResponse(file_path)
  72. headers = {}
  73. headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
  74. headers["Content-Type"] = "application/json"
  75. try:
  76. print("openai")
  77. r = requests.post(
  78. url=target_url,
  79. data=body,
  80. headers=headers,
  81. stream=True,
  82. )
  83. print(r)
  84. r.raise_for_status()
  85. # Save the streaming content to a file
  86. with open(file_path, "wb") as f:
  87. for chunk in r.iter_content(chunk_size=8192):
  88. f.write(chunk)
  89. with open(file_body_path, "w") as f:
  90. json.dump(json.loads(body.decode("utf-8")), f)
  91. # Return the saved file
  92. return FileResponse(file_path)
  93. # return StreamingResponse(
  94. # r.iter_content(chunk_size=8192),
  95. # status_code=r.status_code,
  96. # headers=dict(r.headers),
  97. # )
  98. except Exception as e:
  99. print(e)
  100. error_detail = "Ollama WebUI: Server Connection Error"
  101. if r is not None:
  102. try:
  103. res = r.json()
  104. if "error" in res:
  105. error_detail = f"External: {res['error']}"
  106. except:
  107. error_detail = f"External: {e}"
  108. raise HTTPException(status_code=r.status_code, detail=error_detail)
  109. @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
  110. async def proxy(path: str, request: Request, user=Depends(get_current_user)):
  111. target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
  112. print(target_url, app.state.OPENAI_API_KEY)
  113. if user.role not in ["user", "admin"]:
  114. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  115. if app.state.OPENAI_API_KEY == "":
  116. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  117. body = await request.body()
  118. # TODO: Remove below after gpt-4-vision fix from Open AI
  119. # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
  120. try:
  121. body = body.decode("utf-8")
  122. body = json.loads(body)
  123. # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
  124. # This is a workaround until OpenAI fixes the issue with this model
  125. if body.get("model") == "gpt-4-vision-preview":
  126. if "max_tokens" not in body:
  127. body["max_tokens"] = 4000
  128. print("Modified body_dict:", body)
  129. # Convert the modified body back to JSON
  130. body = json.dumps(body)
  131. except json.JSONDecodeError as e:
  132. print("Error loading request body into a dictionary:", e)
  133. headers = {}
  134. headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
  135. headers["Content-Type"] = "application/json"
  136. try:
  137. r = requests.request(
  138. method=request.method,
  139. url=target_url,
  140. data=body,
  141. headers=headers,
  142. stream=True,
  143. )
  144. r.raise_for_status()
  145. # Check if response is SSE
  146. if "text/event-stream" in r.headers.get("Content-Type", ""):
  147. return StreamingResponse(
  148. r.iter_content(chunk_size=8192),
  149. status_code=r.status_code,
  150. headers=dict(r.headers),
  151. )
  152. else:
  153. # For non-SSE, read the response and return it
  154. # response_data = (
  155. # r.json()
  156. # if r.headers.get("Content-Type", "")
  157. # == "application/json"
  158. # else r.text
  159. # )
  160. response_data = r.json()
  161. if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
  162. response_data["data"] = list(
  163. filter(lambda model: "gpt" in model["id"], response_data["data"])
  164. )
  165. return response_data
  166. except Exception as e:
  167. print(e)
  168. error_detail = "Ollama WebUI: Server Connection Error"
  169. if r is not None:
  170. try:
  171. res = r.json()
  172. if "error" in res:
  173. error_detail = f"External: {res['error']}"
  174. except:
  175. error_detail = f"External: {e}"
  176. raise HTTPException(status_code=r.status_code, detail=error_detail)