main.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from flask import Flask, request, Response, jsonify
  2. from flask_cors import CORS
  3. import requests
  4. import json
  5. from apps.web.models.users import Users
  6. from constants import ERROR_MESSAGES
  7. from utils import extract_token_from_auth_header
  8. from config import OLLAMA_API_BASE_URL, OLLAMA_WEBUI_AUTH
  9. app = Flask(__name__)
  10. CORS(
  11. app
  12. ) # Enable Cross-Origin Resource Sharing (CORS) to allow requests from different domains
  13. # Define the target server URL
  14. TARGET_SERVER_URL = OLLAMA_API_BASE_URL
  15. @app.route("/", defaults={"path": ""}, methods=["GET", "POST", "PUT", "DELETE"])
  16. @app.route("/<path:path>", methods=["GET", "POST", "PUT", "DELETE"])
  17. def proxy(path):
  18. # Combine the base URL of the target server with the requested path
  19. target_url = f"{TARGET_SERVER_URL}/{path}"
  20. print(target_url)
  21. # Get data from the original request
  22. data = request.get_data()
  23. headers = dict(request.headers)
  24. if OLLAMA_WEBUI_AUTH:
  25. if "Authorization" in headers:
  26. token = extract_token_from_auth_header(headers["Authorization"])
  27. user = Users.get_user_by_token(token)
  28. if user:
  29. print(user)
  30. pass
  31. else:
  32. return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
  33. else:
  34. return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
  35. else:
  36. pass
  37. # Make a request to the target server
  38. target_response = requests.request(
  39. method=request.method,
  40. url=target_url,
  41. data=data,
  42. headers=headers,
  43. stream=True, # Enable streaming for server-sent events
  44. )
  45. # Proxy the target server's response to the client
  46. def generate():
  47. for chunk in target_response.iter_content(chunk_size=8192):
  48. yield chunk
  49. response = Response(generate(), status=target_response.status_code)
  50. # Copy headers from the target server's response to the client's response
  51. for key, value in target_response.headers.items():
  52. response.headers[key] = value
  53. return response
  54. if __name__ == "__main__":
  55. app.run(debug=True)