Currently the JointDistribution wraps and unwraps X and Y samples into one array XY by concatenation and slicing.
This is suboptimal: for example, X and Y need to have the same dtype and working with continuous and categorical variables requires manual casting.
Instead, we can use JointDistribution from TFP on JAX.