Skip to content

Commit 2243130

Browse files
committed
add adamW optimizer
1 parent 74f9347 commit 2243130

File tree

5 files changed

+6
-5
lines changed

5 files changed

+6
-5
lines changed

config/adapt_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
adapt_parser.add_argument('--patch_size', nargs=3, type=int, default=(64, 64, 64),
3737
help='Expected size for training (x y z)')
3838
# optimizer type, available: [sgd, adam]
39-
adapt_parser.add_argument('--optimizer', type=str, default="adam", help='available: [sgd, adam]')
39+
adapt_parser.add_argument('--optimizer', type=str, default="adam", help='available: [sgd, adam, adamw]')
4040
# loss metric type, available: [bce, dice, tver]
4141
adapt_parser.add_argument('--loss_metric', type=str, default="tver", help="available: [bce, dice, tver]")
4242

config/angiboost_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
# expected size of the training patch
3232
angiboost_parser.add_argument('--osz', nargs=3, type=int, default=(64, 64, 64),
3333
help='Expected size of the training patch (x y z)')
34-
angiboost_parser.add_argument('--optimizer', type=str, default="adam", help="available: [sgd, adam]")
34+
angiboost_parser.add_argument('--optimizer', type=str, default="adam", help="available: [sgd, adam, adamw]")
3535
angiboost_parser.add_argument('--loss_metric', type=str, default="tver", help="available: [bce, dice, tver]")
3636

3737
# Optimizer tuning

config/boost_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
# expected size of the training patch
3131
boost_parser.add_argument('--patch_size', nargs=3, type=int, default=(64, 64, 64),
3232
help='Expected size of the training patch (x y z)')
33-
boost_parser.add_argument('--optimizer', type=str, default="adam", help="available: [sgd, adam]")
33+
boost_parser.add_argument('--optimizer', type=str, default="adam", help="available: [sgd, adam, adamw]")
3434
boost_parser.add_argument('--loss_metric', type=str, default="tver", help="available: [bce, dice, tver]")
3535

3636
# Optimizer tuning

config/train_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
# expected patch size for training
3030
train_parser.add_argument('--output_size', nargs=3, type=int, default=(64, 64, 64), help='Expected patch size for training (x y z)')
3131
# optimizer type, available: [sgd, adam]
32-
train_parser.add_argument('--optimizer', type=str, default="adam", help='available: [sgd, adam]')
32+
train_parser.add_argument('--optimizer', type=str, default="adam", help='available: [sgd, adam, adamw]')
3333
# loss metric type, available: [bce, dice, tver]
3434
train_parser.add_argument('--loss_metric', type=str, default="tver", help="available: [bce, dice, tver]")
3535

library/loss_func.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ def choose_optimizer(optim_name: str, model_params: Any, lr: float) -> Any:
150150
"""
151151
optimizer_registry = {
152152
'sgd': lambda: torch.optim.SGD(model_params, lr),
153-
'adam': lambda: torch.optim.Adam(model_params, lr)
153+
'adam': lambda: torch.optim.Adam(model_params, lr),
154+
'adamw': lambda: torch.optim.AdamW(model_params, lr)
154155
}
155156

156157
if optim_name not in optimizer_registry:

0 commit comments

Comments
 (0)