Procházet zdrojové kódy

Merge pull request #398 from ollama-webui/proxy-fix

fix: backend proxy
Timothy Jaeryang Baek před 1 rokem
rodič
revize
6b9453e28f
3 změnil soubory, kde provedl 120 přidání a 188 odebrání
  1. 2 3
      Dockerfile
  2. 28 44
      backend/apps/ollama/main.py
  3. 90 141
      backend/apps/ollama/old_main.py

+ 2 - 3
Dockerfile

@@ -12,10 +12,9 @@ RUN npm run build
 
 
 FROM python:3.11-slim-buster as base
 FROM python:3.11-slim-buster as base
 
 
-ARG OLLAMA_API_BASE_URL='/ollama/api'
-
 ENV ENV=prod
 ENV ENV=prod
-ENV OLLAMA_API_BASE_URL $OLLAMA_API_BASE_URL
+
+ENV OLLAMA_API_BASE_URL "/ollama/api"
 
 
 ENV OPENAI_API_BASE_URL ""
 ENV OPENAI_API_BASE_URL ""
 ENV OPENAI_API_KEY ""
 ENV OPENAI_API_KEY ""

+ 28 - 44
backend/apps/ollama/main.py

@@ -1,6 +1,7 @@
 from fastapi import FastAPI, Request, Response, HTTPException, Depends
 from fastapi import FastAPI, Request, Response, HTTPException, Depends
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import StreamingResponse
 from fastapi.responses import StreamingResponse
+from fastapi.concurrency import run_in_threadpool
 
 
 import requests
 import requests
 import json
 import json
@@ -11,8 +12,6 @@ from constants import ERROR_MESSAGES
 from utils.utils import decode_token, get_current_user
 from utils.utils import decode_token, get_current_user
 from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
 from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
 
 
