main.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from flask import Flask, request, Response, jsonify
  2. from flask_cors import CORS
  3. import requests
  4. import json
  5. from apps.web.models.users import Users
  6. from constants import ERROR_MESSAGES
  7. from utils.utils import decode_token
  8. from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
  9. app = Flask(__name__)
  10. CORS(
  11. app
  12. ) # Enable Cross-Origin Resource Sharing (CORS) to allow requests from different domains
  13. # Define the target server URL
  14. TARGET_SERVER_URL = OLLAMA_API_BASE_URL
  15. @app.route("/",
  16. defaults={"path": ""},
  17. methods=["GET", "POST", "PUT", "DELETE"])
  18. @app.route("/<path:path>", methods=["GET", "POST", "PUT", "DELETE"])
  19. def proxy(path):
  20. # Combine the base URL of the target server with the requested path
  21. target_url = f"{TARGET_SERVER_URL}/{path}"
  22. print(target_url)
  23. # Get data from the original request
  24. data = request.get_data()
  25. headers = dict(request.headers)
  26. # Basic RBAC support
  27. if WEBUI_AUTH:
  28. if "Authorization" in headers:
  29. _, credentials = headers["Authorization"].split()
  30. token_data = decode_token(credentials)
  31. if token_data is None or "email" not in token_data:
  32. return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
  33. user = Users.get_user_by_email(token_data["email"])
  34. if user:
  35. # Only user and admin roles can access
  36. if user.role in ["user", "admin"]:
  37. if path in ["pull", "delete", "push", "copy", "create"]:
  38. # Only admin role can perform actions above
  39. if user.role == "admin":
  40. pass
  41. else:
  42. return (
  43. jsonify({
  44. "detail":
  45. ERROR_MESSAGES.ACCESS_PROHIBITED
  46. }),
  47. 401,
  48. )
  49. else:
  50. pass
  51. else:
  52. return jsonify(
  53. {"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), 401
  54. else:
  55. return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
  56. else:
  57. return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
  58. else:
  59. pass
  60. r = None
  61. headers.pop("Host", None)
  62. headers.pop("Authorization", None)
  63. headers.pop("Origin", None)
  64. headers.pop("Referer", None)
  65. try:
  66. # Make a request to the target server
  67. r = requests.request(
  68. method=request.method,
  69. url=target_url,
  70. data=data,
  71. headers=headers,
  72. stream=True, # Enable streaming for server-sent events
  73. )
  74. r.raise_for_status()
  75. # Proxy the target server's response to the client
  76. def generate():
  77. for chunk in r.iter_content(chunk_size=8192):
  78. yield chunk
  79. response = Response(generate(), status=r.status_code)
  80. # Copy headers from the target server's response to the client's response
  81. for key, value in r.headers.items():
  82. response.headers[key] = value
  83. return response
  84. except Exception as e:
  85. print(e)
  86. error_detail = "Ollama WebUI: Server Connection Error"
  87. if r != None:
  88. print(r.text)
  89. res = r.json()
  90. if "error" in res:
  91. error_detail = f"Ollama: {res['error']}"
  92. print(res)
  93. return (
  94. jsonify({
  95. "detail": error_detail,
  96. "message": str(e),
  97. }),
  98. 400,
  99. )
  100. if __name__ == "__main__":
  101. app.run(debug=True)