Skip to content

Commit ede86ca

Browse files
deploy changes
1 parent b3e8a9c commit ede86ca

5 files changed

Lines changed: 57 additions & 15 deletions

File tree

asparagus/functional/lr_scheduling.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def sawtooth_warmup_cosine_decay_schedule(
4343
Phase 2: Both encoder and decoder warmup
4444
Phase 3: Cosine annealing for both
4545
"""
46+
assert max_epochs > 0 and steps_per_epoch > 0, "max_epochs and steps_per_epoch must be greater than 0"
4647
print(f"Using separate warmup: decoder for {decoder_warmup_epochs} epochs, then both for {warmup_epochs} epochs")
4748

4849
decoder_warmup_steps = int(decoder_warmup_epochs * steps_per_epoch)
@@ -72,16 +73,24 @@ def decoder_phase1_lambda(step):
7273
)
7374

7475

75-
def simple_warmup_cosine_decay_schedule(optimizer, warmup_epochs, steps_per_epoch, cosine_period_ratio, max_epochs):
76+
def simple_warmup_cosine_decay_schedule(
77+
optimizer, warmup_epochs, steps_per_epoch, cosine_period_ratio, max_epochs=-1, max_steps=-1
78+
):
7679
"""
7780
Phase 1: Warmup for both encoder and decoder
7881
Phase 2: Cosine annealing for both
7982
"""
80-
print(f"Using warmup for {warmup_epochs} epochs")
83+
assert warmup_epochs >= 0, "Warmup epochs must be greater than or equal to 0."
84+
assert cosine_period_ratio > 0, "Cosine period ratio must be greater than 0."
85+
assert steps_per_epoch > 0, "Steps per epoch must be greater than 0."
86+
assert max_epochs > 0 or max_steps > 0, "Either max_epochs or max_steps must be greater than 0."
8187

82-
total_warmup_steps = int(warmup_epochs * steps_per_epoch)
8388
# cosine_half_period is from max to min
84-
cosine_steps = int(cosine_period_ratio * (max_epochs * steps_per_epoch - total_warmup_steps))
89+
if max_epochs > 0:
90+
max_steps = max_epochs * steps_per_epoch
91+
92+
total_warmup_steps = int(warmup_epochs * steps_per_epoch)
93+
cosine_steps = int(cosine_period_ratio * (max_steps - total_warmup_steps))
8594

8695
cosine_scheduler = CosineAnnealingLR(optimizer, T_max=cosine_steps)
8796
warmup_scheduler = LinearLR(
@@ -90,17 +99,25 @@ def simple_warmup_cosine_decay_schedule(optimizer, warmup_epochs, steps_per_epoc
9099
total_iters=total_warmup_steps,
91100
)
92101

102+
print(f"Using warmup for {warmup_epochs} epochs ({total_warmup_steps} steps)")
103+
print(f"Cosine decay for {cosine_steps} steps after warmup")
104+
assert total_warmup_steps > 0, "Warmup steps must be greater than 0 for warmup schedule."
105+
assert cosine_steps > 0, "Cosine steps must be greater than 0 for warmup cosine decay schedule."
106+
93107
return SequentialLR(
94108
optimizer,
95109
schedulers=[warmup_scheduler, cosine_scheduler],
96110
milestones=[total_warmup_steps],
97111
)
98112

99113

100-
def cosine_decay_schedule(optimizer, steps_per_epoch, cosine_period_ratio, max_epochs):
114+
def cosine_decay_schedule(optimizer, steps_per_epoch, cosine_period_ratio, max_epochs=-1, max_steps=-1):
101115
"""
102116
Phase 1: Cosine annealing for both encoder and decoder
103117
"""
104118
# cosine_half_period is from max to min
105-
cosine_steps = int(cosine_period_ratio * (max_epochs * steps_per_epoch))
119+
if max_epochs > 0:
120+
max_steps = max_epochs * steps_per_epoch
121+
cosine_steps = int(cosine_period_ratio * max_steps)
122+
assert cosine_steps > 0, "Cosine steps must be greater than 0 for cosine decay schedule."
106123
return CosineAnnealingLR(optimizer, T_max=cosine_steps)

asparagus/modules/lightning_modules/base_module.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,34 +110,42 @@ def configure_optimizers(self):
110110

111111
print(f"Using optimizer {optimizer.__class__.__name__} with learning rate {self.learning_rate}")
112112

113-
steps_per_epoch = self.trainer.estimated_stepping_batches // self.trainer.max_epochs
113+
# Calculate steps per epoch based on trainer configuration
114+
# if max_epochs is *not* set (i.e., set to -1), we are probably using max_steps
115+
# if max_epochs is set, we can calculate steps per epoch based on estimated_stepping_batches
116+
if self.trainer.max_epochs <= 0:
117+
optimizer_steps_per_epoch = self.trainer.limit_train_batches // self.trainer.accumulate_grad_batches
118+
else:
119+
optimizer_steps_per_epoch = self.trainer.estimated_stepping_batches // self.trainer.max_epochs
114120

115121
# Scheduler option 1: Three-phase schedule with separate decoder/joint warmup
116122
if self.decoder_warmup_epochs > 0:
117123
scheduler = sawtooth_warmup_cosine_decay_schedule(
118124
optimizer,
119125
self.decoder_warmup_epochs,
120126
self.warmup_epochs,
121-
steps_per_epoch,
127+
optimizer_steps_per_epoch,
122128
self.cosine_period_ratio,
123-
self.trainer.max_epochs,
129+
self.trainer.max_epochs, # may be -1, if using max_steps
124130
)
125131
# Scheduler option 2: Two-phase schedule with joint warmup
126132
elif self.warmup_epochs > 0:
127133
scheduler = simple_warmup_cosine_decay_schedule(
128134
optimizer,
129135
self.warmup_epochs,
130-
steps_per_epoch,
136+
optimizer_steps_per_epoch,
131137
self.cosine_period_ratio,
132-
self.trainer.max_epochs,
138+
self.trainer.max_epochs, # may be -1, if using max_steps
139+
self.trainer.max_steps, # may be -1, if using max_epochs
133140
)
134141
# Scheduler option 3: Just cosine annealing
135142
else:
136143
scheduler = cosine_decay_schedule(
137144
optimizer,
138-
steps_per_epoch,
145+
optimizer_steps_per_epoch,
139146
self.cosine_period_ratio,
140-
self.trainer.max_epochs,
147+
self.trainer.max_epochs, # may be -1, if using max_steps
148+
self.trainer.max_steps, # may be -1, if using max_epochs
141149
)
142150

143151
scheduler_config = {
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from asparagus.modules.transforms.clamp import Torch_ClampTarget as Torch_ClampTarget
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import torch
2+
3+
4+
class Torch_ClampTarget:
5+
def __init__(self, clamp: bool = False, min_value: float = 0.0, max_value: float = 1.0):
6+
self.clamp = clamp
7+
self.min_value = min_value
8+
self.max_value = max_value
9+
10+
def __call__(self, data_dict: dict) -> dict:
11+
if self.clamp and "label" in data_dict:
12+
data_dict["label"] = torch.clamp(data_dict["label"], min=self.min_value, max=self.max_value)
13+
return data_dict

asparagus/modules/transforms/presets/pretrain.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from asparagus.modules.transforms import Torch_ClampTarget
12
from gardening_tools.functional.transforms.spatial import get_max_rotated_size
23
from gardening_tools.modules.transforms.bias_field import Torch_BiasField
34
from gardening_tools.modules.transforms.blur import Torch_Blur
@@ -18,8 +19,9 @@ def CPU_val_transforms(patch_size):
1819
return transforms.Compose(
1920
[
2021
Torch_Normalize(normalize=True),
21-
Torch_CropPad(patch_size=patch_size, p_oversample_foreground=0.4),
22+
Torch_CropPad(patch_size=patch_size, p_oversample_foreground=0.0),
2223
Torch_CopyImageToLabel(copy=True),
24+
Torch_ClampTarget(clamp=True, min_value=-2.0, max_value=4.0),
2325
]
2426
)
2527

@@ -36,7 +38,7 @@ def CPU_train_transforms(patch_size):
3638
return transforms.Compose(
3739
[
3840
Torch_Normalize(normalize=True),
39-
Torch_CropPad(patch_size=pre_aug_patch_size, p_oversample_foreground=0.4),
41+
Torch_CropPad(patch_size=pre_aug_patch_size, p_oversample_foreground=0.0),
4042
Torch_Spatial(
4143
patch_size=patch_size,
4244
p_deform_all_channel=0.0,
@@ -47,6 +49,7 @@ def CPU_train_transforms(patch_size):
4749
skip_label=False,
4850
),
4951
Torch_CopyImageToLabel(copy=True),
52+
Torch_ClampTarget(clamp=True, min_value=-2.0, max_value=4.0),
5053
]
5154
)
5255

0 commit comments

Comments
 (0)