Skip to content

Commit 241fc0a

Browse files
committed
fix: remove type ignore comments for precision handling in connector and accelerator_connector
1 parent 1c037ba commit 241fc0a

4 files changed

Lines changed: 6 additions & 8 deletions

File tree

src/lightning/fabric/connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def _check_and_init_precision(self) -> Precision:
469469
f" Choose a different precision among: {', '.join(mp_precision_supported)}."
470470
)
471471
if self._precision_input in ("16-true", "bf16-true"):
472-
return HalfPrecision(self._precision_input) # type: ignore
472+
return HalfPrecision(self._precision_input)
473473
if self._precision_input == "32-true":
474474
return Precision()
475475
if self._precision_input == "64-true":
@@ -493,7 +493,7 @@ def _check_and_init_precision(self) -> Precision:
493493
else "Using bfloat16 Automatic Mixed Precision (AMP)"
494494
)
495495
device = self._accelerator_flag if self._accelerator_flag in ("cpu", "mps") else "cuda"
496-
return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type]
496+
return MixedPrecision(precision=self._precision_input, device=device)
497497

498498
raise RuntimeError("No precision set")
499499

src/lightning/fabric/plugins/collectives/torch_collective.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def destroy_group(cls, group: CollectibleGroup) -> None:
199199
# can be called by all processes in the default group, group will be `object()` if they are not part of the
200200
# current group
201201
if group in dist.distributed_c10d._pg_map:
202-
dist.destroy_process_group(group) # type: ignore[arg-type]
202+
dist.destroy_process_group(group)
203203

204204
@classmethod
205205
@override

src/lightning/fabric/utilities/registry.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ def _load_external_callbacks(group: str) -> list[Any]:
3636
A list of all callbacks collected from external factories.
3737
3838
"""
39-
factories = (
40-
entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {}) # type: ignore[arg-type]
41-
)
39+
factories = entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {})
4240

4341
external_callbacks: list[Any] = []
4442
for factory in factories:

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def _check_and_init_precision(self) -> Precision:
465465
if isinstance(self.strategy, FSDPStrategy):
466466
return FSDPPrecision(self._precision_flag) # type: ignore[arg-type]
467467
if self._precision_flag in ("16-true", "bf16-true"):
468-
return HalfPrecision(self._precision_flag) # type: ignore
468+
return HalfPrecision(self._precision_flag)
469469
if self._precision_flag == "32-true":
470470
return Precision()
471471
if self._precision_flag == "64-true":
@@ -487,7 +487,7 @@ def _check_and_init_precision(self) -> Precision:
487487
f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)"
488488
)
489489
device = self._accelerator_flag if self._accelerator_flag in ("cpu", "mps") else "cuda"
490-
return MixedPrecision(self._precision_flag, device) # type: ignore[arg-type]
490+
return MixedPrecision(self._precision_flag, device)
491491

492492
raise RuntimeError("No precision set")
493493

0 commit comments

Comments
 (0)