Home Assistant Unofficial Reference 2024.12.1
jwt_wrapper.py
Go to the documentation of this file.
1 """Provide a wrapper around JWT that caches decoding tokens.
2 
3 Since we decode the same tokens over and over again
4 we can cache the result of the decode of valid tokens
5 to speed up the process.
6 """
7 
8 from __future__ import annotations
9 
10 from datetime import timedelta
11 from functools import lru_cache, partial
12 from typing import Any
13 
14 from jwt import DecodeError, PyJWS, PyJWT
15 
16 from homeassistant.util.json import json_loads
17 
18 JWT_TOKEN_CACHE_SIZE = 16
19 MAX_TOKEN_SIZE = 8192
20 
21 _VERIFY_KEYS = ("signature", "exp", "nbf", "iat", "aud", "iss", "sub", "jti")
22 
23 _VERIFY_OPTIONS: dict[str, Any] = {f"verify_{key}": True for key in _VERIFY_KEYS} | {
24  "require": []
25 }
26 _NO_VERIFY_OPTIONS = {f"verify_{key}": False for key in _VERIFY_KEYS}
27 
28 
29 class _PyJWSWithLoadCache(PyJWS):
30  """PyJWS with a dedicated load implementation."""
31 
32  @lru_cache(maxsize=JWT_TOKEN_CACHE_SIZE)
33  # We only ever have a global instance of this class
34  # so we do not have to worry about the LRU growing
35  # each time we create a new instance.
36  def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]:
37  """Load a JWS."""
38  return super()._load(jwt)
39 
40 
42 
43 
44 @lru_cache(maxsize=JWT_TOKEN_CACHE_SIZE)
45 def _decode_payload(json_payload: str) -> dict[str, Any]:
46  """Decode the payload from a JWS dictionary."""
47  try:
48  payload = json_loads(json_payload)
49  except ValueError as err:
50  raise DecodeError(f"Invalid payload string: {err}") from err
51  if not isinstance(payload, dict):
52  raise DecodeError("Invalid payload string: must be a json object")
53  return payload
54 
55 
56 class _PyJWTWithVerify(PyJWT):
57  """PyJWT with a fast decode implementation."""
58 
60  self, jwt: str, key: str, options: dict[str, Any], algorithms: list[str]
61  ) -> dict[str, Any]:
62  """Decode a JWT's payload."""
63  if len(jwt) > MAX_TOKEN_SIZE:
64  # Avoid caching impossible tokens
65  raise DecodeError("Token too large")
66  return _decode_payload(
67  _jws.decode_complete(
68  jwt=jwt,
69  key=key,
70  algorithms=algorithms,
71  options=options,
72  )["payload"]
73  )
74 
76  self,
77  jwt: str,
78  key: str,
79  algorithms: list[str],
80  issuer: str | None = None,
81  leeway: float | timedelta = 0,
82  options: dict[str, Any] | None = None,
83  ) -> dict[str, Any]:
84  """Verify a JWT's signature and claims."""
85  merged_options = {**_VERIFY_OPTIONS, **(options or {})}
86  payload = self.decode_payloaddecode_payload(
87  jwt=jwt,
88  key=key,
89  options=merged_options,
90  algorithms=algorithms,
91  )
92  # These should never be missing since we verify them
93  # but this is an additional safeguard to make sure
94  # nothing slips through.
95  assert "exp" in payload, "exp claim is required"
96  assert "iat" in payload, "iat claim is required"
97  self._validate_claims(
98  payload=payload,
99  options=merged_options,
100  issuer=issuer,
101  leeway=leeway,
102  )
103  return payload
104 
105 
107 verify_and_decode = _jwt.verify_and_decode
108 unverified_hs256_token_decode = lru_cache(maxsize=JWT_TOKEN_CACHE_SIZE)(
109  partial(
110  _jwt.decode_payload, key="", algorithms=["HS256"], options=_NO_VERIFY_OPTIONS
111  )
112 )
113 
114 __all__ = [
115  "unverified_hs256_token_decode",
116  "verify_and_decode",
117 ]
tuple[bytes, bytes, dict, bytes] _load(self, str|bytes jwt)
Definition: jwt_wrapper.py:36
dict[str, Any] verify_and_decode(self, str jwt, str key, list[str] algorithms, str|None issuer=None, float|timedelta leeway=0, dict[str, Any]|None options=None)
Definition: jwt_wrapper.py:83
dict[str, Any] decode_payload(self, str jwt, str key, dict[str, Any] options, list[str] algorithms)
Definition: jwt_wrapper.py:61
dict[str, Any] _decode_payload(str json_payload)
Definition: jwt_wrapper.py:45