Skip to content

Commit 3f64e42

Browse files
committed
fix: Move mvscale import to be conditional
1 parent 9736963 commit 3f64e42

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

python/nutpie/transform_adapter.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def make_transform_adapter(
1818
import flowjax
1919
import flowjax.train
2020
import flowjax.flows
21-
from flowjax.bijections import mvscale
2221
import optax
2322
import traceback
2423
from paramax import Parameterize, unwrap
@@ -164,6 +163,8 @@ def make_layer(key, is_last=False):
164163
flow = flowjax.flows._add_default_permute(coupling, n_dim, key_permute)
165164

166165
if scale_layer:
166+
from flowjax.bijections import mvscale
167+
167168
bijections = list(flow.bijections)
168169
bijections.append(mvscale.MvScale4(jnp.ones(n_dim) * 1e-5))
169170
# bijections.append(mvscale.MvScale3(jnp.ones(n_dim) * 1e-5))

0 commit comments

Comments
 (0)