security_headers.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import re
  2. import os
  3. from fastapi import Request
  4. from starlette.middleware.base import BaseHTTPMiddleware
  5. from typing import Dict
  6. class SecurityHeadersMiddleware(BaseHTTPMiddleware):
  7. async def dispatch(self, request: Request, call_next):
  8. response = await call_next(request)
  9. response.headers.update(set_security_headers())
  10. return response
  11. def set_security_headers() -> Dict[str, str]:
  12. """
  13. Sets security headers based on environment variables.
  14. This function reads specific environment variables and uses their values
  15. to set corresponding security headers. The headers that can be set are:
  16. - cache-control
  17. - strict-transport-security
  18. - referrer-policy
  19. - x-content-type-options
  20. - x-download-options
  21. - x-frame-options
  22. - x-permitted-cross-domain-policies
  23. Each environment variable is associated with a specific setter function
  24. that constructs the header. If the environment variable is set, the
  25. corresponding header is added to the options dictionary.
  26. Returns:
  27. dict: A dictionary containing the security headers and their values.
  28. """
  29. options = {}
  30. header_setters = {
  31. "CACHE_CONTROL": set_cache_control,
  32. "HSTS": set_hsts,
  33. "REFERRER_POLICY": set_referrer,
  34. "XCONTENT_TYPE": set_xcontent_type,
  35. "XDOWNLOAD_OPTIONS": set_xdownload_options,
  36. "XFRAME_OPTIONS": set_xframe,
  37. "XPERMITTED_CROSS_DOMAIN_POLICIES": set_xpermitted_cross_domain_policies,
  38. }
  39. for env_var, setter in header_setters.items():
  40. value = os.environ.get(env_var, None)
  41. if value:
  42. header = setter(value)
  43. if header:
  44. options.update(header)
  45. return options
  46. # Set HTTP Strict Transport Security(HSTS) response header
  47. def set_hsts(value: str):
  48. pattern = r"^max-age=(\d+)(;includeSubDomains)?(;preload)?$"
  49. match = re.match(pattern, value, re.IGNORECASE)
  50. if not match:
  51. return "max-age=31536000;includeSubDomains"
  52. return {"Strict-Transport-Security": value}
  53. # Set X-Frame-Options response header
  54. def set_xframe(value: str):
  55. pattern = r"^(DENY|SAMEORIGIN)$"
  56. match = re.match(pattern, value, re.IGNORECASE)
  57. if not match:
  58. value = "DENY"
  59. return {"X-Frame-Options": value}
  60. # Set Referrer-Policy response header
  61. def set_referrer(value: str):
  62. pattern = r"^(no-referrer|no-referrer-when-downgrade|origin|origin-when-cross-origin|same-origin|strict-origin|strict-origin-when-cross-origin|unsafe-url)$"
  63. match = re.match(pattern, value, re.IGNORECASE)
  64. if not match:
  65. value = "no-referrer"
  66. return {"Referrer-Policy": value}
  67. # Set Cache-Control response header
  68. def set_cache_control(value: str):
  69. pattern = r"^(public|private|no-cache|no-store|must-revalidate|proxy-revalidate|max-age=\d+|s-maxage=\d+|no-transform|immutable)(,\s*(public|private|no-cache|no-store|must-revalidate|proxy-revalidate|max-age=\d+|s-maxage=\d+|no-transform|immutable))*$"
  70. match = re.match(pattern, value, re.IGNORECASE)
  71. if not match:
  72. value = "no-store, max-age=0"
  73. return {"Cache-Control": value}
  74. # Set X-Download-Options response header
  75. def set_xdownload_options(value: str):
  76. if value != "noopen":
  77. value = "noopen"
  78. return {"X-Download-Options": value}
  79. # Set X-Content-Type-Options response header
  80. def set_xcontent_type(value: str):
  81. if value != "nosniff":
  82. value = "nosniff"
  83. return {"X-Content-Type-Options": value}
  84. # Set X-Permitted-Cross-Domain-Policies response header
  85. def set_xpermitted_cross_domain_policies(value: str):
  86. pattern = r"^(none|master-only|by-content-type|by-ftp-filename)$"
  87. match = re.match(pattern, value, re.IGNORECASE)
  88. if not match:
  89. value = "none"
  90. return {"X-Permitted-Cross-Domain-Policies": value}