@@ -134,3 +134,40 @@ def _advance_batch(self, batch: Batch, next_inputs: Tensor, stride: int) -> Batc
134134 constant_scalars = batch .constant_scalars ,
135135 constant_fields = batch .constant_fields ,
136136 )
137+
138+
139+ class EPDTrainProcessor (EncoderProcessorDecoder ):
140+ """Encoder-Processor-Decoder Model training on processor."""
141+
142+ train_processor : Processor
143+
144+ def __init__ (
145+ self ,
146+ encoder_decoder : EncoderDecoder ,
147+ processor : Processor ,
148+ learning_rate : float = 1e-3 ,
149+ stride : int = 1 ,
150+ teacher_forcing_ratio : float = 0.5 ,
151+ max_rollout_steps : int = 10 ,
152+ loss_func : nn .Module | None = None ,
153+ ** kwargs : Any ,
154+ ) -> None :
155+ super ().__init__ (
156+ encoder_decoder = encoder_decoder ,
157+ processor = processor ,
158+ learning_rate = learning_rate ,
159+ stride = stride ,
160+ teacher_forcing_ratio = teacher_forcing_ratio ,
161+ max_rollout_steps = max_rollout_steps ,
162+ loss_func = loss_func ,
163+ ** kwargs ,
164+ )
165+
166+ def training_step (self , batch : Batch , batch_idx : int ) -> Tensor : # noqa: ARG002
167+ encoded_batch = self .encoder_decoder .encoder .encode_batch (batch )
168+ output = self .processor .map (encoded_batch .encoded_inputs )
169+ loss = self .processor .loss (output , encoded_batch .encoded_output_fields )
170+ self .log (
171+ "train_loss" , loss , prog_bar = True , batch_size = batch .input_fields .shape [0 ]
172+ )
173+ return loss
0 commit comments