21
21
else :
22
22
LIGHTNING_INSTALLED = True
23
23
24
+ try :
25
+ import lightning as L
26
+ except ImportError :
27
+ LIGHTNING2_INSTALLED = False
28
+ else :
29
+ LIGHTNING2_INSTALLED = True
30
+
24
31
try :
25
32
from transformers .trainer import Trainer as HFTrainer
26
33
except ImportError :
@@ -161,7 +168,7 @@ def _init_optimizer_helpers(self, pure_model: Module | pl.LightningModule):
161
168
"""
162
169
raise NotImplementedError
163
170
164
- def bind_model (self , model : Module | pl .LightningModule , param_names_map : Dict [str , str ] | None = None ):
171
+ def bind_model (self , model : Module | pl .LightningModule | L . LightningModule , param_names_map : Dict [str , str ] | None = None ):
165
172
"""
166
173
Bind the model suitable for this ``Evaluator`` to use the evaluator's abilities of model modification,
167
174
model training, and model evaluation.
@@ -312,25 +319,27 @@ class LightningEvaluator(Evaluator):
312
319
If the the test metric is needed by nni, please make sure log metric with key ``default`` in ``LightningModule.test_step()``.
313
320
"""
314
321
315
- def __init__ (self , trainer : pl .Trainer , data_module : pl .LightningDataModule ,
322
+ def __init__ (self , trainer : pl .Trainer | L . trainer , data_module : pl .LightningDataModule ,
316
323
dummy_input : Any | None = None ):
317
324
assert LIGHTNING_INSTALLED , 'pytorch_lightning is not installed.'
318
325
err_msg_p = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
319
326
err_msg = err_msg_p .format ('pytorch_lightning.Trainer' , 'pytorch_lightning.Trainer' )
320
- assert isinstance (trainer , pl .Trainer ) and is_traceable (trainer ), err_msg
327
+ lighting2_check = not LIGHTNING2_INSTALLED and isinstance (trainer , L .Trainer )
328
+ assert (isinstance (trainer , pl .Trainer ) or lighting2_check ) and is_traceable (trainer ), err_msg
321
329
err_msg = err_msg_p .format ('pytorch_lightning.LightningDataModule' , 'pytorch_lightning.LightningDataModule' )
322
- assert isinstance (data_module , pl .LightningDataModule ) and is_traceable (data_module ), err_msg
330
+ lighting2_check = not LIGHTNING2_INSTALLED and isinstance (data_module , L .LightningDataModule )
331
+ assert (isinstance (data_module , pl .LightningDataModule ) or lighting2_check ) and is_traceable (data_module ), err_msg
323
332
self .trainer = trainer
324
333
self .data_module = data_module
325
334
self ._dummy_input = dummy_input
326
335
327
- self .model : pl .LightningModule | None = None
336
+ self .model : pl .LightningModule | L . LightningModule | None = None
328
337
self ._ori_model_attr = {}
329
338
self ._param_names_map : Dict [str , str ] | None = None
330
339
331
340
self ._initialization_complete = False
332
341
333
- def _init_optimizer_helpers (self , pure_model : pl .LightningModule ):
342
+ def _init_optimizer_helpers (self , pure_model : pl .LightningModule | L . LightningModule ):
334
343
assert self ._initialization_complete is False , 'Evaluator initialization is already complete.'
335
344
336
345
self ._optimizer_helpers = []
@@ -395,10 +404,14 @@ def _init_optimizer_helpers(self, pure_model: pl.LightningModule):
395
404
396
405
self ._initialization_complete = True
397
406
398
- def bind_model (self , model : pl .LightningModule , param_names_map : Dict [str , str ] | None = None ):
407
+ def bind_model (
408
+ self ,
409
+ model : pl .LightningModule | L .LightningModule ,
410
+ param_names_map : Dict [str , str ] | None = None
411
+ ):
399
412
err_msg = 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
400
413
assert self ._initialization_complete is True , err_msg
401
- assert isinstance (model , pl .LightningModule )
414
+ assert isinstance (model , pl .LightningModule ) or isinstance ( model , L . LightningModule )
402
415
if self .model is not None :
403
416
_logger .warning ('Already bound a model, will unbind it before bind a new model.' )
404
417
self .unbind_model ()
@@ -425,7 +438,7 @@ def unbind_model(self):
425
438
_logger .warning ('Did not bind any model, no need to unbind model.' )
426
439
427
440
def _patch_configure_optimizers (self ):
428
- assert isinstance (self .model , pl .LightningModule )
441
+ assert isinstance (self .model , pl .LightningModule ) or isinstance ( self . model , L . LightningModule )
429
442
430
443
if self ._opt_returned_dicts :
431
444
def new_configure_optimizers (_ ): # type: ignore
@@ -452,11 +465,11 @@ def new_configure_optimizers(_):
452
465
self .model .configure_optimizers = types .MethodType (new_configure_optimizers , self .model )
453
466
454
467
def _revert_configure_optimizers (self ):
455
- assert isinstance (self .model , pl .LightningModule )
468
+ assert isinstance (self .model , pl .LightningModule ) or isinstance ( self . model , L . LightningModule )
456
469
self .model .configure_optimizers = self ._ori_model_attr ['configure_optimizers' ]
457
470
458
471
def patch_loss (self , patch : Callable [[Tensor ], Tensor ]):
459
- assert isinstance (self .model , pl .LightningModule )
472
+ assert isinstance (self .model , pl .LightningModule ) or isinstance ( self . model , L . LightningModule )
460
473
old_training_step = self .model .training_step
461
474
462
475
def patched_training_step (_ , * args , ** kwargs ):
@@ -470,19 +483,28 @@ def patched_training_step(_, *args, **kwargs):
470
483
self .model .training_step = types .MethodType (patched_training_step , self .model )
471
484
472
485
def revert_loss (self ):
473
- assert isinstance (self .model , pl .LightningModule )
486
+ assert isinstance (self .model , pl .LightningModule ) or isinstance ( self . model , L . LightningModule )
474
487
self .model .training_step = self ._ori_model_attr ['training_step' ]
475
488
476
489
def patch_optimizer_step (self , before_step_tasks : List [Callable ], after_step_tasks : List [Callable ]):
477
- assert isinstance (self .model , pl .LightningModule )
490
+ assert isinstance (self .model , pl .LightningModule ) or isinstance ( self . model , L . LightningModule )
478
491
479
492
class OptimizerCallback (Callback ):
480
- def on_before_optimizer_step (self , trainer : pl .Trainer , pl_module : pl .LightningModule ,
481
- optimizer : Optimizer , opt_idx : int ) -> None :
493
+ def on_before_optimizer_step (
494
+ self ,
495
+ trainer : pl .Trainer | L .Trainer ,
496
+ pl_module : pl .LightningModule | L .LightningModule ,
497
+ optimizer : Optimizer , opt_idx : int
498
+ ) -> None :
482
499
for task in before_step_tasks :
483
500
task ()
484
501
485
- def on_before_zero_grad (self , trainer : pl .Trainer , pl_module : pl .LightningModule , optimizer : Optimizer ) -> None :
502
+ def on_before_zero_grad (
503
+ self ,
504
+ trainer : pl .Trainer | L .trainer ,
505
+ pl_module : pl .LightningModule | L .LightningModule ,
506
+ optimizer : Optimizer ,
507
+ ) -> None :
486
508
for task in after_step_tasks :
487
509
task ()
488
510
@@ -496,13 +518,13 @@ def patched_configure_callbacks(_):
496
518
self .model .configure_callbacks = types .MethodType (patched_configure_callbacks , self .model )
497
519
498
520
def revert_optimizer_step (self ):
499
- assert isinstance (self .model , pl .LightningModule )
521
+ assert isinstance (self .model , pl .LightningModule ) or isinstance ( self . model , L . LightningModule )
500
522
self .model .configure_callbacks = self ._ori_model_attr ['configure_callbacks' ]
501
523
502
524
def train (self , max_steps : int | None = None , max_epochs : int | None = None ):
503
- assert isinstance (self .model , pl .LightningModule )
525
+ assert isinstance (self .model , pl .LightningModule ) or isinstance ( self . model , L . LightningModule )
504
526
# reset trainer
505
- trainer : pl .Trainer = self .trainer .trace_copy ().get () # type: ignore
527
+ trainer : pl .Trainer | L . Trainer = self .trainer .trace_copy ().get () # type: ignore
506
528
# NOTE: lightning may dry run some steps at first for sanity check in Trainer.fit() by default,
507
529
# If we want to record some information in the forward hook, we may get some additional information,
508
530
# so using Trainer.num_sanity_val_steps = 0 disable sanity check.
@@ -529,9 +551,9 @@ def evaluate(self) -> Tuple[float | None, List[Dict[str, float]]]:
529
551
E.g., if ``Trainer.test()`` returned ``[{'default': 0.8, 'loss': 2.3}, {'default': 0.6, 'loss': 2.4}]``,
530
552
NNI will take the final metric ``(0.8 + 0.6) / 2 = 0.7``.
531
553
"""
532
- assert isinstance (self .model , pl .LightningModule )
554
+ assert isinstance (self .model , pl .LightningModule ) or isinstance ( self . model , L . LightningModule )
533
555
# reset trainer
534
- trainer : pl .Trainer = self .trainer .trace_copy ().get () # type: ignore
556
+ trainer : pl .Trainer | L . Trainer = self .trainer .trace_copy ().get () # type: ignore
535
557
original_results = trainer .test (self .model , self .data_module )
536
558
# del trainer reference, we don't want to dump trainer when we dump the entire model.
537
559
self .model .trainer = None
@@ -831,7 +853,7 @@ def __init__(self, trainer: HFTrainer, dummy_input: Any | None = None) -> None:
831
853
832
854
self ._initialization_complete = False
833
855
834
- def _init_optimizer_helpers (self , pure_model : Module | pl .LightningModule ):
856
+ def _init_optimizer_helpers (self , pure_model : Module | pl .LightningModule | L . LightningModule ):
835
857
assert self ._initialization_complete is False , 'Evaluator initialization is already complete.'
836
858
837
859
if self .traced_trainer .optimizer is not None and is_traceable (self .traced_trainer .optimizer ):
@@ -862,7 +884,7 @@ def patched_get_optimizer_cls_and_kwargs(args) -> Tuple[Any, Any]:
862
884
863
885
self ._initialization_complete = True
864
886
865
- def bind_model (self , model : Module | pl .LightningModule , param_names_map : Dict [str , str ] | None = None ):
887
+ def bind_model (self , model : Module | pl .LightningModule | L . LightningModule , param_names_map : Dict [str , str ] | None = None ):
866
888
err_msg = 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
867
889
assert self ._initialization_complete is True , err_msg
868
890
assert isinstance (model , Module )
0 commit comments