Skip to content

Commit 5943450

Browse files
committed
Improve type annotations
1 parent 75f1d2a commit 5943450

File tree

3 files changed

+57
-21
lines changed

3 files changed

+57
-21
lines changed

lib/galaxy/managers/jobs.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
cast,
1313
Optional,
1414
TYPE_CHECKING,
15+
TypeVar,
1516
Union,
1617
)
1718

@@ -117,7 +118,11 @@
117118
from galaxy.work.context import WorkRequestContext
118119

119120
if TYPE_CHECKING:
120-
from sqlalchemy.sql.expression import Select
121+
from sqlalchemy.sql.expression import (
122+
ColumnElement,
123+
Label,
124+
Select,
125+
)
121126

122127
log = logging.getLogger(__name__)
123128

@@ -133,7 +138,7 @@ class JobLock(BaseModel):
133138
active: bool = Field(title="Job lock status", description="If active, jobs will not dispatch")
134139

135140

136-
def get_path_key(path_tuple):
141+
def get_path_key(path_tuple: tuple):
137142
path_key = ""
138143
tuple_elements = len(path_tuple)
139144
for i, p in enumerate(path_tuple):
@@ -153,12 +158,15 @@ def get_path_key(path_tuple):
153158

154159

155160
def safe_label_or_none(label: str) -> Optional[str]:
156-
if label and len(label) > 63:
161+
if len(label) > 63:
157162
return None
158163
return label
159164

160165

161-
def safe_aliased(model_class, name=None):
166+
T = TypeVar("T")
167+
168+
169+
def safe_aliased(model_class: type[T], name: str) -> type[T]:
162170
"""Create an aliased model class with a unique name."""
163171
return aliased(model_class, name=safe_label_or_none(name))
164172

@@ -476,11 +484,11 @@ def by_tool_input(
476484
job_state: Optional[JobStatesT] = (Job.states.OK,),
477485
history_id: Union[int, None] = None,
478486
require_name_match: bool = True,
479-
):
487+
) -> Union[Job, None]:
480488
"""Search for jobs producing same results using the 'inputs' part of a tool POST."""
481-
input_data = defaultdict(list)
489+
input_data: dict[Any, list[dict[str, Any]]] = defaultdict(list)
482490

483-
def populate_input_data_input_id(path, key, value):
491+
def populate_input_data_input_id(path: tuple, key, value) -> tuple[Any, Any]:
484492
"""Traverses expanded incoming using remap and collects input_ids and input_data."""
485493
if key == "id":
486494
path_key = get_path_key(path[:-2])
@@ -528,13 +536,13 @@ def __search(
528536
tool_id: str,
529537
tool_version: Optional[str],
530538
user: model.User,
531-
input_data,
539+
input_data: dict[Any, list[dict[str, Any]]],
532540
job_state: Optional[JobStatesT],
533541
param_dump: ToolStateDumpedToJsonInternalT,
534542
wildcard_param_dump=None,
535543
history_id: Union[int, None] = None,
536544
require_name_match: bool = True,
537-
):
545+
) -> Union[Job, None]:
538546
search_timer = ExecutionTimer()
539547

540548
def replace_dataset_ids(path, key, value):
@@ -554,15 +562,15 @@ def replace_dataset_ids(path, key, value):
554562

555563
stmt = select(model.Job.id.label("job_id"))
556564

557-
data_conditions: list = []
565+
data_conditions: list[ColumnElement[bool]] = []
558566

559567
# We now build the stmt filters that relate to the input datasets
560568
# that this job uses. We keep track of the requested dataset id in `requested_ids`,
561569
# the type (hda, hdca or lda) in `data_types`
562570
# and the ids that have been used in the job that has already been run in `used_ids`.
563571
requested_ids = []
564572
data_types = []
565-
used_ids: list = []
573+
used_ids: list[Label[int]] = []
566574
for k, input_list in input_data.items():
567575
# k will be matched against the JobParameter.name column. This can be prefixed depending on whether
568576
# the input is in a repeat, or not (section and conditional)
@@ -751,7 +759,7 @@ def _filter_jobs(
751759

752760
return stmt
753761

754-
def _exclude_jobs_with_deleted_outputs(self, stmt):
762+
def _exclude_jobs_with_deleted_outputs(self, stmt: "Select[tuple[int]]") -> "Select":
755763
subquery_alias = stmt.subquery("filtered_jobs_subquery")
756764
outer_select_columns = [subquery_alias.c[col.name] for col in stmt.selected_columns]
757765
outer_stmt = select(*outer_select_columns).select_from(subquery_alias)
@@ -796,14 +804,14 @@ def _exclude_jobs_with_deleted_outputs(self, stmt):
796804
def _build_stmt_for_hda(
797805
self,
798806
stmt: "Select[tuple[int]]",
799-
data_conditions: list,
800-
used_ids: list,
807+
data_conditions: list["ColumnElement[bool]"],
808+
used_ids: list["Label[int]"],
801809
k,
802810
v,
803811
identifier,
804812
value_index: int,
805813
require_name_match: bool = True,
806-
):
814+
) -> "Select[tuple[int]]":
807815
a = aliased(model.JobToInputDatasetAssociation)
808816
b = aliased(model.HistoryDatasetAssociation)
809817
c = aliased(model.HistoryDatasetAssociation)
@@ -859,7 +867,15 @@ def _build_stmt_for_hda(
859867
)
860868
return stmt
861869

