Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions helios/nn/latent_mim.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
self.target_encoder = deepcopy(self.encoder)
for p in self.target_encoder.parameters():
p.requires_grad = False
self.logit_scale = nn.Parameter(torch.tensor([1.0 / 0.07]).log())

def forward(
self, x: MaskedHeliosSample, patch_size: int
Expand Down
273 changes: 272 additions & 1 deletion helios/train/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,277 @@ def _expand_and_reciprocate(t: Tensor) -> Tensor:
LOSS_REGISTRY = ClassRegistry[Loss]()


@LOSS_REGISTRY.register("clip_patch_discrimination")
class ClipPatchDiscriminationLoss(Loss):
"""Loss function for configurable CLIP-like patch discrimination task.

Closer to the original loss from CLIP paper.
"""

name = "ClipPatchDisc"

def __init__(
self,
label_smoothing: float = 0,
prediction_norm: float | None = None,
target_norm: float | None = None,
modality_loss: bool = True,
symmetric: bool = True,
batch_loss: bool = False,
bandset_loss: bool = False,
spatial_loss: bool = False,
time_loss: bool = False,
mean_of_modalities: bool = True,
sum_of_modalities: bool = False,
decode_only: bool = True,
weight: float = 1.0,
):
"""Initialize patch discrimination loss.

Args:
alpha: scalar multiple for norm
label_smoothing: label smoothing [0,1], 0=none, 1=too much
prediction_norm: norm for predictions,
target_norm: norm for targets,
modality_loss: calculate loss across each modality
symmetric: calculate symmetric version of contrastive loss
batch_loss: caluclate loss across batches
bandset_loss: caluclate loss across bandset
weight: the weight to apply to this loss
mean_of_modalities: mean of means instead of mean of all losses
sum_of_modalities: sum of means instead of mean of all losses
decode_only: only compare to targets masked as decode (prevents cheating maybe?)
spatial_loss: bool = False
time_loss: bool = False
"""
self.label_smoothing = label_smoothing
self.weight = weight
self.prediction_norm = prediction_norm
self.target_norm = target_norm
self.modality_loss = modality_loss
self.symmetric = symmetric
self.batch_loss = batch_loss
self.bandset_loss = bandset_loss
self.mean_of_modalities = mean_of_modalities
self.sum_of_modalities = sum_of_modalities
self.decode_only = decode_only
self.spatial_loss = spatial_loss
self.time_loss = time_loss

def _calculate_modality_loss(
self, preds: Tensor, targs: Tensor, masks: Tensor
) -> Tensor:
preds_flat = rearrange(preds, "b ... d -> b (...) d")
targs_flat = rearrange(targs, "b ... d -> b (...) d")
score = torch.einsum("bxd,byd->bxy", preds_flat, targs_flat) * self.logit_scale
if self.decode_only:
score_mask = (
(masks != MaskValue.DECODER.value)
.flatten(start_dim=1)
.unsqueeze(1)
.expand_as(score)
)
score[score_mask] = torch.finfo(score.dtype).min
label = torch.arange(score.shape[1], dtype=torch.long, device=score.device)
loss = F.cross_entropy(
score.flatten(0, 1),
label.repeat(score.shape[0]),
reduction="none",
label_smoothing=self.label_smoothing,
)[masks.flatten() == MaskValue.DECODER.value]
return loss

def _calculate_batch_loss(
self, preds: Tensor, targs: Tensor, masks: Tensor
) -> Tensor:
preds_flat = rearrange(preds, "b ... d -> (...) b d")
targs_flat = rearrange(targs, "b ... d -> (...) b d")
masks_flat = rearrange(masks, "b ... -> (...) b")
score = torch.einsum("bxd,byd->bxy", preds_flat, targs_flat) * self.logit_scale
if self.decode_only:
score_mask = (
(masks_flat != MaskValue.DECODER.value).unsqueeze(1).expand_as(score)
)
score[score_mask] = torch.finfo(score.dtype).min
label = torch.arange(score.shape[2], dtype=torch.long, device=score.device)
loss = F.cross_entropy(
score.flatten(0, 1),
label.repeat(score.shape[0]),
reduction="none",
label_smoothing=self.label_smoothing,
)[masks_flat.flatten() == MaskValue.DECODER.value]
return loss

def _calculate_bandset_loss(
self, preds: Tensor, targs: Tensor, masks: Tensor
) -> Tensor:
# (B, P_H, P_W, T, Band_Sets, D)
preds_flat = rearrange(preds, "b ... bs d -> (b bs) (...) d")
targs_flat = rearrange(targs, "b ... bs d -> (b bs) (...) d")
masks_flat = rearrange(masks, "b ... bs -> (b bs) (...)")
score = torch.einsum("bxd,byd->bxy", preds_flat, targs_flat) * self.logit_scale
if self.decode_only:
score_mask = (
(masks_flat != MaskValue.DECODER.value)
.flatten(start_dim=1)
.unsqueeze(1)
.expand_as(score)
)
score[score_mask] = torch.finfo(score.dtype).min

label = torch.arange(score.shape[1], dtype=torch.long, device=score.device)
loss = F.cross_entropy(
score.flatten(0, 1),
label.repeat(score.shape[0]),
reduction="none",
label_smoothing=self.label_smoothing,
)[masks_flat.flatten() == MaskValue.DECODER.value]
return loss

def _calculate_spatial_loss(
self, preds: Tensor, targs: Tensor, masks: Tensor
) -> Tensor:
# (B, P_H, P_W, T, Band_Sets, D)
preds_flat = rearrange(preds, "b ph pw t bs d -> (b t) (ph pw bs) d")
targs_flat = rearrange(targs, "b ph pw t bs d -> (b t) (ph pw bs) d")
masks_flat = rearrange(masks, "b ph pw t bs -> (b t) (ph pw bs)")
score = torch.einsum("bxd,byd->bxy", preds_flat, targs_flat) * self.logit_scale
if self.decode_only:
score_mask = (
(masks_flat != MaskValue.DECODER.value)
.flatten(start_dim=1)
.unsqueeze(1)
.expand_as(score)
)
score[score_mask] = torch.finfo(score.dtype).min
label = torch.arange(score.shape[1], dtype=torch.long, device=score.device)
loss = F.cross_entropy(
score.flatten(0, 1),
label.repeat(score.shape[0]),
reduction="none",
label_smoothing=self.label_smoothing,
)[masks_flat.flatten() == MaskValue.DECODER.value]
return loss

def _calculate_time_loss(
self, preds: Tensor, targs: Tensor, masks: Tensor
) -> Tensor:
# (B, P_H, P_W, T, Band_Sets, D)
preds_flat = rearrange(preds, "b ph pw t bs d -> (b ph pw) (t bs) d")
targs_flat = rearrange(targs, "b ph pw t bs d -> (b ph pw) (t bs) d")
masks_flat = rearrange(masks, "b ph pw t bs -> (b ph pw) (t bs)")
score = torch.einsum("bxd,byd->bxy", preds_flat, targs_flat) * self.logit_scale
if self.decode_only:
score_mask = (
(masks_flat != MaskValue.DECODER.value)
.flatten(start_dim=1)
.unsqueeze(1)
.expand_as(score)
)
score[score_mask] = torch.finfo(score.dtype).min
label = torch.arange(score.shape[1], dtype=torch.long, device=score.device)
loss = F.cross_entropy(
score.flatten(0, 1),
label.repeat(score.shape[0]),
reduction="none",
label_smoothing=self.label_smoothing,
)[masks_flat.flatten() == MaskValue.DECODER.value]
return loss

def compute(
self,
predictions: TokensAndMasks,
targets: TokensAndMasks,
logit_scale: Tensor = None,
**kwargs: Any,
) -> Tensor:
"""Compute patch discrimination loss between predictions and targets.

Args:
predictions: Model predictions.
targets: Ground truth targets.
logit_scale: scalar for logit.
**kwargs: Additional keyword arguments.

Returns:
The computed loss value.
"""
# sentinel2: sentinel 2 data of shape (B, P_H, P_W, T, Band_Sets, D)
self.logit_scale = logit_scale
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Loss Function Statefulness Causes Attribute Errors

The ClipPatchDiscriminationLoss uses self.logit_scale in its calculation methods, but it's only assigned within compute from a parameter that defaults to None. This makes the loss function stateful and can lead to AttributeError if compute hasn't been called, or TypeError if logit_scale is None during multiplication.

Additional Locations (5)

Fix in Cursor Fix in Web

losses = []
for modality_name in predictions.modalities:
preds = getattr(predictions, modality_name)
targs = getattr(targets, modality_name)
masks = getattr(
predictions, predictions.get_masked_modality_name(modality_name)
)
if self.target_norm is not None:
targs = self.target_norm * F.normalize(targs, p=2, dim=-1)
if self.prediction_norm is not None:
preds = self.prediction_norm * F.normalize(preds, p=2, dim=-1)

if self.modality_loss:
losses.append(self._calculate_modality_loss(preds, targs, masks))
if self.symmetric:
losses.append(self._calculate_modality_loss(targs, preds, masks))

if self.batch_loss:
losses.append(self._calculate_batch_loss(preds, targs, masks))
if self.symmetric:
losses.append(self._calculate_batch_loss(targs, preds, masks))

if self.bandset_loss:
losses.append(self._calculate_bandset_loss(preds, targs, masks))
if self.symmetric:
losses.append(self._calculate_bandset_loss(targs, preds, masks))

if self.time_loss:
if Modality.get(modality_name).is_multitemporal:
losses.append(self._calculate_time_loss(preds, targs, masks))
if self.symmetric:
losses.append(self._calculate_time_loss(targs, preds, masks))
else:
losses.append(self._calculate_modality_loss(preds, targs, masks))
if self.symmetric:
losses.append(
self._calculate_modality_loss(targs, preds, masks)
)

if self.spatial_loss:
if Modality.get(modality_name).is_multitemporal:
losses.append(self._calculate_spatial_loss(preds, targs, masks))
if self.symmetric:
losses.append(self._calculate_spatial_loss(targs, preds, masks))
else:
losses.append(self._calculate_modality_loss(preds, targs, masks))
if self.symmetric:
losses.append(
self._calculate_modality_loss(targs, preds, masks)
)

if self.mean_of_modalities:
total_loss = torch.stack(
[
loss.mean()
if loss.numel() > 0
else torch.tensor(0, device=loss.device)
for loss in losses
]
)
total_loss = total_loss.mean() if total_loss.numel() > 0 else 0
elif self.sum_of_modalities:
total_loss = torch.stack(
[loss.mean() for loss in losses if loss.numel() > 0]
).sum()
else:
total_loss = torch.cat(
[loss.flatten() for loss in losses if loss.numel() > 0]
)
total_loss = total_loss.mean() if total_loss.numel() > 0 else 0

return self.weight * total_loss


@LOSS_REGISTRY.register("all_discrimination")
class AllDiscriminationLoss(Loss):
"""Loss function for all discrimination task.
Expand Down Expand Up @@ -752,7 +1023,7 @@ def compute(
targets = F.normalize(targets, p=2, dim=-1)
logits = predictions @ targets.transpose(-2, -1)

logger.warning(logits.shape)
# logger.warning(logits.shape)

# Positive keys are the entries on the diagonal
labels = torch.arange(len(predictions), device=predictions.device)
Expand Down
11 changes: 8 additions & 3 deletions helios/train/train_module/contrastive_latentmim.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class ContrastiveLatentMIMTrainModuleConfig(HeliosTrainModuleConfig):
ema_decay: tuple[float, float] = (0.996, 1.0)
max_grad_norm: float = 1.0
contrastive_config: LossConfig | None = None
max_logit_scale: float = 4.6

def build(
self,
Expand Down Expand Up @@ -101,6 +102,7 @@ def __init__(
regularizer_config: LossConfig | None = None,
contrastive_config: LossConfig | None = None,
find_unused_parameters: bool = True,
max_logit_scale: float = 4.6,
):
"""Initialize the training module.

Expand All @@ -127,6 +129,7 @@ def __init__(
regularizer_config: An optional regularizer configuration for the model.
contrastive_config: An optional contrastive configration for the model.
find_unused_parameters: Whether to find unused parameters in the model, only used for DDP.
max_logit_scale: Maximum value for logit scale parameter in loss.
"""
super().__init__(
model=model,
Expand Down Expand Up @@ -161,10 +164,11 @@ def __init__(
self.mae_loss = mae_loss_config.build() if mae_loss_config is not None else None
if self.mae_loss is not None:
self.total_loss_name = f"{self.total_loss_name}+{self.mae_loss.name}"
self.max_logit_scale = max_logit_scale

def loss_fn(self, pred: Any, targets: Any) -> torch.Tensor:
def loss_fn(self, pred: Any, targets: Any, **kwargs: Any) -> torch.Tensor:
"""Compute the loss between the predicted and target tensors."""
return self.base_loss.compute(pred, targets)
return self.base_loss.compute(pred, targets, **kwargs)

def train_batch(
self, batch: tuple[int, HeliosSample], dry_run: bool = False
Expand Down Expand Up @@ -292,7 +296,8 @@ def model_forward(
token_exit_cfg=token_exit_cfg,
)
target_output, _, _ = unpack_encoder_output(output_dict)
loss = self.loss_fn(decoded, target_output)
logit_scale = self.model.logit_scale.clamp(max=self.max_logit_scale).exp()
loss = self.loss_fn(decoded, target_output, logit_scale=logit_scale)
if self.mae_loss is not None and reconstructed is not None:
loss += self.mae_loss.compute(reconstructed, batch)
return loss, latent, decoded, target_output, latent_projected_and_pooled
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ai2-olmo-core @ git+https://github.com/allenai/OLMo-core.git@abc12e50ba756c21e575452cfc6f150dafa9509e # Pin here until >2.1.0 is released.
ai2-olmo-core @ git+https://github.com/allenai/OLMo-core.git@386b0a8 # Pin here until >2.1.0 is released.
albumentations
cartopy
class-registry
Expand Down
Loading
Loading