1+ from typing import Any
2+
13import lightning as L
24import torch
35from torch import nn
46
57from auto_cast .decoders import Decoder
68from auto_cast .encoders import Encoder
7- from auto_cast .processors .base import Preprocessor , Processor
9+ from auto_cast .processors .base import Preprocessor
810from auto_cast .types import Batch , Tensor
911
1012
@@ -19,16 +21,22 @@ class EncoderDecoder(L.LightningModule):
1921 def __init__ (self ):
2022 pass
2123
22- def forward (self , x : Tensor ) -> Tensor :
23- return self .decoder (self .encoder (x ))
24+ def forward (self , * args : Any , ** kwargs : Any ) -> Any :
25+ return self .decoder (self .encoder (* args , ** kwargs ))
2426
2527 def training_step (self , batch : Batch , batch_idx : int ) -> Tensor : # noqa: ARG002
2628 x = self .preprocessor (batch )
2729 output = self (x )
2830 loss = self .loss_func (output , batch ["output_fields" ])
2931 return loss # noqa: RET504
3032
31- def encode_only (self , x : Batch ) -> Tensor :
33+ def validation_step (self , batch : Batch , batch_idx : int ) -> Tensor : ...
34+
35+ def test_step (self , batch : Batch , batch_idx : int ) -> Tensor : ...
36+
37+ def predict_step (self , batch : Batch , batch_idx : int ) -> Tensor : ...
38+
39+ def encode (self , x : Batch ) -> Tensor :
3240 x = self .preprocessor (x )
3341 return self .encoder (x )
3442
@@ -45,25 +53,7 @@ def forward(self, x: Tensor) -> Tensor:
4553 x = self .decoder (z )
4654 return x # noqa: RET504
4755
48- def reparametrize (self , mu , log_var ) :
56+ def reparametrize (self , mu : Tensor , log_var : Tensor ) -> Tensor :
4957 std = torch .exp (0.5 * log_var )
5058 eps = torch .randn_like (std )
5159 return mu + eps * std
52-
53-
54- class EncoderProcessorDecoder (L .LightningModule ):
55- """Encoder-Processor-Decoder Model."""
56-
57- encoder_decoder : EncoderDecoder
58- processor : Processor
59-
60- def __init__ (self ): ...
61-
62- def forward (self , x : Tensor ) -> Tensor :
63- return self .encoder_decoder .decoder (
64- self .processor (self .encoder_decoder .encoder (x ))
65- )
66-
67- def training_step (self , batch : Batch , batch_idx : int ) -> Tensor : ...
68-
69- def configure_optmizers (self ): ...
0 commit comments