Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
torch==2.9.1
mypy==1.19.1
mypy==1.20.0

types-Markdown
types-PyYAML
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def _check_and_init_precision(self) -> Precision:
f" Choose a different precision among: {', '.join(mp_precision_supported)}."
)
if self._precision_input in ("16-true", "bf16-true"):
return HalfPrecision(self._precision_input) # type: ignore
return HalfPrecision(self._precision_input)
if self._precision_input == "32-true":
return Precision()
if self._precision_input == "64-true":
Expand All @@ -493,7 +493,7 @@ def _check_and_init_precision(self) -> Precision:
else "Using bfloat16 Automatic Mixed Precision (AMP)"
)
device = self._accelerator_flag if self._accelerator_flag in ("cpu", "mps") else "cuda"
return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type]
return MixedPrecision(precision=self._precision_input, device=device)

raise RuntimeError("No precision set")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def destroy_group(cls, group: CollectibleGroup) -> None:
# can be called by all processes in the default group, group will be `object()` if they are not part of the
# current group
if group in dist.distributed_c10d._pg_map:
dist.destroy_process_group(group) # type: ignore[arg-type]
dist.destroy_process_group(group)

@classmethod
@override
Expand Down
6 changes: 1 addition & 5 deletions src/lightning/fabric/utilities/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

from lightning_utilities import is_overridden

from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_10_0

_log = logging.getLogger(__name__)


Expand All @@ -36,9 +34,7 @@ def _load_external_callbacks(group: str) -> list[Any]:
A list of all callbacks collected from external factories.

"""
factories = (
entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {}) # type: ignore[arg-type]
)
factories = entry_points(group=group)

external_callbacks: list[Any] = []
for factory in factories:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def _check_and_init_precision(self) -> Precision:
if isinstance(self.strategy, FSDPStrategy):
return FSDPPrecision(self._precision_flag) # type: ignore[arg-type]
if self._precision_flag in ("16-true", "bf16-true"):
return HalfPrecision(self._precision_flag) # type: ignore
return HalfPrecision(self._precision_flag)
if self._precision_flag == "32-true":
return Precision()
if self._precision_flag == "64-true":
Expand All @@ -487,7 +487,7 @@ def _check_and_init_precision(self) -> Precision:
f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)"
)
device = self._accelerator_flag if self._accelerator_flag in ("cpu", "mps") else "cuda"
return MixedPrecision(self._precision_flag, device) # type: ignore[arg-type]
return MixedPrecision(self._precision_flag, device)

raise RuntimeError("No precision set")

Expand Down
Loading