Skip to content

Commit 6ec04e5

Browse files
committed
style: Fix ruff issues
1 parent 7b4dd2e commit 6ec04e5

File tree

3 files changed

+12
-11
lines changed

3 files changed

+12
-11
lines changed

python/nutpie/compile_pymc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def _compile_pymc_model_jax(
380380
def logp_fn_jax_grad(x, *shared):
381381
return jax.value_and_grad(lambda x: orig_logp_fn(x, *shared)[0])(x)
382382

383-
static_argnums = list(range(1, len(logp_shared_names) + 1))
383+
# static_argnums = list(range(1, len(logp_shared_names) + 1))
384384
logp_fn_jax_grad = jax.jit(
385385
logp_fn_jax_grad,
386386
# static_argnums=static_argnums,

python/nutpie/compile_stan.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from functools import partial
21
import json
32
import tempfile
43
from dataclasses import dataclass, replace
4+
from functools import partial
55
from importlib.util import find_spec
66
from pathlib import Path
7-
from typing import Any, Optional, Callable
7+
from typing import Any, Optional
88

99
import numpy as np
1010
import pandas as pd

python/nutpie/transform_adapter.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,18 @@ def make_transform_adapter(
1212
untransformed_dim=None,
1313
zero_init=True,
1414
):
15-
import jax
15+
import traceback
16+
from functools import partial
17+
1618
import equinox as eqx
17-
import jax.numpy as jnp
1819
import flowjax
19-
import flowjax.train
2020
import flowjax.flows
21+
import flowjax.train
22+
import jax
23+
import jax.numpy as jnp
24+
import numpy as np
2125
import optax
22-
import traceback
2326
from paramax import Parameterize, unwrap
24-
from functools import partial
25-
26-
import numpy as np
2727

2828
class FisherLoss:
2929
@eqx.filter_jit
@@ -363,7 +363,7 @@ def update(self, seed, positions, gradients):
363363
else:
364364
base = self._bijection
365365

366-
# make_flow might still only return a single trafo if the for 1d problems
366+
# make_flow might still only return a single trafo for 1d problems
367367
if len(base.bijections) == 1:
368368
self._bijection = base
369369
return
@@ -436,6 +436,7 @@ def update(self, seed, positions, gradients):
436436
except Exception as e:
437437
print("update error:", e)
438438
print(traceback.format_exc())
439+
raise
439440

440441
def init_from_transformed_position(self, transformed_position):
441442
try:

0 commit comments

Comments
 (0)