3838
3939import jax
4040import jax .numpy as jnp
41+ import numpy as np
4142import quax as qx
4243from jax import Array
44+ from jax .sharding import Mesh , NamedSharding , PartitionSpec
4345
4446from pyquil .api import MemoryMap
4547from pyquil .noise ._channels import get_custom_gates_from_program
@@ -288,7 +290,7 @@ class TrajectorySimulator(ProgramSimulator):
288290 outcomes.
289291 """
290292
291- __slots__ = ("_kraus_truncation_threshold" ,)
293+ __slots__ = ("_kraus_truncation_threshold" , "_devices" )
292294
293295 def __init__ (
294296 self ,
@@ -298,9 +300,11 @@ def __init__(
298300 noise_model : NoiseModelLike | None = None ,
299301 max_subsystem_size : int = 0 ,
300302 kraus_truncation_threshold : float = 1e-6 ,
303+ devices : list [jax .Device ] | None = None ,
301304 ) -> None :
302305 super ().__init__ (program , qubits , noise_model = noise_model , max_subsystem_size = max_subsystem_size )
303306 self ._kraus_truncation_threshold = kraus_truncation_threshold
307+ self ._devices = devices if devices is not None else jax .devices ()
304308
305309 def adapt (self , compressed : list [ResolvedOp ]) -> list [TrajectoryOp ]:
306310 """Convert compressed ops to trajectory-compatible types."""
@@ -343,11 +347,14 @@ def sample(
343347 """Run trajectory simulation in batches, returning only measurement outcomes.
344348
345349 State vectors are discarded after each batch, making this scalable
346- to arbitrarily many trajectories.
350+ to arbitrarily many trajectories. When multiple devices are
351+ available, each batch is sharded across them so that every device
352+ processes ``batch_size // n_devices`` trajectories concurrently.
347353
348354 :param params: Flat parameter vector from :meth:`linearize`.
349355 :param num_trajectories: Total number of trajectories to simulate.
350- :param batch_size: Maximum number of trajectories per batch.
356+ :param batch_size: Maximum number of trajectories per batch
357+ (total across all devices).
351358 :param random_seed: Seed for the JAX PRNG.
352359 :return: Measurement outcomes with shape ``(num_trajectories, n_measurements)``.
353360 """
@@ -363,6 +370,7 @@ def sample(
363370 random_seed ,
364371 keep_states = False ,
365372 dims = self .dims ,
373+ devices = self ._devices ,
366374 )
367375
368376 if len (all_outcomes ) == 1 :
@@ -388,6 +396,10 @@ def _apply_trajectory_operations(
388396 - ``qx.KrausMap``: probabilistic Kraus operator sampling
389397 - ``qx.QuantumInstrument``: measurement with outcome recording
390398
399+ Key generation is sharding-friendly: per-operation keys are derived
400+ lazily via ``jax.random.fold_in`` so that the key array is never
401+ materialised in full on a single device.
402+
391403 :param operations: Ordered list of (operator, subsystem) pairs.
392404 :param psi: Initial state vector, optionally batched via ensemble dimension.
393405 :param key: JAX PRNG key (scalar typed key). Will be split internally to
@@ -398,33 +410,42 @@ def _apply_trajectory_operations(
398410 """
399411 measurement_outcomes : list [Array ] = []
400412
401- n_stochastic = sum (1 for op , _ in operations if isinstance (op , (qx .KrausMap , qx .QuantumInstrument )))
402-
403413 ensemble_size = psi .ensemble_size
404414
405- if n_stochastic > 0 :
406- if ensemble_size :
407- n_traj = ensemble_size [0 ]
408- all_keys = jax .random .split (key , n_stochastic * n_traj )
409- all_keys = all_keys .reshape (n_stochastic , n_traj )
415+ # Derive per-trajectory base keys once. When the state is sharded
416+ # across devices the resulting key array inherits the same sharding,
417+ # so each device only materialises its own slice.
418+ if ensemble_size :
419+ if key .ndim > 0 :
420+ # Already per-trajectory keys (e.g. from multi-device sharding
421+ # or batched ``compute()``).
422+ per_traj_keys = key
410423 else :
411- all_keys = jax .random .split (key , n_stochastic )
424+ per_traj_keys = jax .random .split (key , ensemble_size [0 ])
425+ else :
426+ per_traj_keys = None
412427
413- key_idx = 0
428+ stochastic_idx = 0
414429
415430 for op , subsystem in operations :
416431 match op :
417432 case qx .Unitary ():
418433 psi = qx .targeted_apply_unitary (op , psi , subsystem )
419434 case qx .KrausMap ():
420- op_keys = all_keys [key_idx ]
435+ if per_traj_keys is not None :
436+ op_keys = jax .vmap (lambda k : jax .random .fold_in (k , stochastic_idx ))(per_traj_keys )
437+ else :
438+ op_keys = jax .random .fold_in (key , stochastic_idx )
421439 psi = qx .targeted_apply_kraus_map_trajectory (op , psi , op_keys , subsystem )
422- key_idx += 1
440+ stochastic_idx += 1
423441 case qx .QuantumInstrument ():
424- op_keys = all_keys [key_idx ]
442+ if per_traj_keys is not None :
443+ op_keys = jax .vmap (lambda k : jax .random .fold_in (k , stochastic_idx ))(per_traj_keys )
444+ else :
445+ op_keys = jax .random .fold_in (key , stochastic_idx )
425446 psi , outcome = qx .targeted_apply_instrument_to_state_vector (op , psi , op_keys , subsystem )
426447 measurement_outcomes .append (outcome )
427- key_idx += 1
448+ stochastic_idx += 1
428449 case _:
429450 raise TypeError (f"Unsupported operator type: { type (op )} " )
430451
@@ -436,6 +457,20 @@ def _apply_trajectory_operations(
436457 return psi , outcomes
437458
438459
460+ def _make_mesh (devices : list [jax .Device ] | None ) -> Mesh | None :
461+ """Build a 1-D ``Mesh`` over *devices*, or ``None`` for single-device."""
462+ if devices is None :
463+ devices = jax .devices ()
464+ if len (devices ) <= 1 :
465+ return None
466+ return Mesh (np .array (devices ), axis_names = ("traj" ,))
467+
468+
469+ def _round_up_to (n : int , divisor : int ) -> int :
470+ """Round *n* up to the nearest multiple of *divisor*."""
471+ return ((n + divisor - 1 ) // divisor ) * divisor
472+
473+
439474def _run_batched_trajectories (
440475 operations : list [TrajectoryOp ],
441476 n_qubits : int ,
@@ -444,11 +479,21 @@ def _run_batched_trajectories(
444479 random_seed : int ,
445480 keep_states : bool = True ,
446481 dims : tuple [int , ...] | None = None ,
482+ devices : list [jax .Device ] | None = None ,
447483) -> tuple [list [qx .StateVector ] | None , list [Array ]]:
448- """Run trajectory simulation in batches."""
484+ """Run trajectory simulation in batches, optionally sharded across devices.
485+
486+ When *devices* contains more than one device a :class:`jax.sharding.Mesh`
487+ is constructed and both the initial state vector and PRNG keys are sharded
488+ along the trajectory (ensemble) axis. XLA's SPMD partitioner then
489+ distributes the work so that each device processes its own slice.
490+ """
449491 if dims is None :
450492 dims = (2 ,) * n_qubits
451493
494+ mesh = _make_mesh (devices )
495+ n_devices = len (mesh .devices .flat ) if mesh is not None else 1
496+
452497 key = jax .random .key (random_seed )
453498 all_psis : list [qx .StateVector ] = [] if keep_states else []
454499 all_outcomes : list [Array ] = []
@@ -458,31 +503,59 @@ def _run_batched_trajectories(
458503 t_total = 0.0
459504 while remaining > 0 :
460505 this_batch = min (remaining , batch_size )
506+
507+ # Pad to a multiple of n_devices so the shard split is even.
508+ padded_batch = _round_up_to (this_batch , n_devices ) if n_devices > 1 else this_batch
509+ n_pad = padded_batch - this_batch
510+
461511 key , batch_key = jax .random .split (key )
462512
463- if this_batch == 1 :
513+ if padded_batch == 1 :
464514 psi = qx .zero_state_vector (dims = dims )
465515 else :
466- psi = qx .zero_state_vector (dims = dims , ensemble_size = (this_batch ,))
516+ psi = qx .zero_state_vector (dims = dims , ensemble_size = (padded_batch ,))
517+
518+ # Shard state and key across devices when a mesh is available.
519+ if mesh is not None :
520+ sharding = NamedSharding (mesh , PartitionSpec ("traj" )) # type: ignore[no-untyped-call]
521+ psi = qx .StateVector .from_matrix (
522+ jax .device_put (psi .matrix , sharding ),
523+ psi .dims ,
524+ )
525+ # Split a per-trajectory key vector and shard it.
526+ batch_keys = jax .random .split (batch_key , padded_batch )
527+ batch_keys = jax .device_put (batch_keys , sharding )
528+ else :
529+ batch_keys = batch_key
467530
468531 t0 = time .perf_counter ()
469- psi_out , outcomes = _apply_trajectory_operations (operations , psi , batch_key )
532+ psi_out , outcomes = _apply_trajectory_operations (operations , psi , batch_keys )
470533 psi_out .matrix .block_until_ready ()
471534 t1 = time .perf_counter ()
472535 t_total += t1 - t0
473536
474- if this_batch == 1 :
537+ # Strip padding rows.
538+ if n_pad > 0 :
539+ psi_out = qx .StateVector .from_matrix (
540+ psi_out .matrix [:this_batch ],
541+ psi_out .dims ,
542+ )
543+ outcomes = outcomes [:this_batch ]
544+
545+ if this_batch == 1 and padded_batch == 1 :
475546 psi_out = qx .StateVector .from_matrix (
476547 psi_out .matrix [jnp .newaxis ],
477548 psi_out .dims ,
478549 )
479550 outcomes = outcomes [jnp .newaxis ]
480551
481552 logger .debug (
482- "Batch %d: %d trajectories, %d qubits, %.3f s" ,
553+ "Batch %d: %d trajectories (%d padded) , %d qubits, %d device(s) , %.3f s" ,
483554 batch_idx ,
484555 this_batch ,
556+ padded_batch ,
485557 n_qubits ,
558+ n_devices ,
486559 t1 - t0 ,
487560 )
488561
@@ -493,11 +566,12 @@ def _run_batched_trajectories(
493566 batch_idx += 1
494567
495568 logger .info (
496- "Trajectories complete: %d total, %d batches (size=%d), n_qubits=%d, %.3f s total, %.1f traj/s" ,
569+ "Trajectories complete: %d total, %d batches (size=%d), n_qubits=%d, %d device(s), % .3f s total, %.1f traj/s" ,
497570 num_trajectories ,
498571 batch_idx ,
499572 batch_size ,
500573 n_qubits ,
574+ n_devices ,
501575 t_total ,
502576 num_trajectories / t_total if t_total > 0 else float ("inf" ),
503577 )
0 commit comments