Skip to content

Commit aeb0118

Browse files
committed
Add separate EPD that trains on processor
1 parent 93c537e commit aeb0118

1 file changed

Lines changed: 37 additions & 0 deletions

File tree

src/auto_cast/models/encoder_processor_decoder.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,40 @@ def _advance_batch(self, batch: Batch, next_inputs: Tensor, stride: int) -> Batc
134134
constant_scalars=batch.constant_scalars,
135135
constant_fields=batch.constant_fields,
136136
)
137+
138+
139+
class EPDTrainProcessor(EncoderProcessorDecoder):
140+
"""Encoder-Processor-Decoder Model training on processor."""
141+
142+
train_processor: Processor
143+
144+
def __init__(
145+
self,
146+
encoder_decoder: EncoderDecoder,
147+
processor: Processor,
148+
learning_rate: float = 1e-3,
149+
stride: int = 1,
150+
teacher_forcing_ratio: float = 0.5,
151+
max_rollout_steps: int = 10,
152+
loss_func: nn.Module | None = None,
153+
**kwargs: Any,
154+
) -> None:
155+
super().__init__(
156+
encoder_decoder=encoder_decoder,
157+
processor=processor,
158+
learning_rate=learning_rate,
159+
stride=stride,
160+
teacher_forcing_ratio=teacher_forcing_ratio,
161+
max_rollout_steps=max_rollout_steps,
162+
loss_func=loss_func,
163+
**kwargs,
164+
)
165+
166+
def training_step(self, batch: Batch, batch_idx: int) -> Tensor: # noqa: ARG002
167+
encoded_batch = self.encoder_decoder.encoder.encode_batch(batch)
168+
output = self.processor.map(encoded_batch.encoded_inputs)
169+
loss = self.processor.loss(output, encoded_batch.encoded_output_fields)
170+
self.log(
171+
"train_loss", loss, prog_bar=True, batch_size=batch.input_fields.shape[0]
172+
)
173+
return loss

0 commit comments

Comments
 (0)