diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6f7157c..96fa029 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,11 +9,11 @@ repos: - id: debug-statements - id: check-merge-conflict -- repo: https://github.com/psf/black - rev: 23.3.0 - hooks: - - id: black - args: [--line-length=100] +#- repo: https://github.com/psf/black +# rev: 23.3.0 +# hooks: +# - id: black +# args: [--line-length=100] #- repo: https://github.com/pycqa/flake8 # rev: 6.0.0 diff --git a/train/optim.py b/train/optim.py index 785ce22..2bb140c 100644 --- a/train/optim.py +++ b/train/optim.py @@ -20,6 +20,7 @@ def step(self): if grad.shape[0] == param.data.shape[0] else grad.sum(axis=0) ) + except BaseException: raise ValueError( "Cannot align grad shape " diff --git a/train/trainer.py b/train/trainer.py index ad31d3a..e726668 100644 --- a/train/trainer.py +++ b/train/trainer.py @@ -11,6 +11,7 @@ from monitor.tracker import TrainingTracker + class Trainer: """ Trainer class that handles model training, evaluation, and monitoring.