File tree Expand file tree Collapse file tree 1 file changed +9
-4
lines changed Expand file tree Collapse file tree 1 file changed +9
-4
lines changed Original file line number Diff line number Diff line change 1111
1212class CSKD (BaseClass ):
1313 """
14- Implementation of assisted Knowledge distillation from the paper "Improved Knowledge
15- Distillation via Teacher Assistant" https://arxiv.org/pdf/1902.03393 .pdf
14+ Implementation of "Regularizing Class-wise Predictions via Self-knowledge Distillation"
15+ https://arxiv.org/pdf/2003.13964 .pdf
1616
17- :param teacher_model (torch.nn.Module): Teacher model
17+ :param teacher_model (torch.nn.Module): Teacher model -> Should be None
1818 :param student_model (torch.nn.Module): Student model
1919 :param train_loader (torch.utils.data.DataLoader): Dataloader for training
2020 :param val_loader (torch.utils.data.DataLoader): Dataloader for validation/testing
21- :param optimizer_teacher (torch.optim.*): Optimizer used for training teacher
21+ :param optimizer_teacher (torch.optim.*): Optimizer used for training teacher -> Should be None
2222 :param optimizer_student (torch.optim.*): Optimizer used for training student
2323 :param loss_fn (torch.nn.Module): Calculates loss during distillation
2424 :param temp (float): Temperature parameter for distillation
@@ -60,6 +60,11 @@ def __init__(
6060 logdir ,
6161 )
6262 self .lamda = lamda
63+ if teacher_model is not None or optimizer_teacher is not None :
64+ print (
65+ "Error!!! Teacher model and Teacher optimizer should be None for self-distillation, please refer to the documentation."
66+ )
67+ assert teacher_model == None
6368
6469 def calculate_kd_loss (self , y_pred_pair_1 , y_pred_pair_2 ):
6570 """
You can’t perform that action at this time.
0 commit comments