Skip to content

Commit 8c5541b

Browse files
deploy changes
1 parent d85c976 commit 8c5541b

7 files changed

Lines changed: 199 additions & 3 deletions

File tree

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
name: Deploy to Open Repo
2+
3+
on:
4+
workflow_dispatch:
5+
push:
6+
branches:
7+
- main
8+
9+
permissions:
10+
contents: read
11+
12+
jobs:
13+
deploy:
14+
if: |
15+
github.event.repository.name == 'internal_asparagus' &&
16+
github.event.push.pusher.username != 'github-actions[bot]'
17+
name: Push built output to target repo via SSH
18+
runs-on: ubuntu-latest
19+
steps:
20+
- name: Checkout repository (no persisted credentials)
21+
uses: actions/checkout@v4
22+
with:
23+
ref: deploy
24+
fetch-depth: 0
25+
persist-credentials: false
26+
27+
- name: Configure git author
28+
run: |
29+
git config --global user.name "github-actions[bot]"
30+
git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com"
31+
32+
- name: Update deployment branch and remove this workflow so it doesn't get pushed
33+
run: |
34+
git merge --squash -X theirs origin/main --allow-unrelated-histories
35+
git restore --staged README.md
36+
git commit -m "deploy changes"
37+
38+
- name: Start ssh-agent and add deploy key
39+
uses: webfactory/ssh-agent@v0.9.1
40+
with:
41+
ssh-private-key: ${{ secrets.DEPLOY }}
42+
43+
- name: Ensure github.com is in known_hosts
44+
run: |
45+
mkdir -p ~/.ssh
46+
ssh-keyscan github.com >> ~/.ssh/known_hosts
47+
chmod 644 ~/.ssh/known_hosts
48+
49+
- name: Set SSH remote and push via SSH
50+
run: |
51+
echo "Pushing to SSH remote: git@github.com:Sllambias/asparagus.git -> branch main"
52+
53+
git remote remove origin || true
54+
git remote add ssh-origin git@github.com:Sllambias/asparagus.git
55+
56+
# Sanity checks
57+
git remote -v
58+
echo "SSH_AUTH_SOCK=$SSH_AUTH_SOCK"
59+
ssh -T -o StrictHostKeyChecking=no git@github.com || true
60+
61+
git push --force -u ssh-origin HEAD:main

asparagus/modules/lightning_modules/base_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060
self.repeat_stem_weights = repeat_stem_weights
6161
assert 0 < cosine_period_ratio <= 1
6262

63-
self.save_hyperparameters(ignore=["model"])
63+
self.save_hyperparameters(ignore=["model", "train_transforms", "val_transforms", "test_transforms"])
6464
self.model = model
6565

6666
if weights is not None:

asparagus/modules/networks/resenc_unet.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,95 @@
1+
import logging
2+
import torch
3+
from gardening_tools.modules.networks.BaseNet import BaseNet
4+
from gardening_tools.modules.networks.components.blocks import ResidualBlock
5+
from gardening_tools.modules.networks.components.encoders import ResidualUNetEncoder
6+
from gardening_tools.modules.networks.components.heads import ClsRegHead
17
from gardening_tools.modules.networks.resunet import ResidualEncoderUNet
8+
from torch import nn
9+
from typing import List, Tuple, Type, Union
10+
11+
12+
class ResidualEncoderUNetCLSREG(BaseNet):
13+
def __init__(
14+
self,
15+
input_channels: int,
16+
output_channels: int,
17+
dimensions: str,
18+
kernel_size: int,
19+
stride: int,
20+
features_per_stage: list,
21+
n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]],
22+
conv_bias: bool = True,
23+
encoder_basic_block: Type[ResidualBlock] = ResidualBlock,
24+
decoder: nn.Module = ClsRegHead,
25+
norm_op_kwargs={"eps": 1e-05, "affine": True},
26+
dropout_op=None,
27+
dropout_op_kwargs=None,
28+
nonlin=torch.nn.LeakyReLU,
29+
nonlin_kwargs={"inplace": True},
30+
):
31+
super().__init__()
32+
33+
# Extract dropout rates from kwargs
34+
if dropout_op_kwargs is None:
35+
dropout_op_kwargs = {}
36+
37+
encoder_dropout_rate = dropout_op_kwargs.get("encoder_dropout_rate", 0.0)
38+
decoder_dropout_rate = dropout_op_kwargs.get("decoder_dropout_rate", 0.0)
39+
inplace = dropout_op_kwargs.get("inplace", True)
40+
41+
if dimensions == "2D":
42+
conv_op = nn.Conv2d
43+
dropout_op = nn.Dropout2d
44+
norm_op = nn.InstanceNorm2d
45+
pool_op = nn.MaxPool2d
46+
clsreg_pool_op = nn.AdaptiveAvgPool2d
47+
elif dimensions == "3D":
48+
conv_op = nn.Conv3d
49+
dropout_op = nn.Dropout3d
50+
norm_op = nn.InstanceNorm3d
51+
pool_op = nn.MaxPool3d
52+
clsreg_pool_op = nn.AdaptiveAvgPool3d
53+
else:
54+
logging.warning("Uuh, dimensions not in ['2D', '3D']")
55+
56+
self.num_classes = output_channels
57+
58+
self.stem_weight_name = "encoder.stem.conv1.conv.weight"
59+
60+
self.encoder = ResidualUNetEncoder(
61+
input_channels=input_channels,
62+
features_per_stage=features_per_stage,
63+
conv_op=conv_op,
64+
kernel_size=kernel_size,
65+
stride=stride,
66+
n_blocks_per_stage=n_blocks_per_stage,
67+
conv_bias=conv_bias,
68+
norm_op=norm_op,
69+
norm_op_kwargs=norm_op_kwargs,
70+
dropout_op=dropout_op,
71+
dropout_op_kwargs={"p": encoder_dropout_rate, "inplace": inplace},
72+
nonlin=nonlin,
73+
nonlin_kwargs=nonlin_kwargs,
74+
block=encoder_basic_block,
75+
pool_op=pool_op,
76+
)
77+
78+
self.decoder = decoder(
79+
pool_op=clsreg_pool_op,
80+
input_channels=features_per_stage[-1],
81+
output_channels=output_channels,
82+
dropout_rate=decoder_dropout_rate,
83+
)
84+
85+
def forward(self, x):
86+
skips = self.encoder(x)
87+
return self.decoder(skips)
88+
89+
def forward_with_features(self, x):
90+
skips = self.encoder(x)
91+
output = self.decoder(skips)
92+
return output, skips[-1]
293