862-
def _build_stmt_for_ldda(self, stmt, data_conditions, used_ids, k, v, value_index):
870+
def _build_stmt_for_ldda(
871+
self,
872+
stmt: "Select[tuple[int]]",
873+
data_conditions: list["ColumnElement[bool]"],
874+
used_ids: list["Label[int]"],
875+
k,
876+
v,
877+
value_index: int,
878+
) -> "Select[tuple[int]]":
863879
a = aliased(model.JobToInputLibraryDatasetAssociation)
864880
label = safe_label_or_none(f"{k}_{value_index}")
865881
labeled_col = a.ldda_id.label(label)
@@ -876,8 +892,15 @@ def agg_expression(self, column):
876892
return func.array_agg(column, order_by=column)
877893

878894
def _build_stmt_for_hdca(
879-
self, stmt, data_conditions, used_ids, k, v, user_id, value_index, require_name_match=True
880-
):
895+
self,
896+
stmt: "Select[tuple[int]]",
897+
data_conditions: list["ColumnElement[bool]"],
898+
used_ids: list["Label[int]"],
899+
k,
900+
v,
901+
user_id: int,
902+
value_index: int,
903+
) -> "Select[tuple[int]]":
881904
# Strategy for efficiently finding equivalent HDCAs:
882905
# 1. Determine the structural depth of the target HDCA by its collection_type.
883906
# 2. For the target HDCA (identified by 'v'):
@@ -1112,7 +1135,16 @@ def _build_stmt_for_hdca(
11121135
data_conditions.append(a.name == k)
11131136
return stmt
11141137

1115-
def _build_stmt_for_dce(self, stmt, data_conditions, used_ids, k, v, user_id, value_index):
1138+
def _build_stmt_for_dce(
1139+
self,
1140+
stmt: "Select[tuple[int]]",
1141+
data_conditions: list["ColumnElement[bool]"],
1142+
used_ids: list["Label[int]"],
1143+
k,
1144+
v,
1145+
user_id: int,
1146+
value_index: int,
1147+
) -> "Select[tuple[int]]":
11161148
dce_root_target = self.sa_session.get_one(model.DatasetCollectionElement, v)
11171149

11181150
# Determine if the target DCE points to an HDA or a child collection

lib/galaxy_test/api/test_tool_execute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,7 @@ def _run_deferred(
755755

756756

757757
@requires_tool_id("cat|cat1")
758-
def test_deferred_with_cached_input(required_tool: RequiredTool, target_history: TargetHistory):
758+
def test_deferred_with_cached_input(required_tool: RequiredTool, target_history: TargetHistory) -> None:
759759
# Basic deferred dataset
760760
_run_deferred(required_tool, target_history)
761761
# Should just work because input is deferred

lib/galaxy_test/base/api.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import (
44
Any,
55
Optional,
6+
TYPE_CHECKING,
67
)
78
from urllib.parse import (
89
urlencode,
@@ -30,6 +31,9 @@
3031
)
3132
from .interactor import TestCaseGalaxyInteractor as BaseInteractor
3233

34+
if TYPE_CHECKING:
35+
from requests import Response
36+
3337
CONFIG_PREFIXES = ["GALAXY_TEST_CONFIG_", "GALAXY_CONFIG_OVERRIDE_", "GALAXY_CONFIG_"]
3438
CELERY_BROKER = get_from_env("CELERY_BROKER", CONFIG_PREFIXES, "memory://")
3539
CELERY_BACKEND = get_from_env("CELERY_BACKEND", CONFIG_PREFIXES, "rpc://localhost")
@@ -198,7 +202,7 @@ def _patch(self, *args, **kwds):
198202
def _assert_status_code_is_ok(self, response):
199203
assert_status_code_is_ok(response)
200204

201-
def _assert_status_code_is(self, response, expected_status_code):
205+
def _assert_status_code_is(self, response: "Response", expected_status_code: int) -> None:
202206
assert_status_code_is(response, expected_status_code)
203207

204208
def _assert_has_keys(self, response, *keys):

0 commit comments

Comments
 (0)