123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- from fastapi import FastAPI, Request, Response, HTTPException, Depends
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import StreamingResponse
- import requests
- import json
- from pydantic import BaseModel
- from apps.web.models.users import Users
- from constants import ERROR_MESSAGES
- from utils.utils import decode_token, get_current_user
- from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
- import aiohttp
- app = FastAPI()
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
- # TARGET_SERVER_URL = OLLAMA_API_BASE_URL
- @app.get("/url")
- async def get_ollama_api_url(user=Depends(get_current_user)):
- if user and user.role == "admin":
- return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
- else:
- raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
- class UrlUpdateForm(BaseModel):
- url: str
- @app.post("/url/update")
- async def update_ollama_api_url(
- form_data: UrlUpdateForm, user=Depends(get_current_user)
- ):
- if user and user.role == "admin":
- app.state.OLLAMA_API_BASE_URL = form_data.url
- return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
- else:
- raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
- # async def fetch_sse(method, target_url, body, headers):
- # async with aiohttp.ClientSession() as session:
- # try:
- # async with session.request(
- # method, target_url, data=body, headers=headers
- # ) as response:
- # print(response.status)
- # async for line in response.content:
- # yield line
- # except Exception as e:
- # print(e)
- # error_detail = "Open WebUI: Server Connection Error"
- # yield json.dumps({"error": error_detail, "message": str(e)}).encode()
- @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
- async def proxy(path: str, request: Request, user=Depends(get_current_user)):
- target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
- print(target_url)
- body = await request.body()
- headers = dict(request.headers)
- if user.role in ["user", "admin"]:
- if path in ["pull", "delete", "push", "copy", "create"]:
- if user.role != "admin":
- raise HTTPException(
- status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
- )
- else:
- raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
- headers.pop("Host", None)
- headers.pop("Authorization", None)
- headers.pop("Origin", None)
- headers.pop("Referer", None)
- session = aiohttp.ClientSession()
- response = None
- try:
- response = await session.request(
- request.method, target_url, data=body, headers=headers
- )
- print(response)
- if not response.ok:
- data = await response.json()
- print(data)
- response.raise_for_status()
- async def generate():
- async for line in response.content:
- print(line)
- yield line
- await session.close()
- return StreamingResponse(generate(), response.status)
- except Exception as e:
- print(e)
- error_detail = "Open WebUI: Server Connection Error"
- if response is not None:
- try:
- res = await response.json()
- if "error" in res:
- error_detail = f"Ollama: {res['error']}"
- except:
- error_detail = f"Ollama: {e}"
- await session.close()
- raise HTTPException(
- status_code=response.status if response else 500,
- detail=error_detail,
- )
|