394

495
# Encoder 29M parameters
@@ -46,6 +137,26 @@ def resenc_unet_b(
46137
)
47138

48139

140+
# Encoder 90M parameters
141+
# Full model - 90.3 M Total params
142+
def resenc_unet_b_clsreg(
143+
input_channels: int = 1,
144+
output_channels: int = 1,
145+
dimensions: str = "3D",
146+
dropout_op_kwargs: dict = None,
147+
):
148+
return ResidualEncoderUNetCLSREG(
149+
dimensions=dimensions,
150+
input_channels=input_channels,
151+
output_channels=output_channels,
152+
features_per_stage=(32, 64, 128, 256, 320, 320),
153+
stride=2,
154+
kernel_size=3,
155+
n_blocks_per_stage=(1, 3, 4, 6, 6, 6),
156+
dropout_op_kwargs=dropout_op_kwargs,
157+
)
158+
159+
49160
# Encoder 345M parameters
50161
# Full model 391M parameters
51162
def resenc_unet_l(

asparagus/pipeline/auto_configuration/versioning.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def pathing(cfg, train=True):
2626
pretrained_ckpt = os.path.join(model_folder, "checkpoints", cfg.load_checkpoint_name)
2727
assert cfg.checkpoint_path is None, "You cannot provide both a checkpoint path and a checkpoint run id"
2828
elif cfg.checkpoint_path is not None and cfg.checkpoint_path != "":
29+
model_folder = None
2930
pretrained_ckpt = cfg.checkpoint_path
3031
else:
3132
model_folder, pretrained_ckpt = None, None

configs/model/core/resenc_unet.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ _seg_net:
1111
_cls_net:
1212
_target_: asparagus.modules.networks.resenc_unet.${model.cls_net}
1313
dimensions: ${model.dimensions}
14+
dropout_op_kwargs:
15+
encoder_dropout_rate: ${model.encoder_dropout_rate}
16+
decoder_dropout_rate: ${model.decoder_dropout_rate}
17+
inplace: true
1418

1519
_plugin_seg_net:
1620
_target_: asparagus.modules.networks.resenc_unet.${model.plugin_seg_net}

configs/model/resenc_unet_b.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@ dimensions: 3D
66
pretrain_net: resenc_unet_b
77

88
seg_net: resenc_unet_b
9-
cls_net:
9+
cls_net:
1010

11-
plugin_seg_net:
11+
plugin_seg_net:
1212
plugin_seg_train_n_last_params: 10
1313

14+
# Dropout rates (0.0 = no dropout)
15+
encoder_dropout_rate: 0.0
16+
decoder_dropout_rate: 0.0
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
defaults:
2+
- core/resenc_unet@
3+
4+
dimensions: 3D
5+
6+
pretrain_net:
7+
8+
seg_net:
9+
cls_net: resenc_unet_b_clsreg
10+
11+
plugin_seg_net:
12+
plugin_seg_train_n_last_params:
13+
14+
# Dropout rates (0.0 = no dropout)
15+
encoder_dropout_rate: 0.0
16+
decoder_dropout_rate: 0.2

0 commit comments

Comments
 (0)