security_headers.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import re
  2. import os
  3. from fastapi import Request
  4. from starlette.middleware.base import BaseHTTPMiddleware
  5. from typing import Dict
  6. class SecurityHeadersMiddleware(BaseHTTPMiddleware):
  7. async def dispatch(self, request: Request, call_next):
  8. response = await call_next(request)
  9. response.headers.update(set_security_headers())
  10. return response
  11. def set_security_headers() -> Dict[str, str]:
  12. """
  13. Sets security headers based on environment variables.
  14. This function reads specific environment variables and uses their values
  15. to set corresponding security headers. The headers that can be set are:
  16. - cache-control
  17. - strict-transport-security
  18. - referrer-policy
  19. - x-content-type-options
  20. - x-download-options
  21. - x-frame-options
  22. Each environment variable is associated with a specific setter function
  23. that constructs the header. If the environment variable is set, the
  24. corresponding header is added to the options dictionary.
  25. Returns:
  26. dict: A dictionary containing the security headers and their values.
  27. """
  28. options = {}
  29. header_setters = {
  30. 'CACHE_CONTROL': set_cache_control,
  31. 'HSTS': set_hsts,
  32. 'REFERRER_POLICY': set_referrer,
  33. 'XCONTENT_TYPE': set_xcontent_type,
  34. 'XDOWNLOAD_OPTIONS': set_xdownload_options,
  35. 'XFRAME_OPTIONS': set_xframe,
  36. 'XPERMITTED_CROSS_DOMAIN_POLICIES': set_xpermitted_cross_domain_policies,
  37. }
  38. for env_var, setter in header_setters.items():
  39. value = os.environ.get(env_var, None)
  40. if value:
  41. header = setter(value)
  42. if header:
  43. options.update(header)
  44. return options
  45. # Set HTTP Strict Transport Security(HSTS) response header
  46. def set_hsts(value: str):
  47. pattern = r'^max-age=(\d+)(;includeSubDomains)?(;preload)?$'
  48. match = re.match(pattern, value, re.IGNORECASE)
  49. if not match:
  50. return 'max-age=31536000;includeSubDomains'
  51. return {
  52. 'Strict-Transport-Security': value
  53. }
  54. # Set X-Frame-Options response header
  55. def set_xframe(value: str):
  56. pattern = r'^(DENY|SAMEORIGIN)$'
  57. match = re.match(pattern, value, re.IGNORECASE)
  58. if not match:
  59. value = 'DENY'
  60. return {
  61. "X-Frame-Options": value
  62. }
  63. # Set Referrer-Policy response header
  64. def set_referrer(value: str):
  65. pattern = r'^(no-referrer|no-referrer-when-downgrade|origin|origin-when-cross-origin|same-origin|strict-origin|strict-origin-when-cross-origin|unsafe-url)$'
  66. match = re.match(pattern, value, re.IGNORECASE)
  67. if not match:
  68. value = 'no-referrer'
  69. return {
  70. 'Referrer-Policy': value
  71. }
  72. # Set Cache-Control response header
  73. def set_cache_control(value: str):
  74. pattern = r'^(public|private|no-cache|no-store|must-revalidate|proxy-revalidate|max-age=\d+|s-maxage=\d+|no-transform|immutable)(,\s*(public|private|no-cache|no-store|must-revalidate|proxy-revalidate|max-age=\d+|s-maxage=\d+|no-transform|immutable))*$'
  75. match = re.match(pattern, value, re.IGNORECASE)
  76. if not match:
  77. value = 'no-store, max-age=0'
  78. return {
  79. 'Cache-Control': value
  80. }
  81. # Set X-Download-Options response header
  82. def set_xdownload_options(value: str):
  83. if value != 'noopen':
  84. value = 'noopen'
  85. return {
  86. 'X-Download-Options': value
  87. }
  88. # Set X-Content-Type-Options response header
  89. def set_xcontent_type(value: str):
  90. if value != 'nosniff':
  91. value = 'nosniff'
  92. return {
  93. 'X-Content-Type-Options': value
  94. }
  95. # Set X-Permitted-Cross-Domain-Policies response header
  96. def set_xpermitted_cross_domain_policies(value: str):
  97. pattern = r'^(none|master-only|by-content-type|by-ftp-filename)$'
  98. match = re.match(pattern, value, re.IGNORECASE)
  99. if not match:
  100. value = 'none'
  101. return {
  102. 'X-Permitted-Cross-Domain-Policies': value
  103. }