Skip to content

Commit babd782

Browse files
committed
Add sharding support
1 parent 56790d9 commit babd782

3 files changed

Lines changed: 196 additions & 24 deletions

File tree

poetry.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyquil/simulation/_simulator.py

Lines changed: 97 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@
3838

3939
import jax
4040
import jax.numpy as jnp
41+
import numpy as np
4142
import quax as qx
4243
from jax import Array
44+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
4345

4446
from pyquil.api import MemoryMap
4547
from pyquil.noise._channels import get_custom_gates_from_program
@@ -288,7 +290,7 @@ class TrajectorySimulator(ProgramSimulator):
288290
outcomes.
289291
"""
290292

291-
__slots__ = ("_kraus_truncation_threshold",)
293+
__slots__ = ("_kraus_truncation_threshold", "_devices")
292294

293295
def __init__(
294296
self,
@@ -298,9 +300,11 @@ def __init__(
298300
noise_model: NoiseModelLike | None = None,
299301
max_subsystem_size: int = 0,
300302
kraus_truncation_threshold: float = 1e-6,
303+
devices: list[jax.Device] | None = None,
301304
) -> None:
302305
super().__init__(program, qubits, noise_model=noise_model, max_subsystem_size=max_subsystem_size)
303306
self._kraus_truncation_threshold = kraus_truncation_threshold
307+
self._devices = devices if devices is not None else jax.devices()
304308

305309
def adapt(self, compressed: list[ResolvedOp]) -> list[TrajectoryOp]:
306310
"""Convert compressed ops to trajectory-compatible types."""
@@ -343,11 +347,14 @@ def sample(
343347
"""Run trajectory simulation in batches, returning only measurement outcomes.
344348
345349
State vectors are discarded after each batch, making this scalable
346-
to arbitrarily many trajectories.
350+
to arbitrarily many trajectories. When multiple devices are
351+
available, each batch is sharded across them so that every device
352+
processes ``batch_size // n_devices`` trajectories concurrently.
347353
348354
:param params: Flat parameter vector from :meth:`linearize`.
349355
:param num_trajectories: Total number of trajectories to simulate.
350-
:param batch_size: Maximum number of trajectories per batch.
356+
:param batch_size: Maximum number of trajectories per batch
357+
(total across all devices).
351358
:param random_seed: Seed for the JAX PRNG.
352359
:return: Measurement outcomes with shape ``(num_trajectories, n_measurements)``.
353360
"""
@@ -363,6 +370,7 @@ def sample(
363370
random_seed,
364371
keep_states=False,
365372
dims=self.dims,
373+
devices=self._devices,
366374
)
367375

368376
if len(all_outcomes) == 1:
@@ -388,6 +396,10 @@ def _apply_trajectory_operations(
388396
- ``qx.KrausMap``: probabilistic Kraus operator sampling
389397
- ``qx.QuantumInstrument``: measurement with outcome recording
390398
399+
Key generation is sharding-friendly: per-operation keys are derived
400+
lazily via ``jax.random.fold_in`` so that the key array is never
401+
materialised in full on a single device.
402+
391403
:param operations: Ordered list of (operator, subsystem) pairs.
392404
:param psi: Initial state vector, optionally batched via ensemble dimension.
393405
:param key: JAX PRNG key (scalar typed key). Will be split internally to
@@ -398,33 +410,42 @@ def _apply_trajectory_operations(
398410
"""
399411
measurement_outcomes: list[Array] = []
400412

401-
n_stochastic = sum(1 for op, _ in operations if isinstance(op, (qx.KrausMap, qx.QuantumInstrument)))
402-
403413
ensemble_size = psi.ensemble_size
404414

405-
if n_stochastic > 0:
406-
if ensemble_size:
407-
n_traj = ensemble_size[0]
408-
all_keys = jax.random.split(key, n_stochastic * n_traj)
409-
all_keys = all_keys.reshape(n_stochastic, n_traj)
415+
# Derive per-trajectory base keys once. When the state is sharded
416+
# across devices the resulting key array inherits the same sharding,
417+
# so each device only materialises its own slice.
418+
if ensemble_size:
419+
if key.ndim > 0:
420+
# Already per-trajectory keys (e.g. from multi-device sharding
421+
# or batched ``compute()``).
422+
per_traj_keys = key
410423
else:
411-
all_keys = jax.random.split(key, n_stochastic)
424+
per_traj_keys = jax.random.split(key, ensemble_size[0])
425+
else:
426+
per_traj_keys = None
412427

