security_headers.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import re
  2. import os
  3. from fastapi import Request
  4. from starlette.middleware.base import BaseHTTPMiddleware
  5. from typing import Dict
  6. class SecurityHeadersMiddleware(BaseHTTPMiddleware):
  7. async def dispatch(self, request: Request, call_next):
  8. response = await call_next(request)
  9. response.headers.update(set_security_headers())
  10. return response
  11. def set_security_headers() -> Dict[str, str]:
  12. """
  13. Sets security headers based on environment variables.
  14. This function reads specific environment variables and uses their values
  15. to set corresponding security headers. The headers that can be set are:
  16. - cache-control
  17. - strict-transport-security
  18. - referrer-policy
  19. - x-content-type-options
  20. - x-download-options
  21. - x-frame-options
  22. - x-permitted-cross-domain-policies
  23. Each environment variable is associated with a specific setter function
  24. that constructs the header. If the environment variable is set, the
  25. corresponding header is added to the options dictionary.
  26. Returns:
  27. dict: A dictionary containing the security headers and their values.
  28. """
  29. options = {}
  30. header_setters = {
  31. 'CACHE_CONTROL': set_cache_control,
  32. 'HSTS': set_hsts,
  33. 'REFERRER_POLICY': set_referrer,
  34. 'XCONTENT_TYPE': set_xcontent_type,
  35. 'XDOWNLOAD_OPTIONS': set_xdownload_options,
  36. 'XFRAME_OPTIONS': set_xframe,
  37. 'XPERMITTED_CROSS_DOMAIN_POLICIES': set_xpermitted_cross_domain_policies,
  38. }
  39. for env_var, setter in header_setters.items():
  40. value = os.environ.get(env_var, None)
  41. if value:
  42. header = setter(value)
  43. if header:
  44. options.update(header)
  45. return options
  46. # Set HTTP Strict Transport Security(HSTS) response header
  47. def set_hsts(value: str):
  48. pattern = r'^max-age=(\d+)(;includeSubDomains)?(;preload)?$'
  49. match = re.match(pattern, value, re.IGNORECASE)
  50. if not match:
  51. return 'max-age=31536000;includeSubDomains'
  52. return {
  53. 'Strict-Transport-Security': value
  54. }
  55. # Set X-Frame-Options response header
  56. def set_xframe(value: str):
  57. pattern = r'^(DENY|SAMEORIGIN)$'
  58. match = re.match(pattern, value, re.IGNORECASE)
  59. if not match:
  60. value = 'DENY'
  61. return {
  62. "X-Frame-Options": value
  63. }
  64. # Set Referrer-Policy response header
  65. def set_referrer(value: str):
  66. pattern = r'^(no-referrer|no-referrer-when-downgrade|origin|origin-when-cross-origin|same-origin|strict-origin|strict-origin-when-cross-origin|unsafe-url)$'
  67. match = re.match(pattern, value, re.IGNORECASE)
  68. if not match:
  69. value = 'no-referrer'
  70. return {
  71. 'Referrer-Policy': value
  72. }
  73. # Set Cache-Control response header
  74. def set_cache_control(value: str):
  75. pattern = r'^(public|private|no-cache|no-store|must-revalidate|proxy-revalidate|max-age=\d+|s-maxage=\d+|no-transform|immutable)(,\s*(public|private|no-cache|no-store|must-revalidate|proxy-revalidate|max-age=\d+|s-maxage=\d+|no-transform|immutable))*$'
  76. match = re.match(pattern, value, re.IGNORECASE)
  77. if not match:
  78. value = 'no-store, max-age=0'
  79. return {
  80. 'Cache-Control': value
  81. }
  82. # Set X-Download-Options response header
  83. def set_xdownload_options(value: str):
  84. if value != 'noopen':
  85. value = 'noopen'
  86. return {
  87. 'X-Download-Options': value
  88. }
  89. # Set X-Content-Type-Options response header
  90. def set_xcontent_type(value: str):
  91. if value != 'nosniff':
  92. value = 'nosniff'
  93. return {
  94. 'X-Content-Type-Options': value
  95. }
  96. # Set X-Permitted-Cross-Domain-Policies response header
  97. def set_xpermitted_cross_domain_policies(value: str):
  98. pattern = r'^(none|master-only|by-content-type|by-ftp-filename)$'
  99. match = re.match(pattern, value, re.IGNORECASE)
  100. if not match:
  101. value = 'none'
  102. return {
  103. 'X-Permitted-Cross-Domain-Policies': value
  104. }