Skip to content

Commit eb3edb6

Browse files
authored
Merge pull request #170 from lsst-dm/tickets/DM-48808/typing_fixes
DM-48808 : Typing Fixes
2 parents 71b4107 + 00a0436 commit eb3edb6

File tree

13 files changed

+87
-57
lines changed

13 files changed

+87
-57
lines changed

src/lsst/cmservice/common/panda.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def refresh_panda_token(url: str, data: dict[str, str]) -> str | None:
7171
os.environ["PANDA_AUTH_ID_TOKEN"] = config.panda.id_token
7272
# - update token expiry
7373
decoded_token = decode_id_token(config.panda.id_token)
74-
config.panda.token_expiry = float(decoded_token["exp"]) # type: ignore
74+
config.panda.token_expiry = float(decoded_token["exp"]) # type: ignore[assignment]
7575
if TYPE_CHECKING:
7676
# the validation machinery of the pyantic field handles conversion
7777
# from float to datetime.
@@ -126,7 +126,7 @@ def get_panda_token() -> str | None:
126126
try:
127127
if config.panda.token_expiry is None:
128128
decoded_token = decode_id_token(config.panda.id_token)
129-
config.panda.token_expiry = float(decoded_token["exp"]) # type: ignore
129+
config.panda.token_expiry = float(decoded_token["exp"]) # type: ignore[assignment]
130130
if TYPE_CHECKING:
131131
# the validation machinery of the pyantic field handles conversion
132132
# from float to datetime.

