old_main.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from fastapi import FastAPI, Request, Response, HTTPException, Depends
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from fastapi.responses import StreamingResponse
  4. import requests
  5. import json
  6. from pydantic import BaseModel
  7. from apps.web.models.users import Users
  8. from constants import ERROR_MESSAGES
  9. from utils.utils import decode_token, get_current_user
  10. from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
  11. import aiohttp
  12. app = FastAPI()
  13. app.add_middleware(
  14. CORSMiddleware,
  15. allow_origins=["*"],
  16. allow_credentials=True,
  17. allow_methods=["*"],
  18. allow_headers=["*"],
  19. )
  20. app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
  21. # TARGET_SERVER_URL = OLLAMA_API_BASE_URL
  22. @app.get("/url")
  23. async def get_ollama_api_url(user=Depends(get_current_user)):
  24. if user and user.role == "admin":
  25. return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
  26. else:
  27. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  28. class UrlUpdateForm(BaseModel):
  29. url: str
  30. @app.post("/url/update")
  31. async def update_ollama_api_url(
  32. form_data: UrlUpdateForm, user=Depends(get_current_user)
  33. ):
  34. if user and user.role == "admin":
  35. app.state.OLLAMA_API_BASE_URL = form_data.url
  36. return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
  37. else:
  38. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  39. # async def fetch_sse(method, target_url, body, headers):
  40. # async with aiohttp.ClientSession() as session:
  41. # try:
  42. # async with session.request(
  43. # method, target_url, data=body, headers=headers
  44. # ) as response:
  45. # print(response.status)
  46. # async for line in response.content:
  47. # yield line
  48. # except Exception as e:
  49. # print(e)
  50. # error_detail = "Open WebUI: Server Connection Error"
  51. # yield json.dumps({"error": error_detail, "message": str(e)}).encode()
  52. @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
  53. async def proxy(path: str, request: Request, user=Depends(get_current_user)):
  54. target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
  55. print(target_url)
  56. body = await request.body()
  57. headers = dict(request.headers)
  58. if user.role in ["user", "admin"]:
  59. if path in ["pull", "delete", "push", "copy", "create"]:
  60. if user.role != "admin":
  61. raise HTTPException(
  62. status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
  63. )
  64. else:
  65. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  66. headers.pop("Host", None)
  67. headers.pop("Authorization", None)
  68. headers.pop("Origin", None)
  69. headers.pop("Referer", None)
  70. session = aiohttp.ClientSession()
  71. response = None
  72. try:
  73. response = await session.request(
  74. request.method, target_url, data=body, headers=headers
  75. )
  76. print(response)
  77. if not response.ok:
  78. data = await response.json()
  79. print(data)
  80. response.raise_for_status()
  81. async def generate():
  82. async for line in response.content:
  83. print(line)
  84. yield line
  85. await session.close()
  86. return StreamingResponse(generate(), response.status)
  87. except Exception as e:
  88. print(e)
  89. error_detail = "Open WebUI: Server Connection Error"
  90. if response is not None:
  91. try:
  92. res = await response.json()
  93. if "error" in res:
  94. error_detail = f"Ollama: {res['error']}"
  95. except:
  96. error_detail = f"Ollama: {e}"
  97. await session.close()
  98. raise HTTPException(
  99. status_code=response.status if response else 500,
  100. detail=error_detail,
  101. )