Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Commit 27c21e5

Browse files
committed
adapt pytorch lighting 2.0 AKA lightning
1 parent 928575b commit 27c21e5

File tree

2 files changed

+85
-44
lines changed

2 files changed

+85
-44
lines changed

Diff for: nni/compression/pytorch/utils/evaluator.py

+45-23
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@
2121
else:
2222
LIGHTNING_INSTALLED = True
2323

24+
try:
25+
import lightning as L
26+
except ImportError:
27+
LIGHTNING2_INSTALLED = False
28+
else:
29+
LIGHTNING2_INSTALLED = True
30+
2431
try:
2532
from transformers.trainer import Trainer as HFTrainer
2633
except ImportError:
@@ -161,7 +168,7 @@ def _init_optimizer_helpers(self, pure_model: Module | pl.LightningModule):
161168
"""
162169
raise NotImplementedError
163170

164-
def bind_model(self, model: Module | pl.LightningModule, param_names_map: Dict[str, str] | None = None):
171+
def bind_model(self, model: Module | pl.LightningModule | L.LightningModule, param_names_map: Dict[str, str] | None = None):
165172
"""
166173
Bind the model suitable for this ``Evaluator`` to use the evaluator's abilities of model modification,
167174
model training, and model evaluation.
@@ -312,25 +319,27 @@ class LightningEvaluator(Evaluator):
312319
If the the test metric is needed by nni, please make sure log metric with key ``default`` in ``LightningModule.test_step()``.
313320
"""
314321

315-
def __init__(self, trainer: pl.Trainer, data_module: pl.LightningDataModule,
322+
def __init__(self, trainer: pl.Trainer | L.trainer, data_module: pl.LightningDataModule,
316323
dummy_input: Any | None = None):
317324
assert LIGHTNING_INSTALLED, 'pytorch_lightning is not installed.'
318325
err_msg_p = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
319326
err_msg = err_msg_p.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer')
320-
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), err_msg
327+
lighting2_check = not LIGHTNING2_INSTALLED and isinstance(trainer, L.Trainer)
328+
assert (isinstance(trainer, pl.Trainer) or lighting2_check) and is_traceable(trainer), err_msg
321329
err_msg = err_msg_p.format('pytorch_lightning.LightningDataModule', 'pytorch_lightning.LightningDataModule')
322-
assert isinstance(data_module, pl.LightningDataModule) and is_traceable(data_module), err_msg
330+
lighting2_check = not LIGHTNING2_INSTALLED and isinstance(data_module, L.LightningDataModule)
331+
assert (isinstance(data_module, pl.LightningDataModule) or lighting2_check) and is_traceable(data_module), err_msg
323332
self.trainer = trainer
324333
self.data_module = data_module
325334
self._dummy_input = dummy_input
326335

327-
self.model: pl.LightningModule | None = None
336+
self.model: pl.LightningModule | L.LightningModule | None = None
328337
self._ori_model_attr = {}
329338
self._param_names_map: Dict[str, str] | None = None
330339

331340
self._initialization_complete = False
332341

333-
def _init_optimizer_helpers(self, pure_model: pl.LightningModule):
342+
def _init_optimizer_helpers(self, pure_model: pl.LightningModule | L.LightningModule):
334343
assert self._initialization_complete is False, 'Evaluator initialization is already complete.'
335344

336345
self._optimizer_helpers = []
@@ -395,10 +404,14 @@ def _init_optimizer_helpers(self, pure_model: pl.LightningModule):
395404

396405
self._initialization_complete = True
397406

398-
def bind_model(self, model: pl.LightningModule, param_names_map: Dict[str, str] | None = None):
407+
def bind_model(
408+
self,
409+
model: pl.LightningModule | L.LightningModule,
410+
param_names_map: Dict[str, str] | None = None
411+
):
399412
err_msg = 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
400413
assert self._initialization_complete is True, err_msg
401-
assert isinstance(model, pl.LightningModule)
414+
assert isinstance(model, pl.LightningModule) or isinstance(model, L.LightningModule)
402415
if self.model is not None:
403416
_logger.warning('Already bound a model, will unbind it before bind a new model.')
404417
self.unbind_model()
@@ -425,7 +438,7 @@ def unbind_model(self):
425438
_logger.warning('Did not bind any model, no need to unbind model.')
426439

427440
def _patch_configure_optimizers(self):
428-
assert isinstance(self.model, pl.LightningModule)
441+
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
429442

430443
if self._opt_returned_dicts:
431444
def new_configure_optimizers(_): # type: ignore
@@ -452,11 +465,11 @@ def new_configure_optimizers(_):
452465
self.model.configure_optimizers = types.MethodType(new_configure_optimizers, self.model)
453466

454467
def _revert_configure_optimizers(self):
455-
assert isinstance(self.model, pl.LightningModule)
468+
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
456469
self.model.configure_optimizers = self._ori_model_attr['configure_optimizers']
457470

458471
def patch_loss(self, patch: Callable[[Tensor], Tensor]):
459-
assert isinstance(self.model, pl.LightningModule)
472+
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
460473
old_training_step = self.model.training_step
461474

462475
def patched_training_step(_, *args, **kwargs):
@@ -470,19 +483,28 @@ def patched_training_step(_, *args, **kwargs):
470483
self.model.training_step = types.MethodType(patched_training_step, self.model)
471484

472485
def revert_loss(self):
473-
assert isinstance(self.model, pl.LightningModule)
486+
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
474487
self.model.training_step = self._ori_model_attr['training_step']
475488

476489
def patch_optimizer_step(self, before_step_tasks: List[Callable], after_step_tasks: List[Callable]):
477-
assert isinstance(self.model, pl.LightningModule)
490+
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
478491

479492
class OptimizerCallback(Callback):
480-
def on_before_optimizer_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule,
481-
optimizer: Optimizer, opt_idx: int) -> None:
493+
def on_before_optimizer_step(
494+
self,
495+
trainer: pl.Trainer | L.Trainer,
496+
pl_module: pl.LightningModule | L.LightningModule,
497+
optimizer: Optimizer, opt_idx: int
498+
) -> None:
482499
for task in before_step_tasks:
483500
task()
484501

485-
def on_before_zero_grad(self, trainer: pl.Trainer, pl_module: pl.LightningModule, optimizer: Optimizer) -> None:
502+
def on_before_zero_grad(
503+
self,
504+
trainer: pl.Trainer | L.trainer,
505+
pl_module: pl.LightningModule | L.LightningModule,
506+
optimizer: Optimizer,
507+
) -> None:
486508
for task in after_step_tasks:
487509
task()
488510

@@ -496,13 +518,13 @@ def patched_configure_callbacks(_):
496518
self.model.configure_callbacks = types.MethodType(patched_configure_callbacks, self.model)
497519

498520
def revert_optimizer_step(self):
499-
assert isinstance(self.model, pl.LightningModule)
521+
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
500522
self.model.configure_callbacks = self._ori_model_attr['configure_callbacks']
501523

502524
def train(self, max_steps: int | None = None, max_epochs: int | None = None):
503-
assert isinstance(self.model, pl.LightningModule)
525+
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
504526
# reset trainer
505-
trainer: pl.Trainer = self.trainer.trace_copy().get() # type: ignore
527+
trainer: pl.Trainer | L.Trainer = self.trainer.trace_copy().get() # type: ignore
506528
# NOTE: lightning may dry run some steps at first for sanity check in Trainer.fit() by default,
507529
# If we want to record some information in the forward hook, we may get some additional information,
508530
# so using Trainer.num_sanity_val_steps = 0 disable sanity check.
@@ -529,9 +551,9 @@ def evaluate(self) -> Tuple[float | None, List[Dict[str, float]]]:
529551
E.g., if ``Trainer.test()`` returned ``[{'default': 0.8, 'loss': 2.3}, {'default': 0.6, 'loss': 2.4}]``,
530552
NNI will take the final metric ``(0.8 + 0.6) / 2 = 0.7``.
531553
"""
532-
assert isinstance(self.model, pl.LightningModule)
554+
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
533555
# reset trainer
534-
trainer: pl.Trainer = self.trainer.trace_copy().get() # type: ignore
556+
trainer: pl.Trainer | L.Trainer = self.trainer.trace_copy().get() # type: ignore
535557
original_results = trainer.test(self.model, self.data_module)
536558
# del trainer reference, we don't want to dump trainer when we dump the entire model.
537559
self.model.trainer = None
@@ -831,7 +853,7 @@ def __init__(self, trainer: HFTrainer, dummy_input: Any | None = None) -> None:
831853