-import aiohttp
-
 app = FastAPI()
 app = FastAPI()
 app.add_middleware(
 app.add_middleware(
     CORSMiddleware,
     CORSMiddleware,
@@ -50,25 +49,9 @@ async def update_ollama_api_url(
         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
 
 
 
 
-# async def fetch_sse(method, target_url, body, headers):
-#     async with aiohttp.ClientSession() as session:
-#         try:
-#             async with session.request(
-#                 method, target_url, data=body, headers=headers
-#             ) as response:
-#                 print(response.status)
-#                 async for line in response.content:
-#                     yield line
-#         except Exception as e:
-#             print(e)
-#             error_detail = "Ollama WebUI: Server Connection Error"
-#             yield json.dumps({"error": error_detail, "message": str(e)}).encode()
-
-
 @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
 @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
 async def proxy(path: str, request: Request, user=Depends(get_current_user)):
 async def proxy(path: str, request: Request, user=Depends(get_current_user)):
     target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
     target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
-    print(target_url)
 
 
     body = await request.body()
     body = await request.body()
     headers = dict(request.headers)
     headers = dict(request.headers)
@@ -87,41 +70,42 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
     headers.pop("Origin", None)
     headers.pop("Origin", None)
     headers.pop("Referer", None)
     headers.pop("Referer", None)
 
 
-    session = aiohttp.ClientSession()
-    response = None
-    try:
-        response = await session.request(
-            request.method, target_url, data=body, headers=headers
-        )
-
-        print(response)
-        if not response.ok:
-            data = await response.json()
-            print(data)
-            response.raise_for_status()
-
-        async def generate():
-            async for line in response.content:
-                print(line)
-                yield line
-            await session.close()
-
-        return StreamingResponse(generate(), response.status)
+    r = None
+
+    def get_request():
+        nonlocal r
+        try:
+            r = requests.request(
+                method=request.method,
+                url=target_url,
+                data=body,
+                headers=headers,
+                stream=True,
+            )
+
+            r.raise_for_status()
+
+            return StreamingResponse(
+                r.iter_content(chunk_size=8192),
+                status_code=r.status_code,
+                headers=dict(r.headers),
+            )
+        except Exception as e:
+            raise e
 
 
+    try:
+        return await run_in_threadpool(get_request)
     except Exception as e:
     except Exception as e:
-        print(e)
         error_detail = "Ollama WebUI: Server Connection Error"
         error_detail = "Ollama WebUI: Server Connection Error"
-
-        if response is not None:
+        if r is not None:
             try:
             try:
-                res = await response.json()
+                res = r.json()
                 if "error" in res:
                 if "error" in res:
                     error_detail = f"Ollama: {res['error']}"
                     error_detail = f"Ollama: {res['error']}"
             except:
             except:
                 error_detail = f"Ollama: {e}"
                 error_detail = f"Ollama: {e}"
 
 
-        await session.close()
         raise HTTPException(
         raise HTTPException(
-            status_code=response.status if response else 500,
+            status_code=r.status_code if r else 500,
             detail=error_detail,
             detail=error_detail,
         )
         )

+ 90 - 141
backend/apps/ollama/old_main.py

@@ -1,178 +1,127 @@
-from flask import Flask, request, Response, jsonify
-from flask_cors import CORS
+from fastapi import FastAPI, Request, Response, HTTPException, Depends
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import StreamingResponse
 
 
 import requests
 import requests
 import json
 import json
+from pydantic import BaseModel
 
 
 from apps.web.models.users import Users
 from apps.web.models.users import Users
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
-from utils.utils import decode_token
+from utils.utils import decode_token, get_current_user
 from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
 from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
 
 
-app = Flask(__name__)
-CORS(
-    app
-)  # Enable Cross-Origin Resource Sharing (CORS) to allow requests from different domains
+import aiohttp
 
 
-# Define the target server URL
-TARGET_SERVER_URL = OLLAMA_API_BASE_URL
+app = FastAPI()
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
 
 
+app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
 
 
-@app.route("/url", methods=["GET"])
-def get_ollama_api_url():
-    headers = dict(request.headers)
-    if "Authorization" in headers:
-        _, credentials = headers["Authorization"].split()
-        token_data = decode_token(credentials)
-        if token_data is None or "email" not in token_data:
-            return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
-
-        user = Users.get_user_by_email(token_data["email"])
-        if user and user.role == "admin":
-            return (
-                jsonify({"OLLAMA_API_BASE_URL": TARGET_SERVER_URL}),
-                200,
-            )
-        else:
-            return (
-                jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}),
-                401,
-            )
-    else:
-        return (
-            jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}),
-            401,
-        )
+# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
 
 
 
 
-@app.route("/url/update", methods=["POST"])
-def update_ollama_api_url():
-    headers = dict(request.headers)
-    data = request.get_json(force=True)
-
-    if "Authorization" in headers:
-        _, credentials = headers["Authorization"].split()
-        token_data = decode_token(credentials)
-        if token_data is None or "email" not in token_data:
-            return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
-
-        user = Users.get_user_by_email(token_data["email"])
-        if user and user.role == "admin":
-            TARGET_SERVER_URL = data["url"]
-            return (
-                jsonify({"OLLAMA_API_BASE_URL": TARGET_SERVER_URL}),
-                200,
-            )
-        else:
-            return (
-                jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}),
-                401,
-            )
+@app.get("/url")
+async def get_ollama_api_url(user=Depends(get_current_user)):
+    if user and user.role == "admin":
+        return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
     else:
     else:
-        return (
-            jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}),
-            401,
-        )
+        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+
+
+class UrlUpdateForm(BaseModel):
+    url: str
 
 
 
 
-@app.route("/",
-           defaults={"path": ""},
-           methods=["GET", "POST", "PUT", "DELETE"])
-@app.route("/<path:path>", methods=["GET", "POST", "PUT", "DELETE"])
-def proxy(path):
-    # Combine the base URL of the target server with the requested path
-    target_url = f"{TARGET_SERVER_URL}/{path}"
+@app.post("/url/update")
+async def update_ollama_api_url(
+    form_data: UrlUpdateForm, user=Depends(get_current_user)
+):
+    if user and user.role == "admin":
+        app.state.OLLAMA_API_BASE_URL = form_data.url
+        return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
+    else:
+        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+
+
+# async def fetch_sse(method, target_url, body, headers):
+#     async with aiohttp.ClientSession() as session:
+#         try:
+#             async with session.request(
+#                 method, target_url, data=body, headers=headers
+#             ) as response:
+#                 print(response.status)
+#                 async for line in response.content:
+#                     yield line
+#         except Exception as e:
+#             print(e)
+#             error_detail = "Ollama WebUI: Server Connection Error"
+#             yield json.dumps({"error": error_detail, "message": str(e)}).encode()
+
+
+@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
+async def proxy(path: str, request: Request, user=Depends(get_current_user)):
+    target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
     print(target_url)
     print(target_url)
 
 
