main.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. from fastapi import FastAPI, Request, Response, HTTPException, Depends
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from fastapi.responses import StreamingResponse
  4. from fastapi.concurrency import run_in_threadpool
  5. import requests
  6. import json
  7. import uuid
  8. from pydantic import BaseModel
  9. from apps.web.models.users import Users
  10. from constants import ERROR_MESSAGES
  11. from utils.utils import decode_token, get_current_user
  12. from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
  13. app = FastAPI()
  14. app.add_middleware(
  15. CORSMiddleware,
  16. allow_origins=["*"],
  17. allow_credentials=True,
  18. allow_methods=["*"],
  19. allow_headers=["*"],
  20. )
  21. app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
  22. # TARGET_SERVER_URL = OLLAMA_API_BASE_URL
  23. REQUEST_POOL = []
  24. @app.get("/url")
  25. async def get_ollama_api_url(user=Depends(get_current_user)):
  26. if user and user.role == "admin":
  27. return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
  28. else:
  29. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  30. class UrlUpdateForm(BaseModel):
  31. url: str
  32. @app.post("/url/update")
  33. async def update_ollama_api_url(
  34. form_data: UrlUpdateForm, user=Depends(get_current_user)
  35. ):
  36. if user and user.role == "admin":
  37. app.state.OLLAMA_API_BASE_URL = form_data.url
  38. return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
  39. else:
  40. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  41. @app.get("/cancel/{request_id}")
  42. async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
  43. if user:
  44. if request_id in REQUEST_POOL:
  45. REQUEST_POOL.remove(request_id)
  46. return True
  47. else:
  48. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  49. @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
  50. async def proxy(path: str, request: Request, user=Depends(get_current_user)):
  51. target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
  52. body = await request.body()
  53. headers = dict(request.headers)
  54. if user.role in ["user", "admin"]:
  55. if path in ["pull", "delete", "push", "copy", "create"]:
  56. if user.role != "admin":
  57. raise HTTPException(
  58. status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
  59. )
  60. else:
  61. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  62. headers.pop("host", None)
  63. headers.pop("authorization", None)
  64. headers.pop("origin", None)
  65. headers.pop("referer", None)
  66. r = None
  67. def get_request():
  68. nonlocal r
  69. request_id = str(uuid.uuid4())
  70. try:
  71. REQUEST_POOL.append(request_id)
  72. def stream_content():
  73. try:
  74. if path in ["chat"]:
  75. yield json.dumps({"id": request_id, "done": False}) + "\n"
  76. for chunk in r.iter_content(chunk_size=8192):
  77. if request_id in REQUEST_POOL:
  78. yield chunk
  79. else:
  80. print("User: canceled request")
  81. break
  82. finally:
  83. if hasattr(r, "close"):
  84. r.close()
  85. REQUEST_POOL.remove(request_id)
  86. r = requests.request(
  87. method=request.method,
  88. url=target_url,
  89. data=body,
  90. headers=headers,
  91. stream=True,
  92. )
  93. r.raise_for_status()
  94. # r.close()
  95. return StreamingResponse(
  96. stream_content(),
  97. status_code=r.status_code,
  98. headers=dict(r.headers),
  99. )
  100. except Exception as e:
  101. raise e
  102. try:
  103. return await run_in_threadpool(get_request)
  104. except Exception as e:
  105. error_detail = "Ollama WebUI: Server Connection Error"
  106. if r is not None:
  107. try:
  108. res = r.json()
  109. if "error" in res:
  110. error_detail = f"Ollama: {res['error']}"
  111. except:
  112. error_detail = f"Ollama: {e}"
  113. raise HTTPException(
  114. status_code=r.status_code if r else 500,
  115. detail=error_detail,
  116. )