832854
self._initialization_complete = False
833855

834-
def _init_optimizer_helpers(self, pure_model: Module | pl.LightningModule):
856+
def _init_optimizer_helpers(self, pure_model: Module | pl.LightningModule | L.LightningModule):
835857
assert self._initialization_complete is False, 'Evaluator initialization is already complete.'
836858

837859
if self.traced_trainer.optimizer is not None and is_traceable(self.traced_trainer.optimizer):
@@ -862,7 +884,7 @@ def patched_get_optimizer_cls_and_kwargs(args) -> Tuple[Any, Any]:
862884

863885
self._initialization_complete = True
864886

865-
def bind_model(self, model: Module | pl.LightningModule, param_names_map: Dict[str, str] | None = None):
887+
def bind_model(self, model: Module | pl.LightningModule | L.LightningModule, param_names_map: Dict[str, str] | None = None):
866888
err_msg = 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
867889
assert self._initialization_complete is True, err_msg
868890
assert isinstance(model, Module)

Diff for: nni/contrib/compression/utils/evaluator.py

+40-21
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@
2222
else:
2323
LIGHTNING_INSTALLED = True
2424

25+
try:
26+
import lightning as L
27+
except ImportError:
28+
LIGHTNING2_INSTALLED = False
29+
else:
30+
LIGHTNING2_INSTALLED = True
31+
2532
try:
2633
from transformers.trainer import Trainer as HFTrainer
2734
except ImportError:
@@ -149,7 +156,7 @@ class Evaluator:
149156
_initialization_complete: bool
150157
_hook: List[Hook]
151158