-    # Get data from the original request
-    data = request.get_data()
+    body = await request.body()
     headers = dict(request.headers)
     headers = dict(request.headers)
 
 
-    # Basic RBAC support
-    if WEBUI_AUTH:
-        if "Authorization" in headers:
-            _, credentials = headers["Authorization"].split()
-            token_data = decode_token(credentials)
-            if token_data is None or "email" not in token_data:
-                return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
-
-            user = Users.get_user_by_email(token_data["email"])
-            if user:
-                # Only user and admin roles can access
-                if user.role in ["user", "admin"]:
-                    if path in ["pull", "delete", "push", "copy", "create"]:
-                        # Only admin role can perform actions above
-                        if user.role == "admin":
-                            pass
-                        else:
-                            return (
-                                jsonify({
-                                    "detail":
-                                    ERROR_MESSAGES.ACCESS_PROHIBITED
-                                }),
-                                401,
-                            )
-                    else:
-                        pass
-                else:
-                    return jsonify(
-                        {"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), 401
-            else:
-                return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
-        else:
-            return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
+    if user.role in ["user", "admin"]:
+        if path in ["pull", "delete", "push", "copy", "create"]:
+            if user.role != "admin":
+                raise HTTPException(
+                    status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
+                )
     else:
     else:
-        pass
-
-    r = None
+        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
 
 
     headers.pop("Host", None)
     headers.pop("Host", None)
     headers.pop("Authorization", None)
     headers.pop("Authorization", None)
     headers.pop("Origin", None)
     headers.pop("Origin", None)
     headers.pop("Referer", None)
     headers.pop("Referer", None)
 
 
+    session = aiohttp.ClientSession()
+    response = None
     try:
     try:
-        # Make a request to the target server
-        r = requests.request(
-            method=request.method,
-            url=target_url,
-            data=data,
-            headers=headers,
-            stream=True,  # Enable streaming for server-sent events
+        response = await session.request(
+            request.method, target_url, data=body, headers=headers
         )
         )
 
 
-        r.raise_for_status()
-
-        # Proxy the target server's response to the client
-        def generate():
-            for chunk in r.iter_content(chunk_size=8192):
-                yield chunk
+        print(response)
+        if not response.ok:
+            data = await response.json()
+            print(data)
+            response.raise_for_status()
 
 
-        response = Response(generate(), status=r.status_code)
+        async def generate():
+            async for line in response.content:
+                print(line)
+                yield line
+            await session.close()
 
 
-        # Copy headers from the target server's response to the client's response
-        for key, value in r.headers.items():
-            response.headers[key] = value
+        return StreamingResponse(generate(), response.status)
 
 
-        return response
     except Exception as e:
     except Exception as e:
         print(e)
         print(e)
         error_detail = "Ollama WebUI: Server Connection Error"
         error_detail = "Ollama WebUI: Server Connection Error"
-        if r != None:
-            print(r.text)
-            res = r.json()
-            if "error" in res:
-                error_detail = f"Ollama: {res['error']}"
-            print(res)
-
-        return (
-            jsonify({
-                "detail": error_detail,
-                "message": str(e),
-            }),
-            400,
-        )
-
 
 
-if __name__ == "__main__":
-    app.run(debug=True)
+        if response is not None:
+            try:
+                res = await response.json()
+                if "error" in res:
+                    error_detail = f"Ollama: {res['error']}"
+            except:
+                error_detail = f"Ollama: {e}"
+
+        await session.close()
+        raise HTTPException(
+            status_code=response.status if response else 500,
+            detail=error_detail,
+        )