Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def body_fn(wrapped_carry, x, prefix=None):
# we haven't promote shapes of values yet during `lax.scan`, so we do it here
site["value"] = _promote_scanned_value_shapes(site["value"], site["fn"])

# XXX: site['infer']['dim_to_name'] is not enough to determine leftmost dimension because
# Note: site['infer']['dim_to_name'] is not enough to determine leftmost dimension because
# we don't record 1-size dimensions in this field
time_dim = -min(
len(site["fn"].batch_shape), jnp.ndim(site["value"]) - site["fn"].event_dim
Expand Down
2 changes: 1 addition & 1 deletion numpyro/contrib/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _update_params(params, new_params, prior, prefix=""):
else:
d = prior
param_batch_shape = param_shape[: len(param_shape) - d.event_dim]
# XXX: here we set all dimensions of prior to event dimensions.
# Note: here we set all dimensions of prior to event dimensions.
new_params[name] = numpyro.sample(
flatten_name, d.expand(param_batch_shape).to_event()
)
Expand Down
4 changes: 2 additions & 2 deletions numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _get_codomain(bijector):
loc, scale, concentration = bijector.loc, bijector.scale, bijector.concentration
if not_jax_tracer(concentration) and np.all(np.less(concentration, 0)):
return constraints.interval(loc, loc + scale / jnp.abs(concentration))
# XXX: here we suppose concentration > 0
# Note: here we suppose concentration > 0
# which is not true in general, but should cover enough usage cases
else:
return constraints.greater_than(loc)
Expand Down Expand Up @@ -278,7 +278,7 @@ def support(self):

@property
def is_discrete(self):
# XXX: this should cover most cases
# Note: this should cover most cases
return self.support is None

def tree_flatten(self):
Expand Down
2 changes: 1 addition & 1 deletion numpyro/contrib/tfp/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def init(
if is_prng_key(rng_key):
init_state = self._init_fn(init_params, rng_key)
else:
# XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
# note: it's safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
# nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
# wa_steps because those variables do not depend on traced args: init_params, rng_key.
init_state = vmap(self._init_fn)(init_params, rng_key)
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


def _log_beta_1(alpha, value):
# XXX: support sparse `value`
# Note: support sparse `value`
return gammaln(1 + value) + gammaln(alpha) - gammaln(value + alpha)


Expand Down
4 changes: 2 additions & 2 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def codomain(self) -> Constraint:
raise NotImplementedError

def __call__(self, x: NumLike) -> NumLike:
# XXX consider to clamp from below for stability if necessary
# Note: consider to clamp from below for stability if necessary
return jnp.exp(x)

def _inverse(self, y: NumLike) -> NumLike:
Expand Down Expand Up @@ -1318,7 +1318,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
raise NotImplementedError

def tree_flatten(self):
# XXX: what if unpack_fn is a parametrized callable pytree?
# Note: what if unpack_fn is a parametrized callable pytree?
return (), ((), {"unpack_fn": self.unpack_fn, "pack_fn": self.pack_fn})

def eq(self, other: object, static: bool = False) -> ArrayLike:
Expand Down
2 changes: 1 addition & 1 deletion numpyro/examples/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def _load_jsb_chorales() -> dict:
)
data = pickle.load(f)

# XXX: we might expose those in `load_dataset` keywords
# Note: we might expose those in `load_dataset` keywords
min_note = 21
note_range = 88
processed_dataset = {}
Expand Down
2 changes: 1 addition & 1 deletion numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def __call__(self, *args, **kwargs):
def _unpack_and_constrain(self, latent_sample, params):
def unpack_single_latent(latent):
unpacked_samples = self._unpack_latent(latent)
# XXX: we need to add param here to be able to replay model
# Note: we need to add param here to be able to replay model
unpacked_samples.update(
{
k: v
Expand Down
2 changes: 1 addition & 1 deletion numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,7 +1146,7 @@ def single_particle_elbo(rng_key: jax.Array) -> jax.Array:
if self.max_plate_nesting == float("inf"):
seeded_model = seed(model, model_seed)
seeded_guide = seed(guide, guide_seed)
# XXX: We can extract abstract latents here such that they
# Note: We can extract abstract latents here such that they
# can be reused in get_nonreparam_deps below.
self.max_plate_nesting = guess_max_plate_nesting(
seeded_model, seeded_guide, args, kwargs, param_map
Expand Down
6 changes: 3 additions & 3 deletions numpyro/infer/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def get_diagnostics_str(self, state):
return "acc. prob={:.2f}".format(state.inner_state.mean_accept_prob)

def init_inner_state(self, rng_key):
# XXX hack -- we don't know num_chains until we init the inner state
# Note: hack -- we don't know num_chains until we init the inner state
self._moves = [
move(self._num_chains) if move.__name__ == "make_de_move" else move
for move in self._moves
Expand Down Expand Up @@ -370,7 +370,7 @@ def de_move(rng_key, active, inactive):
pairs_key, gamma_key = random.split(rng_key)
n_active_chains, n_params = inactive.shape

# XXX: if we pass in n_params to parent scope we don't need to
# Note: if we pass in n_params to parent scope we don't need to
# recompute this each time
g = 2.38 / jnp.sqrt(2.0 * n_params) if not g0 else g0

Expand Down Expand Up @@ -535,7 +535,7 @@ def __init__(
def init_inner_state(self, rng_key):
self.batch_log_density = lambda x: self._batch_log_density(x)[:, jnp.newaxis]

# XXX hack -- we don't know num_chains until we init the inner state
# Note: hack -- we don't know num_chains until we init the inner state
self._moves = [
move(self._num_chains)
if move.__name__ == "make_differential_move"
Expand Down
4 changes: 2 additions & 2 deletions numpyro/infer/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ def init(
if self._model is not None:
z = init_params[0] if isinstance(init_params, ParamInfo) else init_params
if isinstance(dense_mass, bool):
# XXX: by default, the order variables are sorted by their names,
# Note: by default, the order variables are sorted by their names,
# this is to be compatible with older numpyro versions
# and to match autoguide scale parameter and jax flatten utils
dense_mass = [tuple(sorted(z))] if dense_mass else []
Expand Down Expand Up @@ -790,7 +790,7 @@ def init(
if is_prng_key(rng_key):
init_state = hmc_init_fn(init_params, rng_key)
else:
# XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
# Note it's safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
# nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
# wa_steps because those variables do not depend on traced args: init_params, rng_key.
init_state = vmap(hmc_init_fn)(init_params, rng_key)
Expand Down
2 changes: 1 addition & 1 deletion numpyro/infer/hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support

z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)
# Here we loop over the support of z_flat[idx] to get z_new
# XXX: we can't vmap potential_fn over all proposals and sample from the conditional
# Note: we can't vmap potential_fn over all proposals and sample from the conditional
# categorical distribution because support_size is a traced value, i.e. its value
# might change across different discrete variables;
# so here we will loop over all proposals and use an online scheme to sample from
Expand Down
6 changes: 3 additions & 3 deletions numpyro/infer/hmc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from numpyro.util import cond, identity, while_loop

AdaptWindow = namedtuple("AdaptWindow", ["start", "end"])
# XXX: we need to store rng_key here in case we use find_reasonable_step_size functionality
# Note: we need to store rng_key here in case we use find_reasonable_step_size functionality
HMCAdaptState = namedtuple(
"HMCAdaptState",
[
Expand Down Expand Up @@ -214,7 +214,7 @@ def final_fn(state, regularize=False):
return cov, cov_inv_sqrt, tril_inv

mean, m2, n = state
# XXX it is not necessary to check for the case n=1
# Note: it is not necessary to check for the case n=1
cov = m2 / (n - 1)
if regularize:
# Regularization from Stan
Expand Down Expand Up @@ -953,7 +953,7 @@ def _leaf_idx_to_ckpt_idxs(n):
# turning condition;
# however, we can check the turning condition of the subtree 0 -> 5, which
# likely satisfies turning condition because its trajectory 3/4 of a circle.
# XXX: make sure that detailed balance is satisfied if we follow this direction
# Note: make sure that detailed balance is satisfied if we follow this direction
idx_min = idx_max - num_subtrees + 1
return idx_min, idx_max

Expand Down
2 changes: 1 addition & 1 deletion numpyro/infer/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def init_to_uniform(site=None, radius=2):
)
return site["value"]

# XXX: we import here to avoid circular import
# Note: we import here to avoid circular import
from numpyro.infer.util import helpful_support_errors

rng_key = site["kwargs"].get("rng_key")
Expand Down
4 changes: 2 additions & 2 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def collect_and_postprocess(x):
return collect_and_postprocess


# XXX: Is there a better hash key that we can use?
# Note: Is there a better hash key that we can use?
def _hashable(x):
# NOTE: When the arguments are JITed, ShapedArray is hashable.
if isinstance(x, (np.ndarray, jnp.ndarray)):
Expand Down Expand Up @@ -778,7 +778,7 @@ def print_summary(self, prob=0.9, exclude_deterministic=True):
sites = self._states[self._sample_field]
if isinstance(sites, dict) and exclude_deterministic:
state_sample_field = attrgetter(self._sample_field)(self._last_state)
# XXX: there might be the case that state.z is not a dictionary but
# Note: there might be the case that state.z is not a dictionary but
# its postprocessed value `sites` is a dictionary.
# TODO: in general, when both `sites` and `state.z` are dictionaries,
# they can have different key names, not necessary due to deterministic
Expand Down
6 changes: 3 additions & 3 deletions numpyro/infer/sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def init_kernel(
else jnp.sqrt(inverse_mass_matrix)
)
if adapt_state_size is None:
# XXX: heuristic choice
# Note: heuristic choice
adapt_state_size = 2 * z_flat.shape[-1]
else:
assert adapt_state_size > 1, "adapt_state_size should be greater than 1."
Expand Down Expand Up @@ -182,7 +182,7 @@ def sample_kernel(sa_state, model_args=(), model_kwargs=None):
pe_fn = potential_fn_gen(*model_args, **model_kwargs)
zs, pes, loc, scale = sa_state.adapt_state
# we recompute loc/scale after each iteration to avoid precision loss
# XXX: consider to expose a setting to do this job periodically
# Note: consider to expose a setting to do this job periodically
# to save some computations
loc = jnp.mean(zs, 0)
if scale.ndim == 2:
Expand Down Expand Up @@ -235,7 +235,7 @@ def sample_kernel(sa_state, model_args=(), model_kwargs=None):
sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
)

# XXX: we make a modification of SA sampler in [1]
# Note: we make a modification of SA sampler in [1]
# in [1], each MCMC state contains N points `zs`
# here we do resampling to pick randomly a point from those N points
k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[0]))
Expand Down
2 changes: 1 addition & 1 deletion numpyro/infer/svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def body_fn(svi_state, _):
else:
svi_state, losses = lax.scan(body_fn, svi_state, None, length=num_steps)

# XXX: we also return the last svi_state for further inspection of both
# Note: we also return the last svi_state for further inspection of both
# optimizer's state and mutable state.
return SVIRunResult(self.get_params(svi_state), svi_state, losses)

Expand Down
8 changes: 4 additions & 4 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def body_fn(state):
key, subkey = random.split(key)

if radius is None or prototype_params is None:
# XXX: we don't want to apply enum to draw latent samples
# Note: we don't want to apply enum to draw latent samples
model_ = model
if enum:
from numpyro.contrib.funsor import enum as enum_handler
Expand Down Expand Up @@ -461,7 +461,7 @@ def _find_valid_params(rng_key, exit_early=False):
if device_get(is_valid):
return (init_params, pe, z_grad), is_valid

# XXX: this requires compiling the model, so for multi-chain, we trace the model 2-times
# Note: this requires compiling the model, so for multi-chain, we trace the model 2-times
# even if the init_state is a valid result
_, _, (init_params, pe, z_grad), is_valid = while_loop(
cond_fn, body_fn, init_state
Expand Down Expand Up @@ -516,7 +516,7 @@ def _get_model_transforms(model, model_args=(), model_kwargs=None):
support = v["fn"].support
with helpful_support_errors(v, raise_warnings=True):
inv_transforms[k] = biject_to(support)
# XXX: the following code filters out most situations with dynamic supports
# Note: the following code filters out most situations with dynamic supports
args = ()
if isinstance(support, constraints._GreaterThan):
args = ("lower_bound",)
Expand Down Expand Up @@ -582,7 +582,7 @@ def get_potential_fn(
_partial_args_kwargs, partial(potential_energy, model, enum=enum)
)
if replay_model:
# XXX: we seed to sample discrete sites (but not collect them)
# Note: we seed to sample discrete sites (but not collect them)
model_ = seed(model.fn, 0) if enum else model
postprocess_fn = partial(
_partial_args_kwargs,
Expand Down
4 changes: 2 additions & 2 deletions numpyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def eval_provenance(fn, **kwargs):
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fn, **fn_info), in_tree)
# Abstract eval to get output pytree
avals = _safe_map(shaped_abstractify, args)
# XXX: we split out the process of abstract evaluation and provenance tracking
# Note: we split out the process of abstract evaluation and provenance tracking
# for simplicity. In principle, they can be merged so that we only need to walk
# through the equations once.

Expand Down Expand Up @@ -108,7 +108,7 @@ def write(v, p):
track_deps_rules = {}


# XXX: Currently, we use default rule for scan_p, cond_p, while_p, remat_p
# Note: Currently, we use default rule for scan_p, cond_p, while_p, remat_p
def _default_track_deps_rules(eqn, provenance_inputs):
provenance_outputs = frozenset().union(*provenance_inputs)
return [provenance_outputs] * len(eqn.outvars)
Expand Down
2 changes: 1 addition & 1 deletion numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def __init__(
self.subsample_size = self._indices.shape[0]
super(plate, self).__init__()

# XXX: different from Pyro, this method returns dim and indices
# Note: different from Pyro, this method returns dim and indices
@staticmethod
def _subsample(name, size, subsample_size, dim): # noqa: ANN001, ANN205
msg = {
Expand Down
2 changes: 1 addition & 1 deletion numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def soft_vmap(
xs = jax.tree.map(
lambda x: jnp.reshape(x, prepend_shape + jnp.shape(x)[batch_ndims:]), xs
)
# XXX: probably for the default behavior with chunk_size=None,
# Note: probably for the default behavior with chunk_size=None,
# it is better to catch OOM error and reduce chunk_size by half until OOM disappears.
chunk_size = batch_size if chunk_size is None else min(batch_size, chunk_size)
if chunk_size > 1:
Expand Down
4 changes: 2 additions & 2 deletions test/infer/test_infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def model(data):
rng_keys[i], model, init_strategy=init_strategy, model_args=(count_data,)
)
for name, p in init_params[0].items():
# XXX: the result is equal if we disable fast-math-mode
# Note: the result is equal if we disable fast-math-mode
assert_allclose(p[i], init_params_i[0][name], atol=1e-6)


Expand Down Expand Up @@ -494,7 +494,7 @@ def model(data):
rng_keys[i], model, init_strategy=init_strategy, model_args=(data,)
)
for name, p in init_params[0].items():
# XXX: the result is equal if we disable fast-math-mode
# Note: the result is equal if we disable fast-math-mode
assert_allclose(p[i], init_params_i[0][name], atol=1e-6)


Expand Down
2 changes: 1 addition & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2264,7 +2264,7 @@ def test_mean_var(jax_dist, sp_dist, params):
k = random.key(0)
samples = d_jax.sample(k, sample_shape=(n,)).astype(np.float32)
# check with suitable scipy implementation if available
# XXX: VonMises is already tested below
# note: VonMises is already tested below
if (
sp_dist
and not _is_batched_multivariate(d_jax)
Expand Down
Loading