Skip to content

Commit e756bfb

Browse files
committed
set run name manually
1 parent 03f3f6b commit e756bfb

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

scripts/templates/train_etnn.sh.j2

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Training script for {{ experiment_name }}
44

55
# DEFINE EXP ARGUMENTS
6+
EXP_NAME={{ experiment_name }}
67
LIFTERS=({{ lifters | join(' ') }})
78
DIM={{ dim }}
89
VISIBLE_DIMS=({{ visible_dims | join(' ') }})
@@ -76,6 +77,7 @@ do
7677
--splits "$SPLITS" \
7778
--normalize_invariants \
7879
--clip_gradient \
80+
--run_name "${EXP_NAME} ${TARGET_NAME}" \
7981
--checkpoint_dir "$CHECKPOINT_DIR" &
8082

8183
# Wait for 30 minutes before moving to the next iteration

src/main_qm9.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ def main(args):
118118
resume="must",
119119
)
120120
else:
121-
wandb.init(project="QM9-clean-experiments")
121+
run_name = args.run_name
122+
wandb.init(project="QM9-clean-experiments", name=run_name)
122123
run_id = wandb.run.id
123-
run_name = wandb.run.name
124124

125125
for epoch in tqdm(range(start_epoch, args.epochs)):
126126
epoch_start_time, epoch_mae_train, epoch_mae_val = time.time(), 0, 0
@@ -290,6 +290,9 @@ def main(args):
290290
parser.add_argument("--num_samples", type=int, default=None, help="num samples to to train on")
291291
parser.add_argument("--splits", type=str, default="egnn", help="split type")
292292

293+
# wandb arguments
294+
parser.add_argument("--run_name", type=str, default=None, help="run name")
295+
293296
parsed_args = parser.parse_args()
294297
parsed_args = parser_utils.add_common_derived_arguments(parsed_args)
295298
parsed_args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

0 commit comments

Comments
 (0)