1515from ott .neural .methods .monge_gap import MongeGapEstimator
1616from ott .neural .methods .neuraldual import W2NeuralDual
1717from ott .neural .networks .icnn import ICNN
18- from ott .neural .networks .potentials import MLP
18+ from ott .neural .networks .potentials import PotentialMLP
1919
2020from cmonge .datasets .single_loader import AbstractDataModule
2121from cmonge .evaluate import (
@@ -199,6 +199,7 @@ def setup(
199199 fitting_loss : Dict [str , Any ],
200200 regularizer : Dict [str , Any ],
201201 optim : Dict [str , Any ],
202+ checkpointing_path : Optional [str ] = None , # For compatibility with base class
202203 ) -> None :
203204 """Initializes models and optimizers."""
204205 self .metrics ["params" ] = {
@@ -218,7 +219,7 @@ def setup(
218219 regularizer = partial (regularizer_fn , ** regularizer .kwargs )
219220
220221 # setup neural network model
221- model = MLP (dim_hidden = dim_hidden , is_potential = False , act_fn = nn .gelu )
222+ model = PotentialMLP (dim_hidden = dim_hidden , is_potential = False , act_fn = nn .gelu )
222223
223224 # setup optimizer and scheduler
224225 opt_fn = optim_factory [optim .name ]
@@ -313,7 +314,7 @@ def setup(
313314 dim_hidden = dim_hidden ,
314315 gaussian_map_samples = (samples_source , samples_target ),
315316 )
316- neural_g = MLP (dim_hidden = dim_hidden )
317+ neural_g = PotentialMLP (dim_hidden = dim_hidden )
317318
318319 lr_schedule = optax .cosine_decay_schedule (
319320 init_value = lr , decay_steps = num_train_iters , alpha = 1e-2
0 commit comments