main.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from fastapi import FastAPI, Request, Response, HTTPException, Depends, status
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from fastapi.responses import StreamingResponse
  4. from fastapi.concurrency import run_in_threadpool
  5. import requests
  6. import json
  7. import uuid
  8. from pydantic import BaseModel
  9. from apps.web.models.users import Users
  10. from constants import ERROR_MESSAGES
  11. from utils.utils import decode_token, get_current_user, get_admin_user
  12. from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
  13. app = FastAPI()
  14. app.add_middleware(
  15. CORSMiddleware,
  16. allow_origins=["*"],
  17. allow_credentials=True,
  18. allow_methods=["*"],
  19. allow_headers=["*"],
  20. )
  21. app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
  22. # TARGET_SERVER_URL = OLLAMA_API_BASE_URL
  23. REQUEST_POOL = []
  24. @app.get("/url")
  25. async def get_ollama_api_url(user=Depends(get_admin_user)):
  26. return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
  27. class UrlUpdateForm(BaseModel):
  28. url: str
  29. @app.post("/url/update")
  30. async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
  31. app.state.OLLAMA_API_BASE_URL = form_data.url
  32. return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
  33. @app.get("/cancel/{request_id}")
  34. async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
  35. if user:
  36. if request_id in REQUEST_POOL:
  37. REQUEST_POOL.remove(request_id)
  38. return True
  39. else:
  40. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  41. @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
  42. async def proxy(path: str, request: Request, user=Depends(get_current_user)):
  43. target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
  44. body = await request.body()
  45. headers = dict(request.headers)
  46. if user.role in ["user", "admin"]:
  47. if path in ["pull", "delete", "push", "copy", "create"]:
  48. if user.role != "admin":
  49. raise HTTPException(
  50. status_code=status.HTTP_401_UNAUTHORIZED,
  51. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  52. )
  53. else:
  54. raise HTTPException(
  55. status_code=status.HTTP_401_UNAUTHORIZED,
  56. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  57. )
  58. headers.pop("host", None)
  59. headers.pop("authorization", None)
  60. headers.pop("origin", None)
  61. headers.pop("referer", None)
  62. r = None
  63. def get_request():
  64. nonlocal r
  65. request_id = str(uuid.uuid4())
  66. try:
  67. REQUEST_POOL.append(request_id)
  68. def stream_content():
  69. try:
  70. if path in ["chat"]:
  71. yield json.dumps({"id": request_id, "done": False}) + "\n"
  72. for chunk in r.iter_content(chunk_size=8192):
  73. if request_id in REQUEST_POOL:
  74. yield chunk
  75. else:
  76. print("User: canceled request")
  77. break
  78. finally:
  79. if hasattr(r, "close"):
  80. r.close()
  81. REQUEST_POOL.remove(request_id)
  82. r = requests.request(
  83. method=request.method,
  84. url=target_url,
  85. data=body,
  86. headers=headers,
  87. stream=True,
  88. )
  89. r.raise_for_status()
  90. # r.close()
  91. return StreamingResponse(
  92. stream_content(),
  93. status_code=r.status_code,
  94. headers=dict(r.headers),
  95. )
  96. except Exception as e:
  97. raise e
  98. try:
  99. return await run_in_threadpool(get_request)
  100. except Exception as e:
  101. error_detail = "Open WebUI: Server Connection Error"
  102. if r is not None:
  103. try:
  104. res = r.json()
  105. if "error" in res:
  106. error_detail = f"Ollama: {res['error']}"
  107. except:
  108. error_detail = f"Ollama: {e}"
  109. raise HTTPException(
  110. status_code=r.status_code if r else 500,
  111. detail=error_detail,
  112. )