Skip to content

Multiple chain draws for simple Multivariate Bernouilli inference #596

Open
@amirdib

Description

@amirdib

Hi,

I want to perform a simple inference of a Multivariate Bernouilli (dimension D) with multiple chains. The code below works and correctly infers the parameters value for an unique chain.
I suspect that I incorrectly defined my model. I didn't find any simple example of simple bernouilli inference.

The error returned is:
ValueError: Dimension must be 3 but is 2 for 'mcmc_sample_chain/simple_step_size_adaptation___init__/_bootstrap_results/mh_bootstrap_results/hmc_kernel_bootstrap_results/maybe_call_fn_and_grads/value_and_gradients/mcmc_sample_chain_simple_step_size_adaptation___init____bootstrap_results_mh_bootstrap_results_hmc_kernel_bootstrap_results_maybe_call_fn_and_grads_value_and_gradients_Samplemcmc_sample_chain_simple_step_size_adaptation___init____bootstrap_results_mh_bootstrap_results_hmc_kernel_bootstrap_results_maybe_call_fn_and_grads_value_and_gradients_Independentmcmc_sample_chain_simple_step_size_adaptation___init____bootstrap_results_mh_bootstrap_results_hmc_kernel_bootstrap_results_maybe_call_fn_and_grads_value_and_gradients_Bernoulli/log_prob/transpose' (op: 'Transpose') with input shapes: [1,5000,2], [2].

Here a simple example with D=2 and N = 5000 (number of samples in the training set).

import numpy as np 
import tensorflow as tf
import tensorflow_probability as tfp
import functools
tfd = tfp.distributions

# ---------- DATA Generator ------------#

def generate_bernouilli(N,p):
    return np.array([np.random.binomial(size=N, n=1, p = probability) for probability in p ]).T

D = 2
N = 5000
p = np.sort(np.random.random(D))

observations = generate_bernouilli(N,p)

# ---------- Model ------------#

def make_likelihood(theta):
    one_y = tfd.Independent(
        distribution = tfd.Bernoulli(probs=theta),
        reinterpreted_batch_ndims=1)
    y = tfd.Sample(one_y,
          sample_shape=(N,))
    return y
    
def joint_log_prob(observations, theta):
    return (tf.reduce_sum(make_likelihood(theta).log_prob(observations)))
    
posterior_log_prob = functools.partial(joint_log_prob, observations)


# ---------- MCMC sampling  ------------#

num_results = int(10e3)
num_burnin_steps = int(1e3)
n_chains = 5

adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(
    tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=posterior_log_prob,
        num_leapfrog_steps=3,
        step_size=1.),
    target_accept_prob=tf.constant(.8),
    num_adaptation_steps=int(num_burnin_steps * 0.8))


@tf.function
def run_chain():
# Run the chain (with burn-in).
    samples, is_accepted = tfp.mcmc.sample_chain(
    num_results=num_results,
    num_burnin_steps=num_burnin_steps,
    current_state=tf.ones([n_chains,2])/10,
    kernel=adaptive_hmc,
    trace_fn=lambda _, pkr: pkr.inner_results.is_accepted)

    is_accepted = tf.reduce_mean(tf.cast(is_accepted, dtype=tf.float32))
    return samples, is_accepted


# ---------- Run  ------------#
with tf.device('/CPU:0'):
    samples, is_accepted = run_chain()

The code works perfectly if we replace current_state by current_state=tf.ones([2])/10 (and thus removing the independent chain sampling.

I have few questions and I will be very grateful for any help:

  • Is my model correctly implemented ?
  • Is there a way to debug this type of error in tf ? The python debugger is not of much help.
  • Is there any description of the relation between the different dimensions (event shape, batch_shape, sample_shape) that must hold for the implementation of HMC to work ?

Thanks in advance !

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions