main.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. from fastapi import FastAPI, Request, Response, HTTPException, Depends, status
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from fastapi.responses import StreamingResponse
  4. from fastapi.concurrency import run_in_threadpool
  5. import requests
  6. import json
  7. import uuid
  8. from pydantic import BaseModel
  9. from apps.web.models.users import Users
  10. from constants import ERROR_MESSAGES
  11. from utils.utils import decode_token, get_current_user, get_admin_user
  12. from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
  13. app = FastAPI()
  14. app.add_middleware(
  15. CORSMiddleware,
  16. allow_origins=["*"],
  17. allow_credentials=True,
  18. allow_methods=["*"],
  19. allow_headers=["*"],
  20. )
  21. app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
  22. # TARGET_SERVER_URL = OLLAMA_API_BASE_URL
  23. REQUEST_POOL = []
  24. @app.get("/url")
  25. async def get_ollama_api_url(user=Depends(get_admin_user)):
  26. return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
  27. class UrlUpdateForm(BaseModel):
  28. url: str
  29. @app.post("/url/update")
  30. async def update_ollama_api_url(
  31. form_data: UrlUpdateForm, user=Depends(get_admin_user)
  32. ):
  33. app.state.OLLAMA_API_BASE_URL = form_data.url
  34. return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
  35. @app.get("/cancel/{request_id}")
  36. async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
  37. if user:
  38. if request_id in REQUEST_POOL:
  39. REQUEST_POOL.remove(request_id)
  40. return True
  41. else:
  42. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  43. @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
  44. async def proxy(path: str, request: Request, user=Depends(get_current_user)):
  45. target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
  46. body = await request.body()
  47. headers = dict(request.headers)
  48. if user.role in ["user", "admin"]:
  49. if path in ["pull", "delete", "push", "copy", "create"]:
  50. if user.role != "admin":
  51. raise HTTPException(
  52. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
  53. )
  54. else:
  55. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  56. headers.pop("host", None)
  57. headers.pop("authorization", None)
  58. headers.pop("origin", None)
  59. headers.pop("referer", None)
  60. r = None
  61. def get_request():
  62. nonlocal r
  63. request_id = str(uuid.uuid4())
  64. try:
  65. REQUEST_POOL.append(request_id)
  66. def stream_content():
  67. try:
  68. if path in ["chat"]:
  69. yield json.dumps({"id": request_id, "done": False}) + "\n"
  70. for chunk in r.iter_content(chunk_size=8192):
  71. if request_id in REQUEST_POOL:
  72. yield chunk
  73. else:
  74. print("User: canceled request")
  75. break
  76. finally:
  77. if hasattr(r, "close"):
  78. r.close()
  79. REQUEST_POOL.remove(request_id)
  80. r = requests.request(
  81. method=request.method,
  82. url=target_url,
  83. data=body,
  84. headers=headers,
  85. stream=True,
  86. )
  87. r.raise_for_status()
  88. # r.close()
  89. return StreamingResponse(
  90. stream_content(),
  91. status_code=r.status_code,
  92. headers=dict(r.headers),
  93. )
  94. except Exception as e:
  95. raise e
  96. try:
  97. return await run_in_threadpool(get_request)
  98. except Exception as e:
  99. error_detail = "Ollama WebUI: Server Connection Error"
  100. if r is not None:
  101. try:
  102. res = r.json()
  103. if "error" in res:
  104. error_detail = f"Ollama: {res['error']}"
  105. except:
  106. error_detail = f"Ollama: {e}"
  107. raise HTTPException(
  108. status_code=r.status_code if r else 500,
  109. detail=error_detail,
  110. )