413-
key_idx = 0
428+
stochastic_idx = 0
414429

415430
for op, subsystem in operations:
416431
match op:
417432
case qx.Unitary():
418433
psi = qx.targeted_apply_unitary(op, psi, subsystem)
419434
case qx.KrausMap():
420-
op_keys = all_keys[key_idx]
435+
if per_traj_keys is not None:
436+
op_keys = jax.vmap(lambda k: jax.random.fold_in(k, stochastic_idx))(per_traj_keys)
437+
else:
438+
op_keys = jax.random.fold_in(key, stochastic_idx)
421439
psi = qx.targeted_apply_kraus_map_trajectory(op, psi, op_keys, subsystem)
422-
key_idx += 1
440+
stochastic_idx += 1
423441
case qx.QuantumInstrument():
424-
op_keys = all_keys[key_idx]
442+
if per_traj_keys is not None:
443+
op_keys = jax.vmap(lambda k: jax.random.fold_in(k, stochastic_idx))(per_traj_keys)
444+
else:
445+
op_keys = jax.random.fold_in(key, stochastic_idx)
425446
psi, outcome = qx.targeted_apply_instrument_to_state_vector(op, psi, op_keys, subsystem)
426447
measurement_outcomes.append(outcome)
427-
key_idx += 1
448+
stochastic_idx += 1
428449
case _:
429450
raise TypeError(f"Unsupported operator type: {type(op)}")
430451

