main.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. from flask import Flask, request, Response, jsonify
  2. from flask_cors import CORS
  3. import requests
  4. import json
  5. from apps.web.models.users import Users
  6. from constants import ERROR_MESSAGES
  7. from utils.utils import extract_token_from_auth_header
  8. from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
  9. app = Flask(__name__)
  10. CORS(
  11. app
  12. ) # Enable Cross-Origin Resource Sharing (CORS) to allow requests from different domains
  13. # Define the target server URL
  14. TARGET_SERVER_URL = OLLAMA_API_BASE_URL
  15. @app.route("/", defaults={"path": ""}, methods=["GET", "POST", "PUT", "DELETE"])
  16. @app.route("/<path:path>", methods=["GET", "POST", "PUT", "DELETE"])
  17. def proxy(path):
  18. # Combine the base URL of the target server with the requested path
  19. target_url = f"{TARGET_SERVER_URL}/{path}"
  20. print(path)
  21. # Get data from the original request
  22. data = request.get_data()
  23. headers = dict(request.headers)
  24. # Basic RBAC support
  25. if WEBUI_AUTH:
  26. if "Authorization" in headers:
  27. token = extract_token_from_auth_header(headers["Authorization"])
  28. user = Users.get_user_by_token(token)
  29. if user:
  30. # Only user and admin roles can access
  31. if user.role in ["user", "admin"]:
  32. if path in ["pull", "delete", "push", "copy", "create"]:
  33. # Only admin role can perform actions above
  34. if user.role == "admin":
  35. pass
  36. else:
  37. return (
  38. jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}),
  39. 401,
  40. )
  41. else:
  42. pass
  43. else:
  44. return jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), 401
  45. else:
  46. return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
  47. else:
  48. return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
  49. else:
  50. pass
  51. r = None
  52. try:
  53. # Make a request to the target server
  54. r = requests.request(
  55. method=request.method,
  56. url=target_url,
  57. data=data,
  58. headers=headers,
  59. stream=True, # Enable streaming for server-sent events
  60. )
  61. r.raise_for_status()
  62. # Proxy the target server's response to the client
  63. def generate():
  64. for chunk in r.iter_content(chunk_size=8192):
  65. yield chunk
  66. response = Response(generate(), status=r.status_code)
  67. # Copy headers from the target server's response to the client's response
  68. for key, value in r.headers.items():
  69. response.headers[key] = value
  70. return response
  71. except Exception as e:
  72. error_detail = "Ollama WebUI: Server Connection Error"
  73. if r != None:
  74. res = r.json()
  75. if "error" in res:
  76. error_detail = f"Ollama: {res['error']}"
  77. print(res)
  78. return (
  79. jsonify(
  80. {
  81. "detail": error_detail,
  82. "message": str(e),
  83. }
  84. ),
  85. 400,
  86. )
  87. if __name__ == "__main__":
  88. app.run(debug=True)