old_main.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. from flask import Flask, request, Response, jsonify
  2. from flask_cors import CORS
  3. import requests
  4. import json
  5. from apps.web.models.users import Users
  6. from constants import ERROR_MESSAGES
  7. from utils.utils import decode_token
  8. from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
  9. app = Flask(__name__)
  10. CORS(
  11. app
  12. ) # Enable Cross-Origin Resource Sharing (CORS) to allow requests from different domains
  13. # Define the target server URL
  14. TARGET_SERVER_URL = OLLAMA_API_BASE_URL
  15. @app.route("/url", methods=["GET"])
  16. def get_ollama_api_url():
  17. headers = dict(request.headers)
  18. if "Authorization" in headers:
  19. _, credentials = headers["Authorization"].split()
  20. token_data = decode_token(credentials)
  21. if token_data is None or "email" not in token_data:
  22. return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
  23. user = Users.get_user_by_email(token_data["email"])
  24. if user and user.role == "admin":
  25. return (
  26. jsonify({"OLLAMA_API_BASE_URL": TARGET_SERVER_URL}),
  27. 200,
  28. )
  29. else:
  30. return (
  31. jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}),
  32. 401,
  33. )
  34. else:
  35. return (
  36. jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}),
  37. 401,
  38. )
  39. @app.route("/url/update", methods=["POST"])
  40. def update_ollama_api_url():
  41. headers = dict(request.headers)
  42. data = request.get_json(force=True)
  43. if "Authorization" in headers:
  44. _, credentials = headers["Authorization"].split()
  45. token_data = decode_token(credentials)
  46. if token_data is None or "email" not in token_data:
  47. return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
  48. user = Users.get_user_by_email(token_data["email"])
  49. if user and user.role == "admin":
  50. TARGET_SERVER_URL = data["url"]
  51. return (
  52. jsonify({"OLLAMA_API_BASE_URL": TARGET_SERVER_URL}),
  53. 200,
  54. )
  55. else:
  56. return (
  57. jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}),
  58. 401,
  59. )
  60. else:
  61. return (
  62. jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}),
  63. 401,
  64. )
  65. @app.route("/",
  66. defaults={"path": ""},
  67. methods=["GET", "POST", "PUT", "DELETE"])
  68. @app.route("/<path:path>", methods=["GET", "POST", "PUT", "DELETE"])
  69. def proxy(path):
  70. # Combine the base URL of the target server with the requested path
  71. target_url = f"{TARGET_SERVER_URL}/{path}"
  72. print(target_url)
  73. # Get data from the original request
  74. data = request.get_data()
  75. headers = dict(request.headers)
  76. # Basic RBAC support
  77. if WEBUI_AUTH:
  78. if "Authorization" in headers:
  79. _, credentials = headers["Authorization"].split()
  80. token_data = decode_token(credentials)
  81. if token_data is None or "email" not in token_data:
  82. return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
  83. user = Users.get_user_by_email(token_data["email"])
  84. if user:
  85. # Only user and admin roles can access
  86. if user.role in ["user", "admin"]:
  87. if path in ["pull", "delete", "push", "copy", "create"]:
  88. # Only admin role can perform actions above
  89. if user.role == "admin":
  90. pass
  91. else:
  92. return (
  93. jsonify({
  94. "detail":
  95. ERROR_MESSAGES.ACCESS_PROHIBITED
  96. }),
  97. 401,
  98. )
  99. else:
  100. pass
  101. else:
  102. return jsonify(
  103. {"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), 401
  104. else:
  105. return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
  106. else:
  107. return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
  108. else:
  109. pass
  110. r = None
  111. headers.pop("Host", None)
  112. headers.pop("Authorization", None)
  113. headers.pop("Origin", None)
  114. headers.pop("Referer", None)
  115. try:
  116. # Make a request to the target server
  117. r = requests.request(
  118. method=request.method,
  119. url=target_url,
  120. data=data,
  121. headers=headers,
  122. stream=True, # Enable streaming for server-sent events
  123. )
  124. r.raise_for_status()
  125. # Proxy the target server's response to the client
  126. def generate():
  127. for chunk in r.iter_content(chunk_size=8192):
  128. yield chunk
  129. response = Response(generate(), status=r.status_code)
  130. # Copy headers from the target server's response to the client's response
  131. for key, value in r.headers.items():
  132. response.headers[key] = value
  133. return response
  134. except Exception as e:
  135. print(e)
  136. error_detail = "Ollama WebUI: Server Connection Error"
  137. if r != None:
  138. print(r.text)
  139. res = r.json()
  140. if "error" in res:
  141. error_detail = f"Ollama: {res['error']}"
  142. print(res)
  143. return (
  144. jsonify({
  145. "detail": error_detail,
  146. "message": str(e),
  147. }),
  148. 400,
  149. )
  150. if __name__ == "__main__":
  151. app.run(debug=True)