152-
def _init_optimizer_helpers(self, pure_model: Module | pl.LightningModule):
159+
def _init_optimizer_helpers(self, pure_model: Module | pl.LightningModule | L.LightningModule):
153160
"""
154161
This is an internal API, ``pure_model`` means the model is the original model passed in by the user,
155162
it should not be the modified model (wrapped, hooked, or patched by NNI).
@@ -164,7 +171,7 @@ def _init_optimizer_helpers(self, pure_model: Module | pl.LightningModule):
164171
"""
165172
raise NotImplementedError
166173

167-
def bind_model(self, model: Module | pl.LightningModule, param_names_map: Dict[str, str] | None = None):
174+
def bind_model(self, model: Module | pl.LightningModule | L.LightningModule, param_names_map: Dict[str, str] | None = None):
168175
"""
169176
Bind the model suitable for this ``Evaluator`` to use the evaluator's abilities of model modification,
170177
model training, and model evaluation.
@@ -186,8 +193,12 @@ def unbind_model(self):
186193
"""
187194
raise NotImplementedError
188195

189-
def _optimizer_add_param_group(self, model: Union[torch.nn.Module, pl.LightningModule],
190-
module_name_param_dict: Dict[str, List[Tensor]], optimizers: Optimizer | List[Optimizer]):
196+
def _optimizer_add_param_group(
197+
self,
198+
model: Union[torch.nn.Module, pl.LightningModule, L.LightningModule],
199+
module_name_param_dict: Dict[str, List[Tensor]],
200+
optimizers: Optimizer | List[Optimizer]
201+
):
191202
# used in the bind_model process
192203
def find_param_group(param_groups: List[Dict], module_name: str):
193204
for i, param_group in enumerate(param_groups):
@@ -367,25 +378,33 @@ class LightningEvaluator(Evaluator):
367378
If the the test metric is needed by nni, please make sure log metric with key ``default`` in ``LightningModule.test_step()``.
368379
"""
369380

370-
def __init__(self, trainer: pl.Trainer, data_module: pl.LightningDataModule,
381+
def __init__(self, trainer: pl.Trainer, data_module: pl.LightningDataModule | L.LightningDataModule,
371382
dummy_input: Any | None = None):
372383
assert LIGHTNING_INSTALLED, 'pytorch_lightning is not installed.'
373384
err_msg_p = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
374-
err_msg = err_msg_p.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer')
375-
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), err_msg
376-
err_msg = err_msg_p.format('pytorch_lightning.LightningDataModule', 'pytorch_lightning.LightningDataModule')
385+
err_msg = err_msg_p.format(
386+
'pytorch_lightning.Trainer or lightning.Trainer',
387+
'pytorch_lightning.Trainer or lightning.Trainer',
388+
)
389+
lighting2_check = not LIGHTNING2_INSTALLED and isinstance(trainer, L.Trainer)
390+
assert (isinstance(trainer, pl.Trainer) or lighting2_check)and is_traceable(trainer), err_msg
391+
err_msg = err_msg_p.format(
392+
'pytorch_lightning.LightningDataModule or lightning.LightningDataModule',
393+
'pytorch_lightning.LightningDataModule or lightning.LightningDataModule',
394+
)
395+
lighting2_check = not LIGHTNING2_INSTALLED and isinstance(data_module, L.LightningDataModule)
377396
assert isinstance(data_module, pl.LightningDataModule) and is_traceable(data_module), err_msg
378-
self.trainer = trainer
379-
self.data_module = data_module
397+
self.trainer: pl.Trainer | L.Trainer = trainer
398+
self.data_module: pl.LightningDataModule | L.LightningDataModule = data_module
380399
self._dummy_input = dummy_input
381400

382-
self.model: pl.LightningModule | None = None
401+
self.model: pl.LightningModule | L.LightningModule | None = None
383402
self._ori_model_attr = {}
384403
self._param_names_map: Dict[str, str] | None = None
385404

386405
self._initialization_complete = False
387406

388-
def _init_optimizer_helpers(self, pure_model: pl.LightningModule):
407+
def _init_optimizer_helpers(self, pure_model: pl.LightningModule | L.LightningModule):
389408
assert self._initialization_complete is False, 'Evaluator initialization is already complete.'
390409

391410
self._optimizer_helpers = []
@@ -450,7 +469,7 @@ def _init_optimizer_helpers(self, pure_model: pl.LightningModule):
450469

451470
self._initialization_complete = True
452471

453-
def bind_model(self, model: pl.LightningModule, param_names_map: Dict[str, str] | None = None):
472+
def bind_model(self, model: pl.LightningModule | L.LightningModule, param_names_map: Dict[str, str] | None = None):
454473
err_msg = 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
455474
assert self._initialization_complete is True, err_msg
456475
assert isinstance(model, pl.LightningModule)
@@ -514,7 +533,7 @@ def new_configure_optimizers(_):
514533
self.model.configure_optimizers = types.MethodType(new_configure_optimizers, self.model)
515534

516535
def _patch_configure_optimizers(self):
517-
assert isinstance(self.model, pl.LightningModule)
536+
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
518537
if self._opt_returned_dicts:
519538
def new_configure_optimizers(_): # type: ignore
520539
optimizers = [opt_helper.call(self.model, self._param_names_map) for opt_helper in self._optimizer_helpers] # type: ignore
@@ -559,11 +578,11 @@ def patched_training_step(_, *args, **kwargs):
559578
self.model.training_step = types.MethodType(patched_training_step, self.model)
560579

561580
def revert_loss(self):
562-
assert isinstance(self.model, pl.LightningModule)
581+
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
563582
self.model.training_step = self._ori_model_attr['training_step']
564583

565584
def patch_optimizer_step(self, before_step_tasks: List[Callable], after_step_tasks: List[Callable]):
566-
assert isinstance(self.model, pl.LightningModule)
585+
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
567586
old_configure_optimizers = self.model.configure_optimizers
568587

569588
def patched_step_factory(old_step):
@@ -599,13 +618,13 @@ def new_configure_optimizers(_):
599618
self.model.configure_optimizers = types.MethodType(new_configure_optimizers, self.model)
600619

601620
def revert_optimizer_step(self):
602-
assert isinstance(self.model, pl.LightningModule)
621+
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
603622
self.model.configure_callbacks = self._ori_model_attr['configure_callbacks']
604623

605624
def train(self, max_steps: int | None = None, max_epochs: int | None = None):
606-
assert isinstance(self.model, pl.LightningModule)
625+
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
607626
# reset trainer
608-
trainer: pl.Trainer = self.trainer.trace_copy().get() # type: ignore
627+
trainer: pl.Trainer | L.Trainer = self.trainer.trace_copy().get() # type: ignore
609628
# NOTE: lightning may dry run some steps at first for sanity check in Trainer.fit() by default,
610629
# If we want to record some information in the forward hook, we may get some additional information,
611630
# so using Trainer.num_sanity_val_steps = 0 disable sanity check.
@@ -632,9 +651,9 @@ def evaluate(self) -> Tuple[float | None, List[Dict[str, float]]]:
632651
E.g., if ``Trainer.test()`` returned ``[{'default': 0.8, 'loss': 2.3}, {'default': 0.6, 'loss': 2.4}]``,
633652
NNI will take the final metric ``(0.8 + 0.6) / 2 = 0.7``.
634653
"""
635-
assert isinstance(self.model, pl.LightningModule)
654+
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
636655
# reset trainer
637-
trainer: pl.Trainer = self.trainer.trace_copy().get() # type: ignore
656+
trainer: pl.Trainer | L.trainer = self.trainer.trace_copy().get() # type: ignore
638657
original_results = trainer.test(self.model, self.data_module)
639658
# del trainer reference, we don't want to dump trainer when we dump the entire model.
640659
self.model.trainer = None

0 commit comments

Comments
 (0)