Skip to content

Commit 5216328

Browse files
authored
Merge pull request #33 from AllenInstitute/feature/upgrade-mypy-configs
upgrade mypy and change 'no_site_packages = false' in config
2 parents 27de389 + 6c5660b commit 5216328

File tree

21 files changed

+361
-202
lines changed

21 files changed

+361
-202
lines changed

pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ requires-python = ">=3.9"
2121
dependencies = [
2222
"boto3~=1.35",
2323
"aibs-informatics-core~=0.1",
24+
"typing-extensions~=4.15; python_version < '3.11'",
2425
]
2526

2627
# For dev dependencies: https://peps.python.org/pep-0735/
@@ -32,8 +33,10 @@ dev = [
3233
]
3334
lint = [
3435
"boto3-stubs[athena,apigateway,batch,ecr,ecs,efs,essential,fsx,logs,secretsmanager,ses,sns,ssm,sts,stepfunctions]",
35-
"mypy~=1.13.0",
36+
"mypy~=1.18.0",
3637
"ruff~=0.9",
38+
"types-pytz",
39+
"types-requests",
3740
]
3841
release = [
3942
"build",
@@ -172,7 +175,7 @@ incremental = false
172175
# https://mypy.readthedocs.io/en/stable/config_file.html#import-discovery
173176
ignore_missing_imports = true
174177
follow_imports = "silent"
175-
no_site_packages = true
178+
no_site_packages = false
176179

177180
# Untyped definitions and calls
178181
# https://mypy.readthedocs.io/en/stable/config_file.html#untyped-definitions-and-calls

src/aibs_informatics_aws_utils/athena.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,34 @@
11
import logging
2+
import sys
23
import time
34
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple
45

6+
if sys.version_info >= (3, 11):
7+
# For Python 3.11+
8+
from typing import Unpack
9+
else: # pragma: no cover
10+
# For Python < 3.11
11+
from typing_extensions import Unpack
12+
13+
514
from botocore.exceptions import ClientError
615

716
from aibs_informatics_aws_utils.core import AWSService
817
from aibs_informatics_aws_utils.exceptions import AWSError
918

1019
if TYPE_CHECKING: # pragma: no cover
1120
from mypy_boto3_athena.type_defs import (
12-
GetQueryExecutionInputRequestTypeDef,
1321
GetQueryExecutionOutputTypeDef,
1422
QueryExecutionStatusTypeDef,
1523
QueryExecutionTypeDef,
16-
StartQueryExecutionInputRequestTypeDef,
24+
StartQueryExecutionInputTypeDef,
1725
StartQueryExecutionOutputTypeDef,
1826
)
1927
else:
20-
GetQueryExecutionInputRequestTypeDef = dict
2128
GetQueryExecutionOutputTypeDef = dict
2229
QueryExecutionStatusTypeDef = dict
2330
QueryExecutionTypeDef = dict
24-
StartQueryExecutionInputRequestTypeDef = dict
31+
StartQueryExecutionInputTypeDef = dict
2532
StartQueryExecutionOutputTypeDef = dict
2633

2734

@@ -36,11 +43,11 @@ def start_query_execution(
3643
query_string: str,
3744
work_group: Optional[str] = None,
3845
execution_parameters: Optional[List[str]] = None,
39-
**kwargs,
46+
**kwargs: Unpack[StartQueryExecutionInputTypeDef],
4047
) -> StartQueryExecutionOutputTypeDef:
4148
athena = get_athena_client()
4249

43-
request = StartQueryExecutionInputRequestTypeDef(QueryString=query_string)
50+
request = StartQueryExecutionInputTypeDef(QueryString=query_string)
4451
if work_group:
4552
request["WorkGroup"] = work_group
4653
if execution_parameters:
@@ -73,8 +80,8 @@ def query_waiter(
7380
logger.info(f"Query Execution Status: {stats}")
7481
status = stats["QueryExecution"].get("Status", {})
7582
state = status.get("State")
76-
if state and state in ["SUCCEEDED", "FAILED", "CANCELLED"]:
77-
return state, status
83+
if state in ["SUCCEEDED", "FAILED", "CANCELLED", "TIMEOUT"]:
84+
return state, status # type: ignore[return-value]
7885
time.sleep(0.2) # 200ms
7986
# Exit if the time waiting exceed the timeout seconds
8087
if time.time() > start + timeout:

src/aibs_informatics_aws_utils/batch.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from dataclasses import dataclass, field
2-
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Union
2+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, Union
33

44
from aibs_informatics_core.env import ENV_BASE_KEY_ALIAS, EnvBase, get_env_base
55
from aibs_informatics_core.models.aws.batch import JobName, ResourceRequirements
66
from aibs_informatics_core.utils.decorators import retry
77
from aibs_informatics_core.utils.hashing import sha256_hexdigest
88
from aibs_informatics_core.utils.logging import get_logger
9-
from aibs_informatics_core.utils.tools.dicttools import convert_key_case, remove_null_values
9+
from aibs_informatics_core.utils.tools.dicttools import convert_key_case
1010
from aibs_informatics_core.utils.tools.strtools import pascalcase
1111
from botocore.exceptions import ClientError
1212

@@ -24,8 +24,8 @@
2424
HostTypeDef,
2525
JobDefinitionTypeDef,
2626
KeyValuePairTypeDef,
27+
LinuxParametersTypeDef,
2728
MountPointTypeDef,
28-
RegisterJobDefinitionRequestRequestTypeDef,
2929
RegisterJobDefinitionResponseTypeDef,
3030
ResourceRequirementTypeDef,
3131
RetryStrategyTypeDef,
@@ -42,8 +42,8 @@
4242
JobDefinitionTypeDef = dict
4343
DescribeJobsResponseTypeDef = dict
4444
KeyValuePairTypeDef = dict
45+
LinuxParametersTypeDef = dict
4546
MountPointTypeDef = dict
46-
RegisterJobDefinitionRequestRequestTypeDef = dict
4747
RegisterJobDefinitionResponseTypeDef = dict
4848
ResourceRequirementTypeDef = dict
4949
RetryStrategyTypeDef = dict
@@ -61,27 +61,27 @@ def to_volume(
6161
name: Optional[str],
6262
efs_volume_configuration: Optional[EFSVolumeConfigurationTypeDef],
6363
) -> VolumeTypeDef:
64-
return remove_null_values(
65-
VolumeTypeDef(
66-
host=HostTypeDef(sourcePath=source_path) if source_path else None,
67-
name=name,
68-
efsVolumeConfiguration=efs_volume_configuration,
69-
)
70-
)
64+
volume_dict = VolumeTypeDef()
65+
if source_path:
66+
volume_dict["host"] = HostTypeDef(sourcePath=source_path)
67+
if name:
68+
volume_dict["name"] = name
69+
if efs_volume_configuration:
70+
volume_dict["efsVolumeConfiguration"] = efs_volume_configuration
71+
return volume_dict
7172

7273

7374
def to_mount_point(
7475
container_path: Optional[str],
7576
read_only: bool,
7677
source_volume: Optional[str],
7778
) -> MountPointTypeDef:
78-
return remove_null_values(
79-
MountPointTypeDef(
80-
containerPath=container_path,
81-
readOnly=read_only,
82-
sourceVolume=source_volume,
83-
)
84-
)
79+
mount_point_dict = MountPointTypeDef(readOnly=read_only)
80+
if container_path:
81+
mount_point_dict["containerPath"] = container_path
82+
if source_volume:
83+
mount_point_dict["sourceVolume"] = source_volume
84+
return mount_point_dict
8585

8686

8787
def to_key_value_pairs(
@@ -104,7 +104,7 @@ def to_key_value_pairs(
104104
for k, v in environment.items()
105105
if not remove_null_values or v is not None
106106
],
107-
key=lambda _: _.get("name"),
107+
key=lambda _: _.get("name", ""),
108108
)
109109

110110

@@ -115,16 +115,23 @@ def to_resource_requirements(
115115
) -> List[ResourceRequirementTypeDef]:
116116
"""Converts Batch resource requirement parameters into a list of ResourceRequirement objects
117117
118+
The returned list only includes dictionary entries for resources that specify
119+
an explicit value. Anything unset will be dropped.
120+
118121
Args:
119-
gpu (Optional[int], optional): number of . Defaults to None.
120-
memory (Optional[int], optional): _description_. Defaults to None.
121-
vcpus (Optional[int], optional): _description_. Defaults to None.
122+
gpu (Optional[int], optional): number of GPUs to use. Defaults to None.
123+
memory (Optional[int], optional): amount of memory in MiB. Defaults to None.
124+
vcpus (Optional[int], optional): Number of VCPUs to use. Defaults to None.
122125
123126
Returns:
124127
List[ResourceRequirementTypeDef]: list of resource requirements
125128
"""
126129

127-
pairs = [("GPU", gpu), ("MEMORY", memory), ("VCPU", vcpus)]
130+
pairs: list[tuple[Literal["GPU", "MEMORY", "VCPU"], Optional[int]]] = [
131+
("GPU", gpu),
132+
("MEMORY", memory),
133+
("VCPU", vcpus),
134+
]
128135
return [ResourceRequirementTypeDef(type=t, value=str(v)) for t, v in pairs if v is not None]
129136

130137

@@ -176,7 +183,7 @@ def register_job_definition(
176183
tags: Optional[Mapping[str, str]] = None,
177184
propagate_tags: bool = False,
178185
region: Optional[str] = None,
179-
) -> JobDefinitionTypeDef:
186+
) -> JobDefinitionTypeDef | RegisterJobDefinitionResponseTypeDef:
180187
batch = get_batch_client(region=region)
181188

182189
# First we check to make sure that we aren't crearting unnecessary revisions
@@ -210,7 +217,7 @@ def register_job_definition(
210217
logger.info(
211218
f"Registering job definition with following properties: {register_job_definition_kwargs}"
212219
)
213-
response = batch.register_job_definition(**register_job_definition_kwargs)
220+
response = batch.register_job_definition(**register_job_definition_kwargs) # type: ignore[arg-type]
214221
return response
215222

216223

@@ -263,7 +270,7 @@ class BatchJobBuilder:
263270
mount_points: List[MountPointTypeDef] = field(default_factory=list)
264271
volumes: List[VolumeTypeDef] = field(default_factory=list)
265272
privileged: bool = field(default=False)
266-
linux_parameters: Optional[Dict[str, Any]] = field(default=None)
273+
linux_parameters: Optional[LinuxParametersTypeDef] = field(default=None)
267274
env_base: EnvBase = field(default_factory=EnvBase.from_env)
268275

269276
def __post_init__(self):
@@ -295,7 +302,7 @@ def container_overrides(self) -> ContainerOverridesTypeDef:
295302

296303
@property
297304
def container_overrides__sfn(self) -> Dict[str, Any]:
298-
return convert_key_case(self.container_overrides, pascalcase)
305+
return convert_key_case(self.container_overrides, pascalcase) # type: ignore[arg-type]
299306

300307
def _normalized_resource_requirements(self) -> List[ResourceRequirementTypeDef]:
301308
if isinstance(self.resource_requirements, list):

src/aibs_informatics_aws_utils/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def get_region(region: Optional[str] = None) -> str:
145145

146146
def get_account_id() -> str:
147147
"""Will get the account id from the current credentials/identity"""
148-
return get_caller_identity().get("Account")
148+
return get_caller_identity()["Account"]
149149

150150

151151
def get_user_id() -> UserId:

src/aibs_informatics_aws_utils/data_sync/operations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def sync_s3_to_s3(
256256
result = DataSyncResult()
257257
if self.config.include_detailed_response:
258258
path_stats = get_s3_path_stats(destination_path)
259-
result.files_transferred = path_stats.object_count
259+
result.files_transferred = path_stats.object_count or 0
260260
result.bytes_transferred = path_stats.size_bytes
261261
return result
262262

src/aibs_informatics_aws_utils/dynamodb/conditions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def deserialize_condition(
255255
def _deserialize_condition(ce: ConditionBaseExpression) -> ConditionBase:
256256
ce_key = (ce.format, ce.operator)
257257
condition_base_cls = cls._CONDITION_BASE_CLASS_LOOKUP[ce_key]
258-
ce_values = []
258+
ce_values: list[AttributeBase | ConditionBase] = []
259259
for ce_value in ce.values:
260260
if isinstance(ce_value, ConditionBaseExpression):
261261
ce_values.append(_deserialize_condition(ce_value))

src/aibs_informatics_aws_utils/dynamodb/functions.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,21 @@
1414

1515
if TYPE_CHECKING: # pragma: no cover
1616
from mypy_boto3_dynamodb.type_defs import (
17-
BatchGetItemInputRequestTypeDef,
18-
GetItemInputRequestTypeDef,
17+
BatchGetItemInputTypeDef,
18+
GetItemInputTypeDef,
1919
KeysAndAttributesTypeDef,
20-
QueryInputRequestTypeDef,
21-
ScanInputRequestTypeDef,
20+
QueryInputTableQueryTypeDef,
21+
QueryInputTypeDef,
22+
ScanInputTypeDef,
2223
)
2324
else:
24-
BatchGetItemInputRequestTypeDef = dict
25+
BatchGetItemInputTypeDef = dict
2526
GetItemInputRequestTypeDef = dict
27+
GetItemInputTypeDef = dict
2628
KeysAndAttributesTypeDef = dict
27-
QueryInputRequestTypeDef = dict
28-
ScanInputRequestTypeDef = dict
29+
QueryInputTableQueryTypeDef = dict
30+
QueryInputTypeDef = dict
31+
ScanInputTypeDef = dict
2932

3033

3134
logger = get_logger(__name__)
@@ -105,7 +108,7 @@ def table_get_items(
105108
"Keys": serialized_keys,
106109
}
107110
}
108-
props: BatchGetItemInputRequestTypeDef = {
111+
props: BatchGetItemInputTypeDef = {
109112
"RequestItems": request_items,
110113
"ReturnConsumedCapacity": "NONE",
111114
}
@@ -213,7 +216,7 @@ def table_query(
213216
key_expr_component.expression_attribute_values__serialized
214217
)
215218

216-
db_request: QueryInputRequestTypeDef = {
219+
db_request: QueryInputTypeDef = {
217220
"TableName": table.name,
218221
"KeyConditionExpression": key_expr_component.condition_expression,
219222
}
@@ -287,7 +290,7 @@ def table_scan(
287290
db = get_dynamodb_client(region=region)
288291
table = table_as_resource(table_name)
289292

290-
db_request: ScanInputRequestTypeDef = {"TableName": table.name}
293+
db_request: ScanInputTypeDef = {"TableName": table.name}
291294

292295
# Handle when filter_expression is provided
293296
if filter_expression is not None:

src/aibs_informatics_aws_utils/dynamodb/table.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,12 @@ def build_optimized_condition_expression_set(
204204
new_condition = Key(k).eq(v)
205205
if (
206206
k in candidate_conditions
207-
and candidate_conditions[k]._values[1:] != new_condition._values[1:] # type: ignore[union-attr]
207+
and candidate_conditions[k]._values[1:] != new_condition._values[1:] # type: ignore[attr-defined,union-attr]
208208
):
209209
raise DBQueryException(f"Multiple values provided for attribute {k}!")
210210
candidate_conditions[k] = Key(k).eq(v)
211-
elif len(_._values) and isinstance(_._values[0], (Key, Attr)): # type: ignore[union-attr]
212-
attr_name = cast(str, _._values[0].name) # type: ignore[union-attr]
211+
elif len(_._values) and isinstance(_._values[0], (Key, Attr)): # type: ignore[attr-defined,union-attr]
212+
attr_name = cast(str, _._values[0].name) # type: ignore[attr-defined,union-attr]
213213
if attr_name not in index_all_key_names or not isinstance(
214214
_, SupportedKeyComparisonTypes
215215
):
@@ -728,7 +728,7 @@ def delete(
728728
try:
729729
deleted_attributes = table_delete_item(
730730
table_name=self.table_name,
731-
key=key,
731+
key=cast(DynamoDBKey, key),
732732
return_values="ALL_OLD", # type: ignore[arg-type] # expected type more general than specified here
733733
)
734734

0 commit comments

Comments
 (0)