src/lsst/cmservice/db/row.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,24 @@ async def get_rows(
8383
parent_class = kwargs.get("parent_class")
8484

8585
q = select(cls)
86-
# FIXME: Being a mixin leads to loose typing here.
87-
# Is there a better way?
8886
if hasattr(cls, "parent_id"):
87+
# FIXME All of these tests assert that parent_class is not None
88+
# otherwise it would raise an AttributeError; the constraint
89+
# might as well be satisfied by the first id matching without
90+
# also evaluating additional options. If the gimmick is that
91+
# the method can be invoked with one of parent_id, _name, or
92+
# the parent class, it is invalidated by the fact that
93+
# parent_class is used in all three cases, so defining any of
94+
# the others adds no value.
95+
if TYPE_CHECKING:
96+
assert parent_class is not None
8997
if parent_class is not None:
90-
q = q.where(parent_class.id == cls.parent_id) # type: ignore
98+
parent_id_ = getattr(cls, "parent_id")
99+
q = q.where(parent_class.id == parent_id_)
91100
if parent_name is not None:
92-
q = q.where(parent_class.fullname == parent_name) # type: ignore
101+
q = q.where(parent_class.fullname == parent_name)
93102
if parent_id is not None:
94-
q = q.where(parent_class.id == parent_id) # type: ignore
103+
q = q.where(parent_class.id == parent_id)
95104
q = q.offset(skip).limit(limit)
96105
results = await session.scalars(q)
97106
return results.all()

src/lsst/cmservice/handlers/jobs.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ async def _write_script(
6868
parent: ElementMixin,
6969
**kwargs: Any,
7070
) -> StatusEnum:
71+
if TYPE_CHECKING:
72+
assert isinstance(parent, Job)
7173
# Database operations
7274
await session.refresh(parent, attribute_names=["c_"])
7375
data_dict = await script.data_dict(session)
@@ -90,7 +92,7 @@ async def _write_script(
9092
# yaml template, NOT the yaml template itself!
9193
workflow_config: dict[str, Any] = {}
9294
workflow_config["project"] = "DEFAULT"
93-
workflow_config["campaign"] = parent.c_.name # type: ignore
95+
workflow_config["campaign"] = parent.c_.name
9496
workflow_config["pipeline_yaml"] = pipeline_yaml
9597
workflow_config["lsst_version"] = lsst_version
9698
workflow_config["lsst_distrib_dir"] = lsst_distrib_dir
@@ -170,7 +172,7 @@ async def _write_script(
170172
in_collection = input_colls
171173

172174
payload = {
173-
"name": parent.c_.name, # type: ignore
175+
"name": parent.c_.name,
174176
"butler_config": butler_repo,
175177
"output_run_collection": run_coll,
176178
"input_collection": in_collection,
@@ -216,7 +218,9 @@ async def _check_slurm_job(
216218
if fake_status is not None:
217219
wms_job_id = "fake_job"
218220
else: # pragma: no cover
219-
bps_dict = await parse_bps_stdout(script.log_url) # type: ignore
221+
if TYPE_CHECKING:
222+
assert script.log_url is not None
223+
bps_dict = await parse_bps_stdout(script.log_url)
220224
wms_job_id = self.get_job_id(bps_dict)
221225
await parent.update_values(session, wms_job_id=wms_job_id)
222226
return slurm_status
@@ -243,7 +247,9 @@ async def _check_htcondor_job(
243247
if fake_status is not None:
244248
wms_job_id = "fake_job"
245249
else: # pragma: no cover
246-
bps_dict = await parse_bps_stdout(script.log_url) # type: ignore
250+
if TYPE_CHECKING:
251+
assert script.log_url is not None
252+
bps_dict = await parse_bps_stdout(script.log_url)
247253
wms_job_id = self.get_job_id(bps_dict)
248254
await parent.update_values(session, wms_job_id=wms_job_id)
249255
return htcondor_status

src/lsst/cmservice/handlers/scripts.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import os
44
import textwrap
5-
from typing import Any
5+
from typing import TYPE_CHECKING, Any
66

77
from anyio import Path
88
from sqlalchemy.ext.asyncio import async_scoped_session
@@ -470,6 +470,8 @@ async def _write_script(
470470
**kwargs: Any,
471471
) -> StatusEnum:
472472
test_type_and_raise(parent, Step, "PrepareStepScriptHandler._write_script parent")
473+
if TYPE_CHECKING:
474+
assert isinstance(parent, Step)
473475

474476
resolved_cols = await script.resolve_collections(session)
475477
data_dict = await script.data_dict(session)
@@ -482,7 +484,7 @@ async def _write_script(
482484

483485
prereq_colls: list[str] = []
484486

485-
all_prereqs = await parent.get_all_prereqs(session) # type: ignore
487+
all_prereqs = await parent.get_all_prereqs(session)
486488
for prereq_step in all_prereqs:
487489
prereq_step_colls = await prereq_step.resolve_collections(session)
488490
prereq_colls.append(prereq_step_colls["step_public_output"])

src/lsst/cmservice/web_app/pages/campaigns.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Sequence
2+
from typing import TYPE_CHECKING
23

34
from sqlalchemy import select
45
from sqlalchemy.ext.asyncio import async_scoped_session
@@ -9,6 +10,8 @@
910

1011

1112
async def get_campaign_details(session: async_scoped_session, campaign: Campaign) -> dict:
13+
if TYPE_CHECKING:
14+
assert isinstance(campaign.data, dict)
1215
collections = await campaign.resolve_collections(session, throw_overrides=False)
1316
groups = await get_campaign_groups(session, campaign)
1417
no_groups_completed = len([group for group in groups if group.status == StatusEnum.accepted])
@@ -26,7 +29,7 @@ async def get_campaign_details(session: async_scoped_session, campaign: Campaign
2629
"name": campaign.name,
2730
"production_name": campaign.fullname.split("/")[0],
2831
"fullname": campaign.fullname,
29-
"lsst_version": campaign.data["lsst_version"], # type: ignore
32+
"lsst_version": campaign.data["lsst_version"],
3033
"source": collections.get("campaign_source", ""),
3134
"status": map_status(campaign.status),
3235
"groups_completed": f"{no_groups_completed} of {len(groups)} groups completed",

tests/cli/test_others.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,21 @@ async def test_others_cli(uvicorn: UvicornProcess, api_version: str) -> None:
2222
assert len(check) == 0
2323

2424
result = runner.invoke(client_top, "product_set list --output yaml")
25-
check = check_and_parse_result(result, list[models.ProductSet]) # type: ignore
26-
assert len(check) == 0
25+
products = check_and_parse_result(result, list[models.ProductSet])
26+
assert len(products) == 0
2727

2828
result = runner.invoke(client_top, "script_dependency list --output yaml")
29-
check = check_and_parse_result(result, list[models.Dependency]) # type: ignore
30-
assert len(check) == 0
29+
dependencies = check_and_parse_result(result, list[models.Dependency])
30+
assert len(dependencies) == 0
3131

3232
result = runner.invoke(client_top, "script_error list --output yaml")
33-
check = check_and_parse_result(result, list[models.ScriptError]) # type: ignore
34-
assert len(check) == 0
33+
scripts = check_and_parse_result(result, list[models.ScriptError])
34+
assert len(scripts) == 0
3535

3636
result = runner.invoke(client_top, "task_set list --output yaml")
37-
check = check_and_parse_result(result, list[models.TaskSet]) # type: ignore
38-
assert len(check) == 0
37+
tasks = check_and_parse_result(result, list[models.TaskSet])
38+
assert len(tasks) == 0
3939

4040
result = runner.invoke(client_top, "wms_task_report list --output yaml")
41-
check = check_and_parse_result(result, list[models.WmsTaskReport]) # type: ignore
42-
assert len(check) == 0
41+
reports = check_and_parse_result(result, list[models.WmsTaskReport])
42+
assert len(reports) == 0

tests/cli/util_functions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from lsst.cmservice.common.enums import LevelEnum, StatusEnum
1111

1212
T = TypeVar("T")
13+
E = TypeVar("E", models.Group, models.Campaign, models.Step, models.Job)
1314

1415

1516
def check_and_parse_result(
@@ -412,15 +413,15 @@ def check_scripts(
412413
def check_get_methods(
413414
runner: CliRunner,
414415
client_top: BaseCommand,
415-
entry: models.ElementMixin,
416+
entry: E,
416417
entry_class_name: str,
417418
entry_class: TypeAlias = models.ElementMixin,
418419
) -> None:
419420
result = runner.invoke(client_top, f"{entry_class_name} get all --output yaml --row_id {entry.id}")
420421
check_get = check_and_parse_result(result, entry_class)
421422

422423
assert check_get.id == entry.id, "pulled row should be identical"
423-
assert check_get.level == entry.level, "pulled row db_id should be identical" # type: ignore
424+
assert check_get.level == entry.level, "pulled row db_id should be identical"
424425

425426
result = runner.invoke(client_top, f"{entry_class_name} get by_name --output yaml --name {entry.name}")
426427
check_get = check_and_parse_result(result, entry_class)

tests/common/test_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_config_datetime() -> None:
7474
assert config.panda.token_expiry is None
7575

7676
# test validation and coercion on assignment
77-
config.panda.token_expiry = 1740147265 # type: ignore
77+
config.panda.token_expiry = 1740147265 # type: ignore[assignment]
7878
assert isinstance(config.panda.token_expiry, datetime)
7979
assert config.panda.token_expiry.tzinfo is UTC
8080

tests/db/test_group.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ async def test_group_db(engine: AsyncEngine) -> None:
6666
await db.Group.delete_row(session, -99)
6767

6868
# run group specific method tests
69-
check = await entry.get_campaign(session)
70-
assert check.name == f"camp0_{uuid_int}", "should return same name as camp0"
69+
campaign = await entry.get_campaign(session)
70+
assert campaign.name == f"camp0_{uuid_int}", "should return same name as camp0"
7171

72-
check = await entry.children(session) # type: ignore
73-
assert len(list(check)) == 1, "length of children should be 1" # type: ignore
72+
children = await entry.children(session)
73+
assert len(list(children)) == 1, "length of children should be 1"
7474

7575
# check update methods
7676
await check_update_methods(session, entry, db.Group)

tests/db/test_job.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import importlib
22
import os
33
import uuid
4+
from typing import TYPE_CHECKING
45

56
import pytest
67
import structlog
@@ -62,6 +63,8 @@ async def test_job_db(engine: AsyncEngine) -> None:
6263
entry = check_getall[0] # defining single unit for later
6364

6465
parent = await entry.get_parent(session)
66+
if TYPE_CHECKING:
67+
assert isinstance(parent, db.Group)
6568

6669
await db.Job.update_row(session, entry.id, status=StatusEnum.running)
6770
sleep_time = await parent.estimate_sleep_time(session)
@@ -77,11 +80,11 @@ async def test_job_db(engine: AsyncEngine) -> None:
7780
campaign = await entry.get_campaign(session)
7881
assert campaign.name == f"camp0_{uuid_int}", "should return same name as camp0"
7982

80-
check = await entry.get_siblings(session)
81-
assert len(list(check)) == 0, "length of siblings should be 0"
83+
siblings = await entry.get_siblings(session)
84+
assert len(list(siblings)) == 0, "length of siblings should be 0"
8285

83-
check = await entry.get_errors(session) # type: ignore
84-
assert len(check) == 0, "length of errors should be 0"
86+
errors_ = await entry.get_errors(session)
87+
assert len(errors_) == 0, "length of errors should be 0"
8588

8689
sleep_time = await campaign.estimate_sleep_time(session)
8790
assert sleep_time == 10, "Wrong sleep time"
@@ -94,42 +97,42 @@ async def test_job_db(engine: AsyncEngine) -> None:
9497

9598
# check on the rescue job
9699
with pytest.raises(errors.CMTooFewAcceptedJobsError):
97-
await parent.rescue_job(session) # type: ignore
100+
await parent.rescue_job(session)
98101

99102
await db.Job.update_row(session, entry.id, status=StatusEnum.rescuable)
100-
job2 = await parent.rescue_job(session) # type: ignore
103+
job2 = await parent.rescue_job(session)
101104

102105
with pytest.raises(errors.CMBadStateTransitionError):
103-
await parent.mark_job_rescued(session) # type: ignore
106+
await parent.mark_job_rescued(session)
104107

105108
await db.Job.update_row(session, entry.id, status=StatusEnum.rescuable)
106109
with pytest.raises(errors.CMBadStateTransitionError):
107-
await parent.mark_job_rescued(session) # type: ignore
110+
await parent.mark_job_rescued(session)
108111

109112
await db.Job.update_row(session, entry.id, status=StatusEnum.rescuable)
110113
await db.Job.update_row(session, job2.id, status=StatusEnum.accepted)
111114

112-
rescued = await parent.mark_job_rescued(session) # type: ignore
115+
rescued = await parent.mark_job_rescued(session)
113116
assert len(rescued) == 1, "Wrong number of rescued jobs"
114117

115118
await db.Job.update_row(session, entry.id, status=StatusEnum.accepted)
116119
await db.Job.update_row(session, job2.id, status=StatusEnum.accepted)
117120
with pytest.raises(errors.CMTooManyActiveScriptsError):
118-
await parent.mark_job_rescued(session) # type: ignore
121+
await parent.mark_job_rescued(session)
119122

120123
await db.Job.update_row(session, entry.id, status=StatusEnum.rescuable)
121124
await db.Job.update_row(session, job2.id, status=StatusEnum.rescuable)
122125

123-
job3 = await parent.rescue_job(session) # type: ignore
126+
job3 = await parent.rescue_job(session)
124127

125128
await db.Job.update_row(session, entry.id, status=StatusEnum.rescued)
126129
await db.Job.update_row(session, job2.id, status=StatusEnum.failed, superseded=True)
127130
await db.Job.update_row(session, job3.id, status=StatusEnum.rescuable)
128131

129132
with pytest.raises(errors.CMTooFewAcceptedJobsError):
130-
await parent.mark_job_rescued(session) # type: ignore
133+
await parent.mark_job_rescued(session)
131134

132-
job4 = await parent.rescue_job(session) # type: ignore
135+
job4 = await parent.rescue_job(session)
133136
await db.Job.update_row(session, job4.id, status=StatusEnum.accepted)
134137

135138
rescued = await interface.mark_job_rescued(session, parent.fullname)

0 commit comments

Comments
 (0)