We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent ba21414 commit 0a79308Copy full SHA for 0a79308
sbi/utils/simulation_utils.py
@@ -79,7 +79,7 @@ def simulate_for_sbi(
79
# The batch size will be an approximation, since np.array_split does
80
# not take as argument the size of the batch but their total.
81
num_batches = num_simulations // simulation_batch_size
82
- batches = np.array_split(theta.numpy(), num_batches, axis=0)
+ batches = np.array_split(theta.cpu().numpy(), num_batches, axis=0)
83
batch_seeds = np.random.randint(low=0, high=1_000_000, size=(len(batches),))
84
85
# define seeded simulator.
0 commit comments