Skip to content

Commit e538fd8

Browse files
committed
remove silly "hidden" methods and add configure model to extract feature dim from dummy output
1 parent 7b27084 commit e538fd8

2 files changed

Lines changed: 30 additions & 35 deletions

File tree

asparagus/modules/lightning_modules/linear_probe_module.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def __init__(
5757
self.test_output_path = test_output_path
5858
self.optimizer_momentum = optimizer_momentum
5959
self.optimizer_weight_decay = optimizer_weight_decay
60+
self.learning_rates = learning_rates
61+
self.dimensions = dimensions
6062

6163
for param in self.model.parameters():
6264
param.requires_grad = False
@@ -67,20 +69,6 @@ def __init__(
6769
else:
6870
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
6971

70-
try:
71-
# Asparagus by default expects decoder head here
72-
feature_dim = self.model.decoder.fc.in_features
73-
except AttributeError as e:
74-
# Our MAE model has different layer name
75-
logging.warning(f"`self.model.decoder.fc.in_features` raised {e}, falling back to `self.model.head.in_features`.")
76-
feature_dim = self.model.head.in_features
77-
78-
self.heads = nn.ModuleDict()
79-
for lr in learning_rates:
80-
head_name = self._lr_to_linear_head_name(lr)
81-
head = self._make_head(feature_dim, num_classes)
82-
self.heads[head_name] = head
83-
8472
self.loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor(loss_weight) if loss_weight else None)
8573

8674
self.train_metrics = self.configure_metrics("train")
@@ -89,25 +77,37 @@ def __init__(
8977
# Test metrics (only for best head)
9078
self.test_metrics = self.configure_test_metrics()
9179

80+
self.heads = None
9281
self.best_head_lr = None
9382
self.ignore_index_in_metrics = -1
9483

95-
def _make_head(self, feature_dim: int, num_classes: int) -> nn.Module:
84+
def configure_model(self):
85+
if self.heads is None:
86+
tmp_arr = torch.zeros((1, 1, 32, 32, 32)) if self.dimensions == "3D" else torch.zeros((1, 1, 32, 32))
87+
feature_dim = self.get_features(tmp_arr).view(-1).size(0)
88+
self.heads = nn.ModuleDict()
89+
90+
for lr in self.learning_rates:
91+
head_name = self.lr_to_linear_head_name(lr)
92+
head = self.make_head(feature_dim, self.num_classes)
93+
self.heads[head_name] = head
94+
95+
def make_head(self, feature_dim: int, num_classes: int) -> nn.Module:
9696
head = nn.Linear(feature_dim, num_classes)
9797
nn.init.normal_(head.weight, mean=0.0, std=0.01)
9898
nn.init.zeros_(head.bias)
9999
return head
100100

101101
@staticmethod
102-
def _lr_to_linear_head_name(lr: float) -> str:
102+
def lr_to_linear_head_name(lr: float) -> str:
103103
return f"lr_{lr:.0e}".replace(".", "_").replace("+", "").replace("-", "m")
104104

105105
def train(self, mode=True):
106106
super().train(mode)
107107
self.model.eval()
108108
return self
109109

110-
def _get_features(self, x: torch.Tensor) -> torch.Tensor:
110+
def get_features(self, x: torch.Tensor) -> torch.Tensor:
111111
with torch.no_grad():
112112
skips = self.model._encode(x)
113113

@@ -117,13 +117,12 @@ def _get_features(self, x: torch.Tensor) -> torch.Tensor:
117117
return torch.flatten(features, 1)
118118

119119
def on_before_batch_transfer(self, batch, dataloader_idx):
120-
batch["CLSREG_label"] = batch["CLSREG_label"].squeeze(-1).long()
120+
batch["CLSREG_label"] = batch["CLSREG_label"].view(-1).long()
121121
return batch
122122

123123
def training_step(self, batch, batch_idx):
124124
x, y = batch["image"], batch["CLSREG_label"]
125-
features = self._get_features(x)
126-
125+
features = self.get_features(x)
127126
total_loss = 0.0
128127
for head_name, head in self.heads.items():
129128
logits = head(features)
@@ -143,16 +142,15 @@ def training_step(self, batch, batch_idx):
143142
@torch.no_grad()
144143
def on_train_epoch_end(self):
145144
for lr in self.learning_rates:
146-
head_name = self._lr_to_linear_head_name(lr)
145+
head_name = self.lr_to_linear_head_name(lr)
147146
metrics = self.train_metrics[head_name].compute()
148147
formatted = format_multilabel_metrics(metrics, ignore_index=self.ignore_index_in_metrics)
149148
self.log_dict(formatted, sync_dist=True)
150149
self.train_metrics[head_name].reset()
151150

152151
def validation_step(self, batch, batch_idx):
153152
x, y = batch["image"], batch["CLSREG_label"]
154-
features = self._get_features(x)
155-
153+
features = self.get_features(x)
156154
for head_name, head in self.heads.items():
157155
logits = head(features)
158156
loss = self.loss_fn(logits, y)
@@ -170,7 +168,7 @@ def validation_step(self, batch, batch_idx):
170168
def on_validation_epoch_end(self):
171169
current_aurocs = {}
172170
for lr in self.learning_rates:
173-
head_name = self._lr_to_linear_head_name(lr)
171+
head_name = self.lr_to_linear_head_name(lr)
174172
metrics = self.val_metrics[head_name].compute()
175173
formatted = format_multilabel_metrics(metrics, ignore_index=self.ignore_index_in_metrics)
176174
self.log_dict(formatted, sync_dist=True)
@@ -206,7 +204,7 @@ def configure_test_metrics(self):
206204
def configure_metrics(self, prefix: str):
207205
metrics = nn.ModuleDict()
208206
for lr in self.learning_rates:
209-
head_name = self._lr_to_linear_head_name(lr)
207+
head_name = self.lr_to_linear_head_name(lr)
210208
metrics[head_name] = MetricCollection(
211209
{
212210
f"{prefix}/{head_name}/auroc_macro": MulticlassAUROC(num_classes=self.num_classes, average="macro"),
@@ -219,15 +217,15 @@ def configure_metrics(self, prefix: str):
219217
return metrics
220218

221219
def on_test_epoch_start(self):
222-
logging.info(f"Testing with head: {self._lr_to_linear_head_name(self.best_head_lr)} (lr={self.best_head_lr})")
220+
logging.info(f"Testing with head: {self.lr_to_linear_head_name(self.best_head_lr)} (lr={self.best_head_lr})")
223221
self.results = {}
224222
self.logits = []
225223
self.labels = []
226224

227225
def test_step(self, batch, batch_idx):
228226
x = batch["image"]
229-
features = self._get_features(x)
230-
logits = self.heads[self._lr_to_linear_head_name(self.best_head_lr)](features)
227+
features = self.get_features(x)
228+
logits = self.heads[self.lr_to_linear_head_name(self.best_head_lr)](features)
231229

232230
label = batch["CLSREG_label"]
233231
self.results[batch["file_path"]] = {
@@ -239,15 +237,13 @@ def test_step(self, batch, batch_idx):
239237

240238
def on_test_epoch_end(self):
241239
logits_tensor = torch.stack(self.logits).float()
242-
labels_tensor = torch.stack(self.labels)
243-
240+
labels_tensor = torch.stack(self.labels).view(-1)
244241
avg_results = self.test_metrics(logits_tensor, labels_tensor)
245242
avg_results = {key: value.cpu().numpy().tolist() for key, value in avg_results.items()}
246-
247243
self.results["metrics"] = avg_results
248-
self.results["best_head"] = self._lr_to_linear_head_name(self.best_head_lr)
244+
self.results["best_head"] = self.lr_to_linear_head_name(self.best_head_lr)
249245
self.results["best_head_lr"] = self.best_head_lr
250246
os.makedirs(os.path.split(self.test_output_path)[0], exist_ok=True)
251247
save_json(self.results, self.test_output_path)
252-
logging.info(f"Test using best head: {self._lr_to_linear_head_name(self.best_head_lr)} (lr={self.best_head_lr})")
248+
logging.info(f"Test using best head: {self.lr_to_linear_head_name(self.best_head_lr)} (lr={self.best_head_lr})")
253249
logging.info(f"Aggregated test results for {len(self.results)} files: {avg_results}")

tests/test_linear_probe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ def test_linear_probe_validate_fit_test(cls_probe_files, tmp_path, make_trainer)
2828
kernel_size=3,
2929
n_blocks_per_stage=(1, 1),
3030
)
31-
3231
data_module = ClsRegDataModule(
33-
batch_size=2,
32+
batch_size=1,
3433
num_workers=2, # val_dataloader uses num_workers//2; needs >=2
3534
train_split=cls_probe_files["train"],
3635
val_split=cls_probe_files["val"],

0 commit comments

Comments
 (0)