Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ ignore = [
"S603", # todo: `subprocess` call: check for execution of untrusted input
"S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell`
"S607", # todo: Starting a process with a partial executable path
"RET504", # todo:Unnecessary variable assignment before `return` statement
"RET503",
]
"tests/**" = [
"S101", # Use of `assert` detected
Expand All @@ -111,6 +109,7 @@ ignore = [
"S603", # todo: `subprocess` call: check for execution of untrusted input
"S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell`
"S607", # todo: Starting a process with a partial executable path
"PT004", # todo: Fixture `tmpdir_unittest_fixture` does not return anything, add leading underscore
"PT012", # todo: `pytest.raises()` block should contain a single simple statement
"PT019", # todo: Fixture `_` without value is injected as parameter, use `@pytest.mark.usefixtures` instead
]
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,7 @@ def _setup_dataloader(
dataloader = self._strategy.process_dataloader(dataloader)
device = self.device if move_to_device and not isinstance(self._strategy, XLAStrategy) else None
fabric_dataloader = _FabricDataLoader(dataloader=dataloader, device=device)
fabric_dataloader = cast(DataLoader, fabric_dataloader)
return fabric_dataloader
return cast(DataLoader, fabric_dataloader)

def backward(self, tensor: Tensor, *args: Any, model: Optional[_FabricModule] = None, **kwargs: Any) -> None:
r"""Replaces ``loss.backward()`` in your training loop. Handles precision automatically for you.
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/fabric/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ def log_dir(self) -> str:
if isinstance(self.sub_dir, str):
log_dir = os.path.join(log_dir, self.sub_dir)
log_dir = os.path.expandvars(log_dir)
log_dir = os.path.expanduser(log_dir)
return log_dir
return os.path.expanduser(log_dir)

@property
def sub_dir(self) -> Optional[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
)
elif self.replace_layers in (None, True):
_convert_layers(module)
module = module.to(dtype=self.weights_dtype)
return module
return module.to(dtype=self.weights_dtype)

@override
def tensor_init_context(self) -> AbstractContextManager:
Expand Down
13 changes: 5 additions & 8 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,19 +867,18 @@ def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
)


def _get_sharded_state_dict_context(module: Module) -> Generator[None, None, None]:
def _get_sharded_state_dict_context(module: Module) -> Generator:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType

state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
optim_state_dict_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
state_dict_type_context = FSDP.state_dict_type(
return FSDP.state_dict_type(
module=module,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=state_dict_config,
optim_state_dict_config=optim_state_dict_config,
)
return state_dict_type_context # type: ignore[return-value]
) # type: ignore[return-value]


def _get_full_state_dict_context(
Expand All @@ -891,14 +890,12 @@ def _get_full_state_dict_context(

state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=rank0_only)
optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=rank0_only)
state_dict_type_context = FSDP.state_dict_type(
return FSDP.state_dict_type(
module=module,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=state_dict_config,
optim_state_dict_config=optim_state_dict_config,
)

return state_dict_type_context # type: ignore[return-value]
) # type: ignore[return-value]


def _is_sharded_checkpoint(path: Path) -> bool:
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/fabric/strategies/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
decision,
reduce_op=ReduceOp.SUM, # type: ignore[arg-type]
)
decision = bool(decision == self.world_size) if all else bool(decision)
return decision
return bool(decision == self.world_size) if all else bool(decision)

@override
def teardown(self) -> None:
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/fabric/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
import torch_xla.core.xla_model as xm

tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
tensor = tensor.to(original_device)
return tensor
return tensor.to(original_device)

@override
def all_reduce(
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/fabric/strategies/xla_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
import torch_xla.core.xla_model as xm

tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
tensor = tensor.to(original_device)
return tensor
return tensor.to(original_device)

@override
def all_reduce(
Expand Down
1 change: 1 addition & 0 deletions src/lightning/fabric/utilities/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ def get_available_flops(device: torch.device, dtype: Union[torch.dtype, str]) ->
rank_zero_warn(f"FLOPs not found for TPU {device_name!r} with {dtype}")
return None
return int(_TPU_FLOPS[chip])
return None


def _plugin_to_compute_dtype(plugin: "Precision") -> torch.dtype:
Expand Down
4 changes: 1 addition & 3 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,9 +747,7 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] =
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])

# If using multiple devices, make sure all processes are unanimous on the decision.
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save))

return should_update_best_and_save
return trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save))

def _format_checkpoint_name(
self,
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,5 +322,4 @@ def format_loader_info(info: dict[str, Union[dataset_info, Iterable[dataset_info
# Retrieve information for each dataloader method
dataloader_info = extract_loader_info(datamodule_loader_methods)
# Format the information
dataloader_str = format_loader_info(dataloader_info)
return dataloader_str
return format_loader_info(dataloader_info)
6 changes: 2 additions & 4 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,7 @@ def _apply_batch_transfer_handler(
) -> Any:
device = device or self.device
batch = self._call_batch_hook("transfer_batch_to_device", batch, device, dataloader_idx)
batch = self._call_batch_hook("on_after_batch_transfer", batch, dataloader_idx)
return batch
return self._call_batch_hook("on_after_batch_transfer", batch, dataloader_idx)

def print(self, *args: Any, **kwargs: Any) -> None:
r"""Prints only from process 0. Use this in any distributed mode to log only once.
Expand Down Expand Up @@ -666,8 +665,7 @@ def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
f" You can try doing `self.log({name}, {value}.mean())`"
)
value = value.squeeze()
return value
return value.squeeze()

def all_gather(
self, data: Union[Tensor, dict, list, tuple], group: Optional[Any] = None, sync_grads: bool = False
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def step(self, closure: Callable[[], float]) -> float: ...
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
if closure is not None:
return closure()
return None

@override
def zero_grad(self, set_to_none: Optional[bool] = True) -> None:
Expand Down
12 changes: 4 additions & 8 deletions src/lightning/pytorch/demos/boring_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,20 +253,16 @@ def setup(self, stage: str) -> None:
]

def train_dataloader(self) -> Iterable[DataLoader]:
combined_train = apply_to_collection(self.train_datasets, Dataset, lambda x: DataLoader(x))
return combined_train
return apply_to_collection(self.train_datasets, Dataset, lambda x: DataLoader(x))

def val_dataloader(self) -> DataLoader:
combined_val = apply_to_collection(self.val_datasets, Dataset, lambda x: DataLoader(x))
return combined_val
return apply_to_collection(self.val_datasets, Dataset, lambda x: DataLoader(x))

def test_dataloader(self) -> DataLoader:
combined_test = apply_to_collection(self.test_datasets, Dataset, lambda x: DataLoader(x))
return combined_test
return apply_to_collection(self.test_datasets, Dataset, lambda x: DataLoader(x))

def predict_dataloader(self) -> DataLoader:
combined_predict = apply_to_collection(self.predict_datasets, Dataset, lambda x: DataLoader(x))
return combined_predict
return apply_to_collection(self.predict_datasets, Dataset, lambda x: DataLoader(x))


class ManualOptimBoringModel(BoringModel):
Expand Down
12 changes: 4 additions & 8 deletions src/lightning/pytorch/demos/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def __init__(
def generate_square_subsequent_mask(self, size: int) -> Tensor:
"""Generate a square mask for the sequence to prevent future tokens from being seen."""
mask = torch.triu(torch.ones(size, size), diagonal=1)
mask = mask.float().masked_fill(mask == 1, float("-inf")).masked_fill(mask == 0, 0.0)
return mask
return mask.float().masked_fill(mask == 1, float("-inf")).masked_fill(mask == 0, 0.0)

def forward(self, inputs: Tensor, target: Tensor, mask: Optional[Tensor] = None) -> Tensor:
_, t = inputs.shape
Expand All @@ -78,8 +77,7 @@ def forward(self, inputs: Tensor, target: Tensor, mask: Optional[Tensor] = None)
output = self.transformer(src, target, tgt_mask=mask)
output = self.decoder(output)
output = F.log_softmax(output, dim=-1)
output = output.view(-1, self.vocab_size)
return output
return output.view(-1, self.vocab_size)


class PositionalEncoding(nn.Module):
Expand All @@ -106,8 +104,7 @@ def _init_pos_encoding(self, device: torch.device) -> Tensor:
div_term = torch.exp(torch.arange(0, self.dim, 2, device=device).float() * (-math.log(10000.0) / self.dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
return pe
return pe.unsqueeze(0)


class WikiText2(Dataset):
Expand Down Expand Up @@ -200,8 +197,7 @@ def forward(self, inputs: Tensor, target: Tensor) -> Tensor:
def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
inputs, target = batch
output = self(inputs, target)
loss = torch.nn.functional.nll_loss(output, target.view(-1))
return loss
return torch.nn.functional.nll_loss(output, target.view(-1))

def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.SGD(self.model.parameters(), lr=0.1)
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def version(self) -> Optional[str]:
# Don't create an experiment if we don't have one
if self._experiment is not None:
return self._experiment.get_key()
return None

def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ def log_dir(self) -> str:
if isinstance(self.sub_dir, str):
log_dir = os.path.join(log_dir, self.sub_dir)
log_dir = os.path.expandvars(log_dir)
log_dir = os.path.expanduser(log_dir)
return log_dir
return os.path.expanduser(log_dir)

@property
@override
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/loggers/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def _scan_checkpoints(checkpoint_callback: Checkpoint, logged_model_time: dict)
checkpoints = sorted(
(Path(p).stat().st_mtime, p, s, tag) for p, (s, tag) in checkpoints.items() if Path(p).is_file()
)
checkpoints = [c for c in checkpoints if c[1] not in logged_model_time or logged_model_time[c[1]] < c[0]]
return checkpoints
return [c for c in checkpoints if c[1] not in logged_model_time or logged_model_time[c[1]] < c[0]]


def _log_hyperparams(trainer: "pl.Trainer") -> None:
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/profilers/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,12 +376,11 @@ def _total_steps(self) -> Union[int, float]:
)
return num_val_batches + num_sanity_val_batches
if self._schedule.is_testing:
num_test_batches = (
return (
sum(trainer.num_test_batches)
if isinstance(trainer.num_test_batches, list)
else trainer.num_test_batches
)
return num_test_batches
if self._schedule.is_predicting:
return sum(trainer.num_predict_batches)
raise NotImplementedError("Unsupported schedule")
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,7 @@ def save_checkpoint(
return super().save_checkpoint(checkpoint=checkpoint, filepath=path)
else:
raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}")
return None

@override
def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]:
Expand Down Expand Up @@ -626,8 +627,7 @@ def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] =
optim.load_state_dict(flattened_osd)

# Load metadata (anything not a module or optimizer)
metadata = torch.load(path / _METADATA_FILENAME, weights_only=weights_only)
return metadata
return torch.load(path / _METADATA_FILENAME, weights_only=weights_only)

if _is_full_checkpoint(path):
checkpoint = _lazy_load(path)
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def save_checkpoint(
if _is_sharded_checkpoint(path):
shutil.rmtree(path)
return super().save_checkpoint(checkpoint=checkpoint, filepath=path)
return None

@override
def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]:
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/strategies/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
decision,
reduce_op=ReduceOp.SUM, # type: ignore[arg-type]
)
decision = bool(decision == self.world_size) if all else bool(decision)
return decision
return bool(decision == self.world_size) if all else bool(decision)

@contextmanager
def block_backward_sync(self) -> Generator:
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
import torch_xla.core.xla_model as xm

tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
tensor = tensor.to(original_device)
return tensor
return tensor.to(original_device)

@override
def teardown(self) -> None:
Expand Down
6 changes: 2 additions & 4 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,8 +1334,7 @@ def training_step(self, batch, batch_idx):
else:
dirpath = self.default_root_dir

dirpath = self.strategy.broadcast(dirpath)
return dirpath
return self.strategy.broadcast(dirpath)

@property
def is_global_zero(self) -> bool:
Expand Down Expand Up @@ -1788,5 +1787,4 @@ def configure_optimizers(self):
assert self.max_epochs is not None
max_estimated_steps = math.ceil(total_batches / self.accumulate_grad_batches) * max(self.max_epochs, 1)

max_estimated_steps = min(max_estimated_steps, self.max_steps) if self.max_steps != -1 else max_estimated_steps
return max_estimated_steps
return min(max_estimated_steps, self.max_steps) if self.max_steps != -1 else max_estimated_steps
3 changes: 1 addition & 2 deletions src/lightning/pytorch/utilities/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ def _determine_model_folder(model_name: str, default_root_dir: str) -> str:
# download the latest checkpoint from the model registry
model_name = model_name.replace("/", "_")
model_name = model_name.replace(":", "_")
local_model_dir = os.path.join(default_root_dir, model_name)
return local_model_dir
return os.path.join(default_root_dir, model_name)


def find_model_local_ckpt_path(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,13 @@ def total_flops(self) -> int:
@property
def flop_counts(self) -> dict[str, dict[Any, int]]:
flop_counts = self._flop_counter.get_flop_counts()
ret = {
return {
name: flop_counts.get(
f"{type(self._model).__name__}.{name}",
{},
)
for name in self.layer_names
}
return ret

def summarize(self) -> dict[str, LayerSummary]:
summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules)
Expand Down
Loading