Skip to content

Commit d9c98f6

Browse files
committed
refactor(internal): parse projection params once; unify history projection; remove stale percentage
- _project_entity now accepts pre-parsed list[str]|None; callers parse fields/attribute_keys once before the entity loop (tools_search.py) - ha_get_state warns when attribute_keys is supplied but attributes is not in fields= (the keys would be ignored silently) - _fetch_history/_fetch_statistics return unwrapped inner dicts; all wrapping (add_timezone_metadata, project_fields) happens at the ha_get_history call site, consistent with every other tool - Removed stale "94% token reduction" claim from ha_get_overview docstring - Updated tests to match new signatures and shapes
1 parent f355c44 commit d9c98f6

7 files changed

Lines changed: 219 additions & 66 deletions

File tree

src/ha_mcp/tools/tools_history.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import logging
1313
import re
1414
from datetime import UTC, datetime, timedelta
15-
from typing import Annotated, Any, Literal, cast
15+
from typing import Annotated, Any, Literal
1616

1717
from fastmcp import Context
1818
from fastmcp.exceptions import ToolError
@@ -34,6 +34,7 @@
3434
build_pagination_metadata,
3535
coerce_int_param,
3636
parse_string_list_param,
37+
project_fields,
3738
)
3839

3940
logger = logging.getLogger(__name__)
@@ -339,14 +340,14 @@ async def ha_get_history(
339340

340341
try:
341342
if source == "statistics":
342-
result = await _fetch_statistics(
343-
ws_client, self._client, entity_id_list,
343+
inner = await _fetch_statistics(
344+
ws_client, entity_id_list,
344345
start_dt, end_dt, period, statistic_types,
345346
limit, offset,
346347
)
347348
else:
348-
result = await _fetch_history(
349-
ws_client, self._client, entity_id_list,
349+
inner = await _fetch_history(
350+
ws_client, entity_id_list,
350351
start_dt, end_dt, minimal_response,
351352
significant_changes_only, limit, offset,
352353
_DEFAULT_HISTORY_LIMIT, _MAX_HISTORY_LIMIT,
@@ -357,16 +358,12 @@ async def ha_get_history(
357358
total=3,
358359
message="recorder query complete",
359360
)
360-
if fields is not None:
361-
parsed = parse_string_list_param(fields, "fields", allow_csv=True) or []
362-
keep = set(parsed) | {"success"}
363-
inner = result.get("data", result)
364-
if isinstance(inner, dict):
365-
result = {
366-
**result,
367-
"data": cast(dict[str, Any], {k: v for k, v in inner.items() if k in keep}),
368-
}
369-
return result
361+
# Project BEFORE wrapping so the helper applies at the same shape
362+
# as every other tool (raw response dict). add_timezone_metadata
363+
# wraps the result in {"data": ..., "metadata": ...} which would
364+
# otherwise force a bespoke unwrap-project-rewrap site.
365+
projected = project_fields(inner, fields)
366+
return await add_timezone_metadata(self._client, projected)
370367
finally:
371368
if ws_client:
372369
await ws_client.disconnect()
@@ -464,7 +461,6 @@ def _parse_time_range(
464461

465462
async def _fetch_history(
466463
ws_client: Any,
467-
client: Any,
468464
entity_id_list: list[str],
469465
start_dt: datetime,
470466
end_dt: datetime,
@@ -475,7 +471,11 @@ async def _fetch_history(
475471
default_limit: int,
476472
max_limit: int,
477473
) -> dict[str, Any]:
478-
"""Execute the history/history_during_period WebSocket call."""
474+
"""Execute the history/history_during_period WebSocket call.
475+
476+
Returns the unwrapped history dict; the caller is responsible for projection
477+
and wrapping with ``add_timezone_metadata``.
478+
"""
479479
try:
480480
effective_limit = coerce_int_param(
481481
limit,
@@ -588,12 +588,11 @@ async def _fetch_history(
588588
},
589589
}
590590

591-
return await add_timezone_metadata(client, history_data)
591+
return history_data
592592

593593

594594
async def _fetch_statistics(
595595
ws_client: Any,
596-
client: Any,
597596
entity_id_list: list[str],
598597
start_dt: datetime,
599598
end_dt: datetime,
@@ -602,7 +601,11 @@ async def _fetch_statistics(
602601
limit: int | str | None,
603602
offset: int | str | None,
604603
) -> dict[str, Any]:
605-
"""Execute the recorder/statistics_during_period WebSocket call."""
604+
"""Execute the recorder/statistics_during_period WebSocket call.
605+
606+
Returns the unwrapped statistics dict; the caller is responsible for projection
607+
and wrapping with ``add_timezone_metadata``.
608+
"""
606609
try:
607610
effective_limit = coerce_int_param(
608611
limit,
@@ -762,4 +765,4 @@ async def _fetch_statistics(
762765
"These entities may not have state_class attribute or may not have recorded data yet."
763766
]
764767

765-
return await add_timezone_metadata(client, statistics_data)
768+
return statistics_data

src/ha_mcp/tools/tools_search.py

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -153,27 +153,30 @@ async def _exact_match_search(
153153

154154
def _project_entity(
155155
record: dict[str, Any],
156-
fields: str | list[str] | None,
157-
attribute_keys: str | list[str] | None,
156+
fields: list[str] | None,
157+
attribute_keys: list[str] | None,
158158
) -> dict[str, Any]:
159159
"""Apply optional field projection to a HA entity record.
160160
161161
``fields`` filters which top-level keys to keep (e.g. ["state", "attributes"]).
162162
``attribute_keys`` further filters the ``attributes`` sub-dict.
163163
Both default None = full payload (no-op).
164-
Accepts a list or a CSV/JSON-array string for both parameters.
164+
165+
Both parameters are already parsed into ``list[str] | None`` — string/CSV inputs
166+
must be normalised at the call site via ``parse_string_list_param`` (see
167+
``ha_get_state`` which parses once before the bulk loop to avoid re-parsing per
168+
entity record).
165169
"""
166170
if not isinstance(record, dict):
167171
return record # non-dict (e.g. error path returning None) — skip projection
168172
if fields is not None:
169-
parsed_fields = parse_string_list_param(fields, "fields", allow_csv=True) or []
170-
keep = set(parsed_fields)
173+
keep = set(fields)
171174
record = {k: v for k, v in record.items() if k in keep}
172175
if attribute_keys is not None:
173-
parsed_attr_keys = parse_string_list_param(attribute_keys, "attribute_keys", allow_csv=True) or []
174176
attrs = record.get("attributes")
175177
if isinstance(attrs, dict):
176-
record = {**record, "attributes": {k: v for k, v in attrs.items() if k in parsed_attr_keys}}
178+
attr_keep = set(attribute_keys)
179+
record = {**record, "attributes": {k: v for k, v in attrs.items() if k in attr_keep}}
177180
return record
178181

179182

@@ -920,8 +923,9 @@ async def ha_get_overview(
920923
Standard/full modes paginate entities (default 200 per page) — use offset
921924
to fetch more. Use 'domains' filter to narrow scope.
922925
923-
Use fields= to project the response to only the keys you need — up to 94%
924-
token reduction when fetching a single sub-section (e.g. fields=["system_info"]).
926+
Use fields= to project the response to only the keys you need — a
927+
significantly smaller payload when fetching a single sub-section (e.g.
928+
fields=["system_info"] returns just that section instead of the full overview).
925929
"""
926930
# Validate fields= early so a malformed value returns VALIDATION_INVALID_PARAMETER
927931
# (ha_get_overview has no outer try/except, so ValueError would escape uncaught)
@@ -1250,17 +1254,61 @@ async def ha_get_state(
12501254
Returns success=True if at least one entity state was retrieved.
12511255
Check 'error_count' for any failed lookups in partial-success scenarios.
12521256
1257+
FIELDS PROJECTION:
1258+
`fields=` projects the per-entity record (`entity_id`, `state`, `attributes`,
1259+
`last_changed`, `last_updated`, `context`), NOT the outer bulk response wrapper.
1260+
In single-entity mode it filters keys of the returned record directly. In bulk
1261+
mode it filters keys of each record inside `states[entity_id]`; outer keys
1262+
(`success`, `count`, `states`, `errors`, ...) are always preserved.
1263+
`attribute_keys=` further narrows the `attributes` sub-dict and is only applied
1264+
when `"attributes"` is in `fields=` (or `fields=None`); otherwise it is a no-op.
1265+
12531266
EXAMPLES:
12541267
- Single: ha_get_state("light.kitchen")
12551268
- Multiple: ha_get_state(["light.kitchen", "light.living_room", "sensor.temperature"])
12561269
- State only: ha_get_state("light.kitchen", fields=["state"])
12571270
- Slim bulk: ha_get_state(["light.kitchen", "sensor.temperature"], fields=["state", "attributes"], attribute_keys=["brightness"])
12581271
"""
1272+
# Parse projection params once up front so the bulk loop doesn't re-parse
1273+
# the same string/CSV input per entity (100 entities → 200 parses pre-fix).
1274+
# parse_string_list_param raises ValueError on bad input; surface as
1275+
# VALIDATION_INVALID_PARAMETER via the normal ToolError flow.
1276+
try:
1277+
parsed_fields = parse_string_list_param(fields, "fields", allow_csv=True)
1278+
parsed_attribute_keys = parse_string_list_param(
1279+
attribute_keys, "attribute_keys", allow_csv=True
1280+
)
1281+
except ValueError as e:
1282+
raise_tool_error(
1283+
create_validation_error(
1284+
str(e),
1285+
parameter=(
1286+
"attribute_keys" if "attribute_keys" in str(e) else "fields"
1287+
),
1288+
)
1289+
)
1290+
1291+
# `attribute_keys` only takes effect when `attributes` is in the projected
1292+
# field set (or `fields=None`). Surface a warning rather than silently
1293+
# ignoring it — caller likely intended to slim attributes and would
1294+
# otherwise see an unfiltered or absent `attributes` key with no signal.
1295+
attribute_keys_no_effect = (
1296+
parsed_attribute_keys is not None
1297+
and parsed_fields is not None
1298+
and "attributes" not in parsed_fields
1299+
)
1300+
12591301
# Single entity path
12601302
if isinstance(entity_id, str):
12611303
try:
12621304
result = await client.get_entity_state(entity_id)
1263-
result = _project_entity(result, fields, attribute_keys)
1305+
result = _project_entity(result, parsed_fields, parsed_attribute_keys)
1306+
if attribute_keys_no_effect and isinstance(result, dict):
1307+
result["warning"] = (
1308+
"attribute_keys was ignored because 'attributes' is not in "
1309+
"fields=. Add 'attributes' to fields= (or omit fields=) to "
1310+
"apply attribute_keys."
1311+
)
12641312
return await add_timezone_metadata(client, result)
12651313
except ToolError:
12661314
raise
@@ -1332,7 +1380,9 @@ async def _fetch_state(eid: str) -> dict[str, Any]:
13321380

13331381
for eid, result in zip(unique_ids, results, strict=True):
13341382
if result.get("success") is True and "state" in result:
1335-
states[eid] = _project_entity(result["state"], fields, attribute_keys)
1383+
states[eid] = _project_entity(
1384+
result["state"], parsed_fields, parsed_attribute_keys
1385+
)
13361386
else:
13371387
error_detail = result.get("error")
13381388
if error_detail is None:
@@ -1353,6 +1403,13 @@ async def _fetch_state(eid: str) -> dict[str, Any]:
13531403
"states": states,
13541404
}
13551405

1406+
if attribute_keys_no_effect:
1407+
response["warning"] = (
1408+
"attribute_keys was ignored because 'attributes' is not in "
1409+
"fields=. Add 'attributes' to fields= (or omit fields=) to "
1410+
"apply attribute_keys."
1411+
)
1412+
13561413
if errors:
13571414
response["errors"] = errors
13581415
response["error_count"] = len(errors)

tests/src/unit/test_context_injection.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,10 @@ async def test_ha_get_history_works_without_ctx() -> None:
140140
):
141141
result = await history_tool(entity_ids="sensor.test")
142142

143-
assert result is fake_result
143+
# ha_get_history wraps the inner _fetch_history result via add_timezone_metadata
144+
# — the inner payload must round-trip unchanged under fields=None.
145+
assert result["data"] == fake_result
146+
assert "metadata" in result
144147
fake_ws.disconnect.assert_awaited_once()
145148

146149

@@ -167,7 +170,8 @@ async def test_ha_get_history_emits_progress_with_ctx() -> None:
167170
):
168171
result = await history_tool(entity_ids="sensor.test", ctx=ctx)
169172

170-
assert result is fake_result
173+
assert result["data"] == fake_result
174+
assert "metadata" in result
171175
ctx.info.assert_awaited()
172176
# Three events: connect, query dispatch, completion (progress jumps 1 -> 3).
173177
assert ctx.report_progress.await_count == 3
@@ -526,7 +530,8 @@ async def test_ha_get_history_statistics_emits_progress() -> None:
526530
entity_ids="sensor.test", source="statistics", period="day", ctx=ctx
527531
)
528532

529-
assert result is fake_result
533+
assert result["data"] == fake_result
534+
assert "metadata" in result
530535
assert ctx.report_progress.await_count == 3
531536
messages = _progress_messages(ctx)
532537
assert "querying recorder (statistics)" in messages[1]

tests/src/unit/test_get_state_fields_projection.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,8 @@ def test_does_not_mutate_original(self):
6666
assert "last_changed" in original
6767
assert original["attributes"]["color_temp"] == 3500
6868

69-
def test_fields_csv_string_input(self):
70-
result = _project_entity(dict(_ENTITY_RECORD), "state,entity_id", None)
71-
assert set(result.keys()) == {"state", "entity_id"}
72-
73-
def test_fields_json_array_string_input(self):
74-
result = _project_entity(dict(_ENTITY_RECORD), '["state", "attributes"]', ["brightness"])
75-
assert set(result.keys()) == {"state", "attributes"}
76-
assert result["attributes"] == {"brightness": 200}
77-
78-
def test_attribute_keys_csv_string_input(self):
79-
result = _project_entity(dict(_ENTITY_RECORD), None, "brightness,color_temp")
80-
assert set(result["attributes"].keys()) == {"brightness", "color_temp"}
69+
def test_non_dict_record_returned_unchanged(self):
70+
# Defensive: error paths may pass None/non-dict; helper must not raise.
71+
assert _project_entity(None, ["state"], None) is None # type: ignore[arg-type]
8172

8273

tests/src/unit/test_overview_system_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class TestHaGetOverviewFieldsProjection:
101101
102102
Pins the contract from issue #1199: callers that only need one section
103103
(e.g. system_info) can request it via fields= and receive a response
104-
that omits all other top-level keys, reducing token usage by up to 94%.
104+
that omits all other top-level keys.
105105
"""
106106

107107
@pytest.fixture

0 commit comments

Comments
 (0)