@@ -436,6 +457,20 @@ def _apply_trajectory_operations(
436457
return psi, outcomes
437458

438459

460+
def _make_mesh(devices: list[jax.Device] | None) -> Mesh | None:
461+
"""Build a 1-D ``Mesh`` over *devices*, or ``None`` for single-device."""
462+
if devices is None:
463+
devices = jax.devices()
464+
if len(devices) <= 1:
465+
return None
466+
return Mesh(np.array(devices), axis_names=("traj",))
467+
468+
469+
def _round_up_to(n: int, divisor: int) -> int:
470+
"""Round *n* up to the nearest multiple of *divisor*."""
471+
return ((n + divisor - 1) // divisor) * divisor
472+
473+
439474
def _run_batched_trajectories(
440475
operations: list[TrajectoryOp],
441476
n_qubits: int,
@@ -444,11 +479,21 @@ def _run_batched_trajectories(
444479
random_seed: int,
445480
keep_states: bool = True,
446481
dims: tuple[int, ...] | None = None,
482+
devices: list[jax.Device] | None = None,
447483
) -> tuple[list[qx.StateVector] | None, list[Array]]:
448-
"""Run trajectory simulation in batches."""
484+
"""Run trajectory simulation in batches, optionally sharded across devices.
485+
486+
When *devices* contains more than one device a :class:`jax.sharding.Mesh`
487+
is constructed and both the initial state vector and PRNG keys are sharded
488+
along the trajectory (ensemble) axis. XLA's SPMD partitioner then
489+
distributes the work so that each device processes its own slice.
490+
"""
449491
if dims is None:
450492
dims = (2,) * n_qubits
451493

494+
mesh = _make_mesh(devices)
495+
n_devices = len(mesh.devices.flat) if mesh is not None else 1
496+
452497
key = jax.random.key(random_seed)
453498
all_psis: list[qx.StateVector] = [] if keep_states else []
454499
all_outcomes: list[Array] = []
@@ -458,31 +503,59 @@ def _run_batched_trajectories(
458503
t_total = 0.0
459504
while remaining > 0:
460505
this_batch = min(remaining, batch_size)
506+
507+
# Pad to a multiple of n_devices so the shard split is even.
508+
padded_batch = _round_up_to(this_batch, n_devices) if n_devices > 1 else this_batch
509+
n_pad = padded_batch - this_batch
510+
461511
key, batch_key = jax.random.split(key)
462512

463-
if this_batch == 1:
513+
if padded_batch == 1:
464514
psi = qx.zero_state_vector(dims=dims)
465515
else:
466-
psi = qx.zero_state_vector(dims=dims, ensemble_size=(this_batch,))
516+
psi = qx.zero_state_vector(dims=dims, ensemble_size=(padded_batch,))
517+
518+
# Shard state and key across devices when a mesh is available.
519+
if mesh is not None:
520+
sharding = NamedSharding(mesh, PartitionSpec("traj")) # type: ignore[no-untyped-call]
521+
psi = qx.StateVector.from_matrix(
522+
jax.device_put(psi.matrix, sharding),
523+
psi.dims,
524+
)
525+
# Split a per-trajectory key vector and shard it.
526+
batch_keys = jax.random.split(batch_key, padded_batch)
527+
batch_keys = jax.device_put(batch_keys, sharding)
528+
else:
529+
batch_keys = batch_key
467530

468531
t0 = time.perf_counter()
469-
psi_out, outcomes = _apply_trajectory_operations(operations, psi, batch_key)
532+
psi_out, outcomes = _apply_trajectory_operations(operations, psi, batch_keys)
470533
psi_out.matrix.block_until_ready()
471534
t1 = time.perf_counter()
472535
t_total += t1 - t0
473536

474-
if this_batch == 1:
537+
# Strip padding rows.
538+
if n_pad > 0:
539+
psi_out = qx.StateVector.from_matrix(
540+
psi_out.matrix[:this_batch],
541+
psi_out.dims,
542+
)
543+
outcomes = outcomes[:this_batch]
544+
545+
if this_batch == 1 and padded_batch == 1:
475546
psi_out = qx.StateVector.from_matrix(
476547
psi_out.matrix[jnp.newaxis],
477548
psi_out.dims,
478549
)
479550
outcomes = outcomes[jnp.newaxis]
480551

481552
logger.debug(
482-
"Batch %d: %d trajectories, %d qubits, %.3f s",
553+
"Batch %d: %d trajectories (%d padded), %d qubits, %d device(s), %.3f s",
483554
batch_idx,
484555
this_batch,
556+
padded_batch,
485557
n_qubits,
558+
n_devices,
486559
t1 - t0,
487560
)
488561

@@ -493,11 +566,12 @@ def _run_batched_trajectories(
493566
batch_idx += 1
494567

495568
logger.info(
496-
"Trajectories complete: %d total, %d batches (size=%d), n_qubits=%d, %.3f s total, %.1f traj/s",
569+
"Trajectories complete: %d total, %d batches (size=%d), n_qubits=%d, %d device(s), %.3f s total, %.1f traj/s",
497570
num_trajectories,
498571
batch_idx,
499572
batch_size,
500573
n_qubits,
574+
n_devices,
501575
t_total,
502576
num_trajectories / t_total if t_total > 0 else float("inf"),
503577
)

test/unit/test_state_vector.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
PureStateVectorSimulator,
2828
TrajectorySimulator,
2929
_run_batched_trajectories,
30+
_make_mesh,
31+
_round_up_to,
3032
)
3133
from pyquil.simulation._simulator import (
3234
_apply_trajectory_operations as apply_trajectory_operations,
@@ -1042,3 +1044,99 @@ def test_random_circuit_compression_summary(self, capsys):
10421044
line += f" {counts[s]:>4} ({ratio:.2f})"
10431045
# line += f" {counts[s]:>8}"
10441046
print(line)
1047+
1048+
1049+
# ──────────────────────────────────────────────────────────────────────────────
1050+
# Multi-device / sharding tests
1051+
# ──────────────────────────────────────────────────────────────────────────────
1052+
1053+
1054+
class TestMultiDeviceHelpers:
1055+
def test_round_up_to(self):
1056+
assert _round_up_to(7, 4) == 8
1057+
assert _round_up_to(8, 4) == 8
1058+
assert _round_up_to(1, 3) == 3
1059+
assert _round_up_to(0, 5) == 0
1060+
1061+
def test_make_mesh_single_device_returns_none(self):
1062+
"""A single device should return None (no mesh needed)."""
1063+
devices = jax.devices()[:1]
1064+
assert _make_mesh(devices) is None
1065+
1066+
def test_make_mesh_none_uses_default(self):
1067+
"""Passing None should query jax.devices()."""
1068+
mesh = _make_mesh(None)
1069+
if len(jax.devices()) <= 1:
1070+
assert mesh is None
1071+
else:
1072+
assert mesh is not None
1073+
1074+
1075+
class TestMultiDeviceTrajectory:
1076+
"""Tests that exercise the multi-device code paths.
1077+
1078+
On a single-CPU host these still validate the padding/unpadding logic
1079+
and the ``devices`` parameter plumbing. On a multi-GPU host they
1080+
exercise real cross-device sharding.
1081+
"""
1082+
1083+
def test_devices_parameter_accepted(self):
1084+
"""TrajectorySimulator should accept a ``devices`` keyword."""
1085+
p = Program(H(0), MEASURE(0, None))
1086+
sim = TrajectorySimulator(p, qubits=[0], devices=jax.devices())
1087+
outcomes = sim.sample(_EMPTY_PARAMS, num_trajectories=10)
1088+
assert outcomes.shape == (10, 1)
1089+
1090+
def test_sample_results_match_single_device(self):
1091+
"""Outcomes shape and value range must be the same regardless of device list."""
1092+
p = Program(H(0), MEASURE(0, None))
1093+
sim_default = TrajectorySimulator(p, qubits=[0])
1094+
sim_explicit = TrajectorySimulator(p, qubits=[0], devices=jax.devices())
1095+
1096+
out_default = sim_default.sample(_EMPTY_PARAMS, num_trajectories=64, batch_size=16, random_seed=99)
1097+
out_explicit = sim_explicit.sample(_EMPTY_PARAMS, num_trajectories=64, batch_size=16, random_seed=99)
1098+
1099+
assert out_default.shape == out_explicit.shape
1100+
assert jnp.all((out_default == 0) | (out_default == 1))
1101+
assert jnp.all((out_explicit == 0) | (out_explicit == 1))
1102+
1103+
def test_padding_stripped_correctly(self):
1104+
"""When num_trajectories is not a multiple of n_devices, padding must be removed."""
1105+
p = Program(H(0), MEASURE(0, None))
1106+
sim = TrajectorySimulator(p, qubits=[0], devices=jax.devices())
1107+
# 7 is unlikely to be a multiple of any device count > 1
1108+
outcomes = sim.sample(_EMPTY_PARAMS, num_trajectories=7, batch_size=7)
1109+
assert outcomes.shape == (7, 1)
1110+
1111+
def test_batched_trajectories_with_devices(self):
1112+
"""_run_batched_trajectories should accept and use devices parameter."""
1113+
p = Program(H(0), MEASURE(0, None))
1114+
sim = TrajectorySimulator(p, qubits=[0])
1115+
resolved = sim.resolve(_EMPTY_PARAMS)
1116+
compressed = sim.compress(resolved)
1117+
operations = sim.adapt(compressed)
1118+
1119+
_, outcomes = _run_batched_trajectories(
1120+
operations,
1121+
sim.n_qubits,
1122+
num_trajectories=20,
1123+
batch_size=8,
1124+
random_seed=42,
1125+
keep_states=False,
1126+
dims=sim.dims,
1127+
devices=jax.devices(),
1128+
)
1129+
total = sum(o.shape[0] for o in outcomes)
1130+
assert total == 20
1131+
1132+
def test_noisy_sample_with_devices(self):
1133+
"""Multi-device path should work with noise models."""
1134+
p_error = 0.3
1135+
ch = Channel.from_pauli_noise(inst=X(0), pauli_noise={"X": p_error})
1136+
noise_model = NoiseModel(channels=[ch])
1137+
p = Program(X(0), MEASURE(0, None))
1138+
sim = TrajectorySimulator(p, noise_model=noise_model, qubits=[0], devices=jax.devices())
1139+
outcomes = sim.sample(_EMPTY_PARAMS, num_trajectories=1024, batch_size=256, random_seed=7)
1140+
assert outcomes.shape == (1024, 1)
1141+
frac_0 = float(jnp.mean(outcomes == 0))
1142+
assert abs(frac_0 - p_error) < 0.05

0 commit comments

Comments
 (0)