We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 8f38339 commit 27f19baCopy full SHA for 27f19ba
dalle2_pytorch/trainer.py
@@ -181,7 +181,7 @@ def __init__(
181
eps = 1e-6,
182
max_grad_norm = None,
183
group_wd_params = True,
184
- warmup_steps = 1,
+ warmup_steps = None,
185
cosine_decay_max_steps = None,
186
**kwargs
187
):
@@ -357,7 +357,8 @@ def update(self):
357
358
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
359
if not self.accelerator.optimizer_step_was_skipped:
360
- with self.warmup_scheduler.dampening():
+ sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
361
+ with sched_context():
362
self.scheduler.step()
363
364
if self.use_ema:
dalle2_pytorch/version.py
@@ -1 +1 @@
1
-__version__ = '1.8.1'
+__version__ = '1.8.2'
0 commit comments