Skip to content

Commit 627b22a

Browse files
Extra arguments for instantiating convolutional layers
Signed-off-by: João Lucas de Sousa Almeida <[email protected]> Testing extra kwargs Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent abd5f12 commit 627b22a

File tree

3 files changed

+8
-1
lines changed

3 files changed

+8
-1
lines changed

simulai/models/_pytorch_models/_autoencoder.py

+2
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,7 @@ def __init__(
10981098
scale: float = 1e-3,
10991099
devices: Union[str, list] = "cpu",
11001100
name: str = None,
1101+
**kwargs,
11011102
) -> None:
11021103
"""
11031104
Constructor method.
@@ -1172,6 +1173,7 @@ def __init__(
11721173
shallow=shallow,
11731174
use_batch_norm=use_batch_norm,
11741175
name=self.name,
1176+
**kwargs
11751177
)
11761178

11771179
self.encoder = encoder.to(self.device)

simulai/templates/_templates.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,7 @@ def cnn_autoencoder_auto(
673673
use_batch_norm: bool = False,
674674
shallow: bool = False,
675675
name: str = None,
676+
**kwargs,
676677
) -> Tuple[NetworkTemplate, ...]:
677678

678679
"""
@@ -737,7 +738,7 @@ def cnn_autoencoder_auto(
737738

738739
autogen_cnn = NetworkInstanceGen(
739740
architecture="cnn", dim=case, use_batch_norm=use_batch_norm,
740-
kernel_size=kernel_size,
741+
kernel_size=kernel_size, **kwargs,
741742
)
742743

743744
autogen_dense = NetworkInstanceGen(architecture="dense", shallow=shallow)
@@ -798,6 +799,7 @@ def autoencoder_auto(
798799
use_batch_norm: bool = False,
799800
case: str = None,
800801
name: str = None,
802+
**kwargs,
801803
) -> Tuple[Union[NetworkTemplate, None], ...]:
802804

803805
"""
@@ -864,6 +866,7 @@ def autoencoder_auto(
864866
shallow=shallow,
865867
use_batch_norm=use_batch_norm,
866868
name=name,
869+
**kwargs,
867870
)
868871

869872
return encoder, decoder, bottleneck_encoder, bottleneck_decoder

tests/network/test_template_gen.py

+2
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ def test_autoencoder_kernel_size_shallow(self) -> None:
372372
architecture="cnn",
373373
case="2d",
374374
shallow=True,
375+
padding_mode='replicate',
375376
)
376377

377378
estimated_data = autoencoder.eval(input_data=input_data)
@@ -392,6 +393,7 @@ def test_autoencoder_multiscaleautoencoder(self) -> None:
392393
case="2d",
393394
shallow=True,
394395
name="model",
396+
padding_mode='replicate',
395397
)
396398

397399
estimated_data = autoencoder.reconstruction_forward(input_data=input_data)

0 commit comments

Comments
 (0)