Skip to content

Commit d903d23

Browse files
DriessenAAlice Driessenjannisborn
authored
Small fixes checkpoint loading and ott-jax version bump (#31)
Co-authored-by: Alice Driessen <alicedriessen@Alices-MacBook-Air.local> Co-authored-by: Jannis Born <jannis.born@gmx.de>
1 parent b264624 commit d903d23

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

cmonge/trainers/ot_trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ott.neural.methods.monge_gap import MongeGapEstimator
1616
from ott.neural.methods.neuraldual import W2NeuralDual
1717
from ott.neural.networks.icnn import ICNN
18-
from ott.neural.networks.potentials import MLP
18+
from ott.neural.networks.potentials import PotentialMLP
1919

2020
from cmonge.datasets.single_loader import AbstractDataModule
2121
from cmonge.evaluate import (
@@ -199,6 +199,7 @@ def setup(
199199
fitting_loss: Dict[str, Any],
200200
regularizer: Dict[str, Any],
201201
optim: Dict[str, Any],
202+
checkpointing_path: Optional[str] = None, # For compatibility with base class
202203
) -> None:
203204
"""Initializes models and optimizers."""
204205
self.metrics["params"] = {
@@ -218,7 +219,7 @@ def setup(
218219
regularizer = partial(regularizer_fn, **regularizer.kwargs)
219220

220221
# setup neural network model
221-
model = MLP(dim_hidden=dim_hidden, is_potential=False, act_fn=nn.gelu)
222+
model = PotentialMLP(dim_hidden=dim_hidden, is_potential=False, act_fn=nn.gelu)
222223

223224
# setup optimizer and scheduler
224225
opt_fn = optim_factory[optim.name]
@@ -313,7 +314,7 @@ def setup(
313314
dim_hidden=dim_hidden,
314315
gaussian_map_samples=(samples_source, samples_target),
315316
)
316-
neural_g = MLP(dim_hidden=dim_hidden)
317+
neural_g = PotentialMLP(dim_hidden=dim_hidden)
317318

318319
lr_schedule = optax.cosine_decay_schedule(
319320
init_value=lr, decay_steps=num_train_iters, alpha=1e-2

cmonge/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def monge_get_source_target_transport(
2020
source=True,
2121
transport=True,
2222
batch_size: int = None,
23+
num_contexts: int = 2,
2324
):
2425
if batch_size is None:
2526
batch_size = datamodule.batch_size
@@ -86,7 +87,10 @@ def monge_get_source_target_transport(
8687
)
8788

8889
if transport:
89-
trans = trainer.transport(source_expr, num_contexts=2)
90+
if num_contexts == 0:
91+
trans = trainer.transport(source_expr)
92+
else:
93+
trans = trainer.transport(source_expr, num_contexts=num_contexts)
9094
trans = datamodule.decoder(trans)
9195
trans_meta = cond_meta.copy()
9296
trans_meta["dtype"] = "transport"

0 commit comments

Comments
 (0)