Skip to content

Commit 43ab0e0

Browse files
committed
Merge branch 'main' into linear_probe_exposure
2 parents a6a37a0 + 11f22ea commit 43ab0e0

40 files changed

Lines changed: 1027 additions & 85 deletions

.github/workflows/ci.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,3 @@ jobs:
3535
with:
3636
args: format --check .
3737

38-
- name: Run tests
39-
if: ${{ always() }}
40-
run: pytest
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
name: Pipeline tests
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
branches: [main]
8+
9+
env:
10+
ASPARAGUS_MODELS: "${{ github.workspace }}/models"
11+
#ASPARAGUS: "${{ github.workspace }}/tests/models"
12+
HYDRA_FULL_ERROR: 1
13+
14+
jobs:
15+
pipeline-tests:
16+
runs-on: ubuntu-latest
17+
timeout-minutes: 20
18+
19+
steps:
20+
- uses: actions/checkout@v4
21+
22+
- uses: actions/setup-python@v5
23+
with:
24+
python-version: "3.11"
25+
cache: pip
26+
27+
- name: Install CPU-only torch then package
28+
run: |
29+
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
30+
pip install -e ".[test]"
31+
32+
- name: Run pretrain test
33+
run: |
34+
pytest tests/test_pretrain.py -v --timeout=300
35+
36+
- name: Run finetune_seg test
37+
run: |
38+
pytest tests/test_finetune_seg.py -v --timeout=300
39+
40+
- name: Run train_reg test
41+
run: |
42+
pytest tests/test_train_reg.py -v --timeout=300
43+
44+
- name: Run test_cls test
45+
run: |
46+
pytest tests/test_test_cls.py -v --timeout=300
47+
48+
- name: Run linear probe test
49+
run: |
50+
pytest tests/test_linear_probe.py -v --timeout=300
51+

.github/workflows/stale.yaml

Lines changed: 0 additions & 26 deletions
This file was deleted.

asparagus/functional/reverse_preprocessing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ def reverse_preprocessing(array, image_properties):
88
pad_bbox = image_properties["pad_box"]
99
crop_bbox = image_properties["crop_box"]
1010

11-
shape = array.shape[2:]
12-
if len(shape) == 2:
11+
ndim = len(array.shape[2:])
12+
if ndim == 2:
1313
mode = "bilinear"
14-
elif len(shape) == 3:
14+
elif ndim == 3:
1515
mode = "trilinear"
1616

1717
if len(pad_bbox) > 0:
1818
array = unpad_array(array, pad_bbox)
19-
verify_shapes_are_equal(reference_shape=shape, target_shape=image_properties["shape_before_pad"])
19+
verify_shapes_are_equal(reference_shape=array.shape[2:], target_shape=image_properties["shape_before_pad"])
2020

2121
array = F.interpolate(array, size=image_properties["size_before_resample"], mode=mode)
2222

asparagus/modules/lightning_modules/linear_probe_module.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torchmetrics.classification import (
1212
MulticlassAUROC,
1313
MulticlassAveragePrecision,
14+
MulticlassF1Score,
1415
)
1516
from torchvision import transforms
1617
from typing import List, Optional
@@ -198,6 +199,7 @@ def configure_test_metrics(self):
198199
{
199200
"AUROC_macro": MulticlassAUROC(num_classes=self.num_classes, average="macro"),
200201
"AUPRC_macro": MulticlassAveragePrecision(num_classes=self.num_classes, average="macro"),
202+
"F1_macro": MulticlassF1Score(num_classes=self.num_classes, average="macro"),
201203
}
202204
)
203205

@@ -211,6 +213,7 @@ def configure_metrics(self, prefix: str):
211213
f"{prefix}/{head_name}/auprc_macro": MulticlassAveragePrecision(
212214
num_classes=self.num_classes, average="macro"
213215
),
216+
f"{prefix}/{head_name}/f1_macro": MulticlassF1Score(num_classes=self.num_classes, average="macro"),
214217
}
215218
)
216219
return metrics

asparagus/modules/lightning_modules/segmentation_module.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@
2323
volume_similarity,
2424
)
2525
from gardening_tools.functional.paths.write import save_json
26-
from gardening_tools.functional.transforms.cropping_and_padding import (
27-
fit_patch_size_to_image_size,
28-
)
2926
from gardening_tools.modules.losses.deep_supervision import DeepSupervisionLoss
3027
from gardening_tools.modules.losses.DiceCE import DiceCE
3128
from gardening_tools.modules.metrics import GeneralizedDiceScore
@@ -51,7 +48,6 @@ def __init__(
5148
val_transforms: Optional[transforms.Compose] = None,
5249
optimizer: str = "SGD",
5350
inference_patch_size: list = [],
54-
inference_mode: str = "3D",
5551
test_output_path: str = None,
5652
log_image_every_n_epochs: int = 50,
5753
weight_decay: float = 3e-5,
@@ -78,7 +74,6 @@ def __init__(
7874
load_decoder=load_decoder,
7975
repeat_stem_weights=repeat_stem_weights,
8076
)
81-
self.inference_mode = inference_mode
8277
self.inference_patch_size = inference_patch_size
8378
self.test_output_path = test_output_path
8479
self.num_classes = model.num_classes
@@ -161,7 +156,6 @@ def training_step(self, batch, batch_idx):
161156

162157
def validation_step(self, batch, batch_idx):
163158
x, y = batch["image"], batch["label"]
164-
165159
pred = self.model(x)
166160
loss = self.val_loss(pred, y)
167161
self.log(
@@ -221,7 +215,7 @@ def test_step(self, batch, batch_idx):
221215

222216
logits = self.model.sliding_window_predict(
223217
data=x,
224-
patch_size=fit_patch_size_to_image_size(self.inference_patch_size, list(x.shape[2:])),
218+
patch_size=self.inference_patch_size,
225219
overlap=0.5,
226220
)
227221

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from asparagus.modules.transforms.clamp import Torch_ClampTarget as Torch_ClampTarget
2+
from asparagus.modules.transforms.crop import Torch_Crop as Torch_Crop
3+
from asparagus.modules.transforms.pad import Torch_Pad as Torch_Pad

0 commit comments

Comments
 (0)