Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions rslearn/train/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def __init__(
task: the task to train on
path: the dataset path
path_options: additional options for path to pass to fsspec.
batch_size: the batch size
batch_size: the total batch size across all GPUs. In multi-GPU
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we often set batch_size based on available GPU memory. I don't think the existing option should be changed in behavior; if desired you could deprecate the existing one and add per_gpu_batch_size and global_batch_size options to replace it, and then it should raise error if neither is set or if both are set.

training, this is divided by world_size to get the per-GPU
batch size.
num_workers: number of data loader worker processes, or 0 to use main
process only
init_workers: number of workers used to initialize the dataset, e.g. for
Expand Down Expand Up @@ -215,9 +217,19 @@ def _get_dataloader(
):
num_workers = min(num_workers, len(dataset.get_dataset_examples()))

# Compute per-GPU batch size from total batch size.
per_gpu_batch_size = self.batch_size
if self.trainer is not None and self.trainer.world_size > 1:
if self.batch_size % self.trainer.world_size != 0:
raise ValueError(
f"batch_size ({self.batch_size}) must be divisible by "
f"world_size ({self.trainer.world_size})"
)
per_gpu_batch_size = self.batch_size // self.trainer.world_size

kwargs: dict[str, Any] = dict(
dataset=dataset,
batch_size=self.batch_size,
batch_size=per_gpu_batch_size,
num_workers=num_workers,
collate_fn=collate_fn,
persistent_workers=persistent_workers,
Expand Down
7 changes: 5 additions & 2 deletions rslearn/train/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,14 @@ def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
return d

def on_train_epoch_start(self) -> None:
"""If we are in a multi-dataset distributed strategy, set the epoch."""
"""Set the epoch on the distributed sampler so shuffling varies each epoch."""
try:
self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch)
except AttributeError:
# Fail silently for single-dataset case, which is okay
pass
try:
self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch)
except AttributeError:
pass

def _log_non_scalar_metric(self, name: str, value: NonScalarMetricOutput) -> None:
Expand Down
24 changes: 24 additions & 0 deletions rslearn/train/tasks/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def __init__(
weights: list[float] | None = None,
dice_loss: bool = False,
temperature: float = 1.0,
smooth_sigma: float = 0.0,
):
"""Initialize a new SegmentationTask.

Expand All @@ -321,6 +322,9 @@ def __init__(
dice_loss: weather to add dice loss to cross entropy
temperature: temperature scaling for softmax, does not affect the loss,
only the predictor outputs
smooth_sigma: if > 0, apply a fixed Gaussian blur to logits before
computing loss and outputs. The filter is non-learned but
differentiable, so gradients flow through it to the model.
"""
super().__init__()
if weights is not None:
Expand All @@ -329,6 +333,22 @@ def __init__(
self.weights = None
self.dice_loss = dice_loss
self.temperature = temperature
self.smooth_sigma = smooth_sigma

def _gaussian_smooth(self, logits: torch.Tensor) -> torch.Tensor:
"""Apply depthwise Gaussian blur to logits. Differentiable, no learned params."""
sigma = self.smooth_sigma
radius = int(3 * sigma + 0.5)
size = 2 * radius + 1
x = torch.arange(size, device=logits.device, dtype=logits.dtype) - radius
g1d = torch.exp(-(x**2) / (2 * sigma**2))
g2d = g1d[:, None] * g1d[None, :]
g2d = g2d / g2d.sum()
channels = logits.shape[1]
kernel = g2d.unsqueeze(0).unsqueeze(0).expand(channels, 1, size, size)
# Use replicate padding to avoid border artifacts from zero padding
padded = torch.nn.functional.pad(logits, [radius] * 4, mode="replicate")
return torch.nn.functional.conv2d(padded, kernel, groups=channels)

def forward(
self,
Expand Down Expand Up @@ -357,6 +377,10 @@ def forward(
)

logits = intermediates.feature_maps[0]

if self.smooth_sigma > 0:
logits = self._gaussian_smooth(logits)

outputs = torch.nn.functional.softmax(logits / self.temperature, dim=1)

losses = {}
Expand Down
Loading