Skip to content

Commit b2c9fde

Browse files
committed
remove websocket stuff
1 parent c64813c commit b2c9fde

5 files changed

Lines changed: 52 additions & 190 deletions

File tree

homeassistant/auth/jwt_wrapper.py

Lines changed: 44 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,23 @@
77

88
from __future__ import annotations
99

10-
from collections.abc import Container, Iterable, Sequence
1110
from datetime import timedelta
12-
from functools import lru_cache
13-
from typing import Any, override
11+
from functools import lru_cache, partial
12+
from typing import Any
1413

15-
from jwt import DecodeError, PyJWK, PyJWS, PyJWT
16-
from jwt.algorithms import AllowedPublicKeys
17-
from jwt.types import Options
14+
from jwt import DecodeError, PyJWS, PyJWT
1815

1916
from homeassistant.util.json import json_loads
2017

2118
JWT_TOKEN_CACHE_SIZE = 16
2219
MAX_TOKEN_SIZE = 8192
2320

24-
_NO_VERIFY_OPTIONS = Options(
25-
verify_signature=False,
26-
verify_exp=False,
27-
verify_nbf=False,
28-
verify_iat=False,
29-
verify_aud=False,
30-
verify_iss=False,
31-
verify_sub=False,
32-
verify_jti=False,
33-
require=[],
34-
)
21+
_VERIFY_KEYS = ("signature", "exp", "nbf", "iat", "aud", "iss", "sub", "jti")
22+
23+
_VERIFY_OPTIONS: dict[str, Any] = {f"verify_{key}": True for key in _VERIFY_KEYS} | {
24+
"require": []
25+
}
26+
_NO_VERIFY_OPTIONS = {f"verify_{key}": False for key in _VERIFY_KEYS}
3527

3628

3729
class _PyJWSWithLoadCache(PyJWS):
@@ -46,6 +38,9 @@ def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]:
4638
return super()._load(jwt)
4739

4840

41+
_jws = _PyJWSWithLoadCache()
42+
43+
4944
@lru_cache(maxsize=JWT_TOKEN_CACHE_SIZE)
5045
def _decode_payload(json_payload: str) -> dict[str, Any]:
5146
"""Decode the payload from a JWS dictionary."""
@@ -61,12 +56,21 @@ def _decode_payload(json_payload: str) -> dict[str, Any]:
6156
class _PyJWTWithVerify(PyJWT):
6257
"""PyJWT with a fast decode implementation."""
6358

64-
def __init__(self) -> None:
65-
"""Initialize the PyJWT instance."""
66-
# We require exp and iat claims to be present
67-
super().__init__(Options(require=["exp", "iat"]))
68-
# Override the _jws instance with our cached version
69-
self._jws = _PyJWSWithLoadCache()
59+
def decode_payload(
60+
self, jwt: str, key: str, options: dict[str, Any], algorithms: list[str]
61+
) -> dict[str, Any]:
62+
"""Decode a JWT's payload."""
63+
if len(jwt) > MAX_TOKEN_SIZE:
64+
# Avoid caching impossible tokens
65+
raise DecodeError("Token too large")
66+
return _decode_payload(
67+
_jws.decode_complete(
68+
jwt=jwt,
69+
key=key,
70+
algorithms=algorithms,
71+
options=options,
72+
)["payload"]
73+
)
7074

