main.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. from fastapi import FastAPI, Request, Response, HTTPException, Depends
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from fastapi.responses import StreamingResponse
  4. from fastapi.concurrency import run_in_threadpool
  5. import requests
  6. import json
  7. from pydantic import BaseModel
  8. from apps.web.models.users import Users
  9. from constants import ERROR_MESSAGES
  10. from utils.utils import decode_token, get_current_user
  11. from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
  12. app = FastAPI()
  13. app.add_middleware(
  14. CORSMiddleware,
  15. allow_origins=["*"],
  16. allow_credentials=True,
  17. allow_methods=["*"],
  18. allow_headers=["*"],
  19. )
  20. app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
  21. # TARGET_SERVER_URL = OLLAMA_API_BASE_URL
  22. @app.get("/url")
  23. async def get_ollama_api_url(user=Depends(get_current_user)):
  24. if user and user.role == "admin":
  25. return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
  26. else:
  27. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  28. class UrlUpdateForm(BaseModel):
  29. url: str
  30. @app.post("/url/update")
  31. async def update_ollama_api_url(
  32. form_data: UrlUpdateForm, user=Depends(get_current_user)
  33. ):
  34. if user and user.role == "admin":
  35. app.state.OLLAMA_API_BASE_URL = form_data.url
  36. return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
  37. else:
  38. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  39. @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
  40. async def proxy(path: str, request: Request, user=Depends(get_current_user)):
  41. target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
  42. body = await request.body()
  43. headers = dict(request.headers)
  44. if user.role in ["user", "admin"]:
  45. if path in ["pull", "delete", "push", "copy", "create"]:
  46. if user.role != "admin":
  47. raise HTTPException(
  48. status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
  49. )
  50. else:
  51. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  52. headers.pop("Host", None)
  53. headers.pop("Authorization", None)
  54. headers.pop("Origin", None)
  55. headers.pop("Referer", None)
  56. r = None
  57. def get_request():
  58. nonlocal r
  59. try:
  60. r = requests.request(
  61. method=request.method,
  62. url=target_url,
  63. data=body,
  64. headers=headers,
  65. stream=True,
  66. )
  67. r.raise_for_status()
  68. return StreamingResponse(
  69. r.iter_content(chunk_size=8192),
  70. status_code=r.status_code,
  71. headers=dict(r.headers),
  72. )
  73. except Exception as e:
  74. raise e
  75. try:
  76. return await run_in_threadpool(get_request)
  77. except Exception as e:
  78. error_detail = "Ollama WebUI: Server Connection Error"
  79. if r is not None:
  80. try:
  81. res = r.json()
  82. if "error" in res:
  83. error_detail = f"Ollama: {res['error']}"
  84. except:
  85. error_detail = f"Ollama: {e}"
  86. raise HTTPException(
  87. status_code=r.status_code if r else 500,
  88. detail=error_detail,
  89. )