1 """Middleware to add some basic security filtering to requests."""
3 from __future__
import annotations
5 from collections.abc
import Awaitable, Callable
6 from functools
import lru_cache
9 from typing
import Final
10 from urllib.parse
import unquote
12 from aiohttp.web
import Application, HTTPBadRequest, Request, StreamResponse, middleware
16 _LOGGER = logging.getLogger(__name__)
19 FILTERS: Final = re.compile(
24 r"|(<|%3C).*script.*(>|%3E)"
28 r"|[a-zA-Z0-9_]=/([a-z0-9_.]//?)+"
31 r"|union.*select.*\("
32 r"|union.*all.*select.*"
41 UNSAFE_URL_BYTES = [
"\t",
"\r",
"\n"]
46 """Create security filter middleware for the app."""
49 def _recursive_unquote(value: str) -> str:
50 """Handle values that are encoded multiple times."""
51 if (unquoted := unquote(value)) != value:
52 unquoted = _recursive_unquote(unquoted)
56 async
def security_filter_middleware(
57 request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
59 """Process request and block commonly known exploit attempts."""
60 path_with_query_string = f
"{request.path}?{request.query_string}"
62 for unsafe_byte
in UNSAFE_URL_BYTES:
63 if unsafe_byte
in path_with_query_string:
64 if unsafe_byte
in request.query_string:
66 "Filtered a request with unsafe byte query string: %s",
71 "Filtered a request with an unsafe byte in path: %s",
76 if FILTERS.search(_recursive_unquote(path_with_query_string)):
80 if FILTERS.search(_recursive_unquote(request.query_string)):
82 "Filtered a request with a potential harmful query string: %s",
88 "Filtered a potential harmful request to: %s", request.raw_path
92 return await handler(request)
94 app.middlewares.append(security_filter_middleware)
None setup_security_filter(Application app)