7175
def verify_and_decode(
7276
self,
@@ -75,70 +79,37 @@ def verify_and_decode(
7579
algorithms: list[str],
7680
issuer: str | None = None,
7781
leeway: float | timedelta = 0,
78-
options: Options | None = None,
82+
options: dict[str, Any] | None = None,
7983
) -> dict[str, Any]:
8084
"""Verify a JWT's signature and claims."""
81-
return self.decode(
85+
merged_options = {**_VERIFY_OPTIONS, **(options or {})}
86+
payload = self.decode_payload(
8287
jwt=jwt,
8388
key=key,
89+
options=merged_options,
8490
algorithms=algorithms,
85-
issuer=issuer,
86-
leeway=leeway,
87-
options=options,
8891
)
89-
90-
@override
91-
def decode(
92-
self,
93-
jwt: str | bytes,
94-
key: AllowedPublicKeys | PyJWK | str | bytes = "",
95-
algorithms: Sequence[str] | None = None,
96-
options: Options | None = None,
97-
verify: bool | None = None,
98-
detached_payload: bytes | None = None,
99-
audience: str | Iterable[str] | None = None,
100-
subject: str | None = None,
101-
issuer: str | Container[str] | None = None,
102-
leeway: float | timedelta = 0,
103-
**kwargs: Any,
104-
) -> dict[str, Any]:
105-
"""Decode a JWT, verifying the signature and claims."""
106-
if len(jwt) > MAX_TOKEN_SIZE:
107-
# Avoid caching impossible tokens
108-
raise DecodeError("Token too large")
109-
return super().decode(
110-
jwt=jwt,
111-
key=key,
112-
algorithms=algorithms,
113-
options=options,
114-
verify=verify,
115-
detached_payload=detached_payload,
116-
audience=audience,
117-
subject=subject,
92+
# These should never be missing since we verify them
93+
# but this is an additional safeguard to make sure
94+
# nothing slips through.
95+
assert "exp" in payload, "exp claim is required"
96+
assert "iat" in payload, "iat claim is required"
97+
self._validate_claims(
98+
payload=payload,
99+
options=merged_options,
118100
issuer=issuer,
119101
leeway=leeway,
120-
**kwargs,
121102
)
122-
123-
@override
124-
def _decode_payload(self, decoded: dict[str, Any]) -> dict[str, Any]:
125-
return _decode_payload(decoded["payload"])
103+
return payload
126104

127105

128106
_jwt = _PyJWTWithVerify()
129107
verify_and_decode = _jwt.verify_and_decode
130-
131-
132-
@lru_cache(maxsize=JWT_TOKEN_CACHE_SIZE)
133-
def unverified_hs256_token_decode(jwt: str) -> dict[str, Any]:
134-
"""Decode a JWT without verifying the signature."""
135-
return _jwt.decode(
136-
jwt=jwt,
137-
key="",
138-
algorithms=["HS256"],
139-
options=_NO_VERIFY_OPTIONS,
108+
unverified_hs256_token_decode = lru_cache(maxsize=JWT_TOKEN_CACHE_SIZE)(
109+
partial(
110+
_jwt.decode_payload, key="", algorithms=["HS256"], options=_NO_VERIFY_OPTIONS
140111
)
141-
112+
)
142113

143114
__all__ = [
144115
"unverified_hs256_token_decode",

homeassistant/components/backup/util.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
SecureTarFile,
2323
SecureTarReadError,
2424
SecureTarRootKeyContext,
25-
get_archive_max_ciphertext_size,
2625
)
2726

2827
from homeassistant.core import HomeAssistant
@@ -384,12 +383,9 @@ def _encrypt_backup(
384383
if prefix not in expected_archives:
385384
LOGGER.debug("Unknown inner tar file %s will not be encrypted", obj.name)
386385
continue
387-
if (fileobj := input_tar.extractfile(obj)) is None:
388-
LOGGER.debug(
389-
"Non regular inner tar file %s will not be encrypted", obj.name
390-
)
391-
continue
392-
output_archive.import_tar(fileobj, obj, derived_key_id=inner_tar_idx)
386+
output_archive.import_tar(
387+
input_tar.extractfile(obj), obj, derived_key_id=inner_tar_idx
388+
)
393389
inner_tar_idx += 1
394390

395391

@@ -423,7 +419,7 @@ def __init__(
423419
hass: HomeAssistant,
424420
backup: AgentBackup,
425421
open_stream: Callable[[], Coroutine[Any, Any, AsyncIterator[bytes]]],
426-
password: str,
422+
password: str | None,
427423
) -> None:
428424
"""Initialize."""
429425
self._workers: list[_CipherWorkerStatus] = []
@@ -435,9 +431,7 @@ def __init__(
435431

436432
def size(self) -> int:
437433
"""Return the maximum size of the decrypted or encrypted backup."""
438-
return get_archive_max_ciphertext_size(
439-
self._backup.size, SECURETAR_CREATE_VERSION, self._num_tar_files()
440-
)
434+
return self._backup.size + self._num_tar_files() * tarfile.RECORDSIZE
441435

442436
def _num_tar_files(self) -> int:
443437
"""Return the number of inner tar files."""

homeassistant/components/weather/websocket_api.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
"daily": WeatherEntityFeature.FORECAST_DAILY,
1818
"hourly": WeatherEntityFeature.FORECAST_HOURLY,
1919
"twice_daily": WeatherEntityFeature.FORECAST_TWICE_DAILY,
20-
"minutely": WeatherEntityFeature.FORECAST_MINUTELY,
2120
}
2221

2322

@@ -48,7 +47,7 @@ def ws_convertible_units(
4847
{
4948
vol.Required("type"): "weather/subscribe_forecast",
5049
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
51-
vol.Required("forecast_type"): vol.In(FORECAST_TYPE_TO_FLAG),
50+
vol.Required("forecast_type"): vol.In(["daily", "hourly", "twice_daily"]),
5251
}
5352
)
5453
@websocket_api.async_response
@@ -57,9 +56,7 @@ async def ws_subscribe_forecast(
5756
) -> None:
5857
"""Subscribe to weather forecasts."""
5958
entity_id: str = msg["entity_id"]
60-
forecast_type: Literal["daily", "hourly", "twice_daily", "minutely"] = msg[
61-
"forecast_type"
62-
]
59+
forecast_type: Literal["daily", "hourly", "twice_daily"] = msg["forecast_type"]
6360

6461
if not (entity := hass.data[DATA_COMPONENT].get_entity(msg["entity_id"])):
6562
connection.send_error(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ per-file-ignores = [
418418
# redefined-outer-name: Tests reference fixtures in the test function
419419
# use-implicit-booleaness-not-comparison: Tests need to validate that a list
420420
# or a dict is returned
421-
"tests/**:redefined-outer-name,use-implicit-booleaness-not-comparison",
421+
"/tests/:redefined-outer-name,use-implicit-booleaness-not-comparison",
422422
]
423423

424424
[tool.pylint.REPORTS]

tests/components/weather/test_websocket_api.py

Lines changed: 0 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -159,103 +159,3 @@ class MockWeatherMock(MockWeatherTest):
159159
"code": "forecast_not_supported",
160160
"message": "The weather entity does not support forecast type: daily",
161161
}
162-
163-
164-
async def test_subscribe_forecast_minutely(
165-
hass: HomeAssistant,
166-
hass_ws_client: WebSocketGenerator,
167-
config_flow_fixture: None,
168-
) -> None:
169-
"""Test subscribing to minutely forecast via websocket."""
170-
171-
class MockWeatherMockMinutelyForecast(MockWeatherTest):
172-
"""Mock weather class."""
173-
174-
async def async_forecast_minutely(self) -> list[Forecast] | None:
175-
"""Return the forecast_minutely."""
176-
return self.forecast_list
177-
178-
kwargs = {
179-
"native_temperature": 38,
180-
"native_temperature_unit": UnitOfTemperature.CELSIUS,
181-
"supported_features": WeatherEntityFeature.FORECAST_MINUTELY,
182-
}
183-
weather_entity = await create_entity(
184-
hass, MockWeatherMockMinutelyForecast, None, **kwargs
185-
)
186-
187-
client = await hass_ws_client(hass)
188-
189-
await client.send_json_auto_id(
190-
{
191-
"type": "weather/subscribe_forecast",
192-
"forecast_type": "minutely",
193-
"entity_id": weather_entity.entity_id,
194-
}
195-
)
196-
msg = await client.receive_json()
197-
assert msg["success"]
198-
assert msg["result"] is None
199-
subscription_id = msg["id"]
200-
201-
msg = await client.receive_json()
202-
assert msg["id"] == subscription_id
203-
assert msg["type"] == "event"
204-
forecast = msg["event"]
205-
assert forecast == {
206-
"type": "minutely",
207-
"forecast": [
208-
{
209-
"cloud_coverage": None,
210-
"temperature": 38.0,
211-
"templow": 38.0,
212-
"uv_index": None,
213-
"wind_bearing": None,
214-
}
215-
],
216-
}
217-
218-
await weather_entity.async_update_listeners(None)
219-
msg = await client.receive_json()
220-
assert msg["event"] == forecast
221-
222-
await weather_entity.async_update_listeners(["minutely"])
223-
msg = await client.receive_json()
224-
assert msg["event"] == forecast
225-
226-
weather_entity.forecast_list = None
227-
await weather_entity.async_update_listeners(None)
228-
msg = await client.receive_json()
229-
assert msg["event"] == {"type": "minutely", "forecast": None}
230-
231-
232-
async def test_subscribe_forecast_minutely_unsupported(
233-
hass: HomeAssistant,
234-
hass_ws_client: WebSocketGenerator,
235-
config_flow_fixture: None,
236-
) -> None:
237-
"""Test subscribing to minutely forecast when not supported."""
238-
239-
class MockWeatherMock(MockWeatherTest):
240-
"""Mock weather class."""
241-
242-
kwargs = {
243-
"native_temperature": 38,
244-
"native_temperature_unit": UnitOfTemperature.CELSIUS,
245-
}
246-
weather_entity = await create_entity(hass, MockWeatherMock, None, **kwargs)
247-
client = await hass_ws_client(hass)
248-
249-
await client.send_json_auto_id(
250-
{
251-
"type": "weather/subscribe_forecast",
252-
"forecast_type": "minutely",
253-
"entity_id": weather_entity.entity_id,
254-
}
255-
)
256-
msg = await client.receive_json()
257-
assert not msg["success"]
258-
assert msg["error"] == {
259-
"code": "forecast_not_supported",
260-
"message": "The weather entity does not support forecast type: minutely",
261-
}

0 commit comments

Comments
 (0)