@@ -105,7 +105,7 @@ def save_state(model, optimizer: torch.optim.AdamW, scheduler: torch.optim.lr_sc
105105 If you want to export the model after training set final_export=True and leave is_checkpoint=False.
106106
107107 Args:
108- model (Base_RelationExtraction ): BertModel_RelationExtraction | LlamaModel_RelationExtraction
108+ model (BaseModel_RelationExtraction ): BertModel_RelationExtraction | LlamaModel_RelationExtraction etc.
109109 optimizer (torch.optim.AdamW, optional): Defaults to None.
110110 scheduler (torch.optim.lr_scheduler.MultiStepLR, optional): Defaults to None.
111111 epoch (int): Defaults to None.
@@ -136,11 +136,11 @@ def save_state(model, optimizer: torch.optim.AdamW, scheduler: torch.optim.lr_sc
136136 }, os .path .join (path , file_name ))
137137
138138
139- def load_state (model , optimizer , scheduler , path = "./" , model_name = "BERT" , file_prefix = "train" , load_best = False , config : ConfigRelCAT = ConfigRelCAT ()) -> Tuple [int , int ]:
139+ def load_state (model , optimizer , scheduler , path : str = "./" , model_name : str = "BERT" , file_prefix : str = "train" , load_best : bool = False , relcat_config : ConfigRelCAT = ConfigRelCAT ()) -> Tuple [int , int ]:
140140 """ Used by RelCAT.load() and RelCAT.train()
141141
142142 Args:
143- model (Base_RelationExtraction ): BertModel_RelationExtraction | LlamaModel_RelationExtraction , it has to be initialized before calling this method via (Bert/Llama)Model_RelationExtraction(...)
143+ model (BaseModel_RelationExtraction ): BaseModel_RelationExtraction , it has to be initialized before calling this method via (Bert/Llama)Model_RelationExtraction(...)
144144 optimizer (_type_): optimizer
145145 scheduler (_type_): scheduler
146146 path (str, optional): Defaults to "./".
@@ -153,7 +153,7 @@ def load_state(model, optimizer, scheduler, path="./", model_name="BERT", file_p
153153 Tuple (int, int): last epoch and f1 score.
154154 """
155155
156- device : torch .device = torch .device (config .general .device )
156+ device : torch .device = torch .device (relcat_config .general .device )
157157
158158 model_name = model_name .replace ("/" , "_" )
159159 logging .info ("Attempting to load RelCAT model on device: " + str (device ))
@@ -178,13 +178,13 @@ def load_state(model, optimizer, scheduler, path="./", model_name="BERT", file_p
178178
179179 if optimizer is None :
180180 parameters = filter (lambda p : p .requires_grad , model .parameters ())
181- optimizer = torch .optim .AdamW (params = parameters , lr = config .train .lr , weight_decay = config .train .adam_weight_decay ,
182- betas = config .train .adam_betas , eps = config .train .adam_epsilon )
181+ optimizer = torch .optim .AdamW (params = parameters , lr = relcat_config .train .lr , weight_decay = relcat_config .train .adam_weight_decay ,
182+ betas = relcat_config .train .adam_betas , eps = relcat_config .train .adam_epsilon )
183183
184184 if scheduler is None :
185185 scheduler = torch .optim .lr_scheduler .MultiStepLR (optimizer ,
186- milestones = config .train .multistep_milestones ,
187- gamma = config .train .multistep_lr_gamma )
186+ milestones = relcat_config .train .multistep_milestones ,
187+ gamma = relcat_config .train .multistep_lr_gamma )
188188 optimizer .load_state_dict (checkpoint ['optimizer' ])
189189 scheduler .load_state_dict (checkpoint ['scheduler' ])
190190 logging .info ("Loaded model and optimizer." )
0 commit comments