@@ -243,20 +243,6 @@ Run the real workflow, if you found these loggings in the running log, it means
243243
244244 By adjusting this factor, users can fine-tune the trade-off between memory efficiency
245245 and performance optimizations.
246- * ** --xla_gpu_enable_pipelined_collectives** When using pipeline parallelism,
247- this flag enables overlapping the (i+1)-th layer weight ` AllGather ` with the
248- i-th layer computation. It also enables overlapping (i+1)-th layer
249- weight ` Reduce ` /` ReduceScatter ` with i-th layer's computation. The default
250- value is False. ** There are some bugs when this flag is turned on.**
251- * ** --xla_gpu_collective_permute_decomposer_threshold** This flag is useful when
252- performing [ GSPMD pipelining] ( https://arxiv.org/abs/2105.04663 ) . Setting a
253- nonzero threshold decomposes ` CollectivePermute ` s into
254- ` CollectivePermuteReceiveDone ` and ` CollectivePermuteSendDone ` pairs, so that
255- computation can be performed between each corresponding
256- ` ReceiveDone ` /` SendDone ` pair and hence achieve more overlap. By default the
257- threshold is 0 and there is no decomposition. Setting it to threshold > 0 such
258- as ` --xla_gpu_collective_permute_decomposer_threshold=1024 ` can enable this
259- feature.
260246* ** --xla_gpu_all_gather_combine_threshold_bytes**
261247 ** --xla_gpu_reduce_scatter_combine_threshold_bytes**
262248 ** --xla_gpu_all_reduce_combine_threshold_bytes**
@@ -268,6 +254,227 @@ Run the real workflow, if you found these loggings in the running log, it means
268254 combine at least a Transformer Layer's weight ` AllGather ` /` ReduceScatter ` . By
269255 default, the ` combine_threshold_bytes ` is set to 256.
270256
257+ ### Pipeline Parallelism on GPU
258+
259+ XLA implements SPMD-based pipeline parallelism optimizations. This is a scaling technique
260+ where the forward and backward pass are split into multiple pipeline stages.
261+ Each device (or device group) processes the result of the previous
262+ pipeline stage (or the pipeline input) and sends its partial result to the next
263+ stage until the end of the pipeline is reached. This optimization works best
264+ when the latency of the computation is larger than communication. At compile
265+ time, the operations will be rearranged to overlap communication with
266+ computation.
267+
268+ For an optimized schedule, we recommend these XLA flags:
269+ ```
270+ --xla_gpu_enable_latency_hiding_scheduler=true
271+ --xla_gpu_enable_command_buffer=''
272+ --xla_disable_hlo_passes=collective-permute-motion
273+ --xla_gpu_experimental_pipeline_parallelism_opt_level=PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE
274+ ```
275+
276+ The following JAX example demonstrates a pattern where communication operations
277+ are scheduled to overlap with computations. In this example we will illustrate
278+ how to set up an optimized pipeline parallelism scheduling using 4 GPUs that
279+ form a communication ring (device 0 -> device 1 -> device 2 -> device 3 ->
280+ device 0). We refer to the pattern ` 0 -> 1 -> 2 -> 3 ` as the forward edge, and
281+ ` 3 -> 0 ` as the back edge.
282+
283+ ```
284+ # Imports and setup
285+ import functools
286+ import jax
287+ from jax import sharding
288+ from jax.experimental import mesh_utils
289+ import jax.numpy as jnp
290+ import jax.random
291+
292+ NUM_DEVICES = 4
293+ NUM_MICROBATCHES = 5
294+ NUM_CIRC_REPEATS = 2
295+ CONTRACTING_DIM_SIZE = 4096
296+ NON_CONTRACTING_DIM_SIZE = 8192
297+ COMPUTE_INTENSITY = 32
298+
299+ # Creates a collective permute for the "forward edge".
300+ # 0->1, 1->2, ... (N-2)->(N-1)
301+ def shift_right(arr):
302+ padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1)
303+ # Use lax.slice to guarantee the gradient is a pad.
304+ return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape)
305+
306+
307+ # Creates a collective permute for the "back edge".
308+ # (N-1)->0
309+ def cycle_back(arr):
310+ padding = [[0, NUM_DEVICES - 1]] + [[0, 0]] * (arr.ndim - 1)
311+ return jax.lax.slice(
312+ jnp.pad(arr, padding),
313+ [NUM_DEVICES - 1] + [0] * (arr.ndim - 1),
314+ (NUM_DEVICES - 1 + arr.shape[0],) + arr.shape[1:],
315+ )
316+
317+
318+ def select_on_first_device(then_value, else_value):
319+ assert then_value.shape == else_value.shape
320+ is_first_device = jax.lax.broadcasted_iota("int32", then_value.shape, 0) == 0
321+ return jnp.where(is_first_device, then_value, else_value)
322+
323+
324+ def select_on_last_device(then_value, else_value):
325+ assert then_value.shape == else_value.shape
326+ is_last_device = (
327+ jax.lax.broadcasted_iota("int32", then_value.shape, 0) == NUM_DEVICES - 1
328+ )
329+ return jnp.where(is_last_device, then_value, else_value)
330+
331+
332+ def select_on_first_cycle(i, then_value, else_value):
333+ assert then_value.shape == else_value.shape
334+ is_first_cycle = i < NUM_MICROBATCHES
335+ return jnp.where(is_first_cycle, then_value, else_value)
336+
337+
338+ def while_body(carry, i):
339+ """Body of the pipeline while loop."""
340+ weights, input_buffer, output_buffer, fwd_edge_data, bwd_edge_data = carry
341+
342+ # Read input data from input buffer.
343+ input_data = jax.lax.dynamic_slice(
344+ input_buffer,
345+ (0, (i + 0) % NUM_MICROBATCHES, 0, 0),
346+ (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
347+ )
348+
349+ # Collective permute on the "forward edge" shifts data to the next stage.
350+ fwd_edge_data = shift_right(fwd_edge_data)
351+
352+ # Select compute argument based on device and pipeline cycle.
353+ compute_argument = select_on_first_device(
354+ select_on_first_cycle(i, input_data, bwd_edge_data),
355+ fwd_edge_data,
356+ ).reshape((NUM_DEVICES, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE))
357+
358+ # A few matmuls to simulate compute.
359+ tmp = compute_argument
360+ for _ in range(COMPUTE_INTENSITY):
361+ tmp = jax.lax.dot_general(weights, tmp, (((2,), (1,)), ((0,), (0,))))
362+ compute_result = tmp.reshape(
363+ (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
364+ )
365+
366+ # Read data from buffer to pass it to the first device of the pipeline on the
367+ # "back edge".
368+ bwd_edge_data = jax.lax.dynamic_slice(
369+ output_buffer,
370+ (0, (1 + i) % NUM_MICROBATCHES, 0, 0),
371+ (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
372+ )
373+
374+ # Colelctive permute on the "back edge" passes data to the first device.
375+ bwd_edge_data = cycle_back(bwd_edge_data)
376+
377+ # Update output buffer. We do this after reading from it to avoid the data
378+ # dependency.
379+ output_buffer = jax.lax.dynamic_update_slice(
380+ output_buffer,
381+ compute_result,
382+ (0, (2 + i) % NUM_MICROBATCHES, 0, 0),
383+ )
384+
385+ fwd_edge_data = compute_result
386+ carry = (
387+ weights,
388+ input_buffer,
389+ output_buffer,
390+ fwd_edge_data,
391+ bwd_edge_data,
392+ )
393+ return carry, i
394+
395+
396+ @functools.partial(jax.jit, static_argnames=["mesh"])
397+ def entry_computation(weights, input_buffer, mesh):
398+
399+ # Init output buffer.
400+ output_buffer = jnp.zeros_like(input_buffer)
401+
402+ # Init dummy data for forward and backward edge passed through the while loop.
403+ dummy_data = jnp.zeros(
404+ shape=(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
405+ ).astype(jnp.float32)
406+ dummy_data = jax.device_put(
407+ dummy_data,
408+ sharding.NamedSharding(
409+ mesh, sharding.PartitionSpec("the_one_and_only_axis")
410+ ),
411+ )
412+
413+ # Start pipeline.
414+ carry = weights, input_buffer, output_buffer, dummy_data, dummy_data
415+ num_iterations = NUM_CIRC_REPEATS * NUM_MICROBATCHES + NUM_DEVICES - 1
416+ carry, _ = jax.lax.scan(while_body, carry, xs=jnp.arange(num_iterations))
417+ _, _, output_buffer, _, _ = carry
418+
419+ return output_buffer
420+
421+
422+ def main(_):
423+
424+ # Expect constant number of devices.
425+ assert NUM_DEVICES == jax.local_device_count()
426+
427+ # Create mesh.
428+ mesh = sharding.Mesh(
429+ mesh_utils.create_device_mesh([NUM_DEVICES]),
430+ axis_names=["the_one_and_only_axis"],
431+ )
432+
433+ # Init weights.
434+ weights = 1.0 / CONTRACTING_DIM_SIZE
435+ weights = jax.lax.broadcast_in_dim(
436+ weights,
437+ shape=(NUM_DEVICES, CONTRACTING_DIM_SIZE, CONTRACTING_DIM_SIZE),
438+ broadcast_dimensions=(),
439+ )
440+ weights = jax.device_put(
441+ weights,
442+ sharding.NamedSharding(
443+ mesh, sharding.PartitionSpec("the_one_and_only_axis")
444+ ),
445+ )
446+
447+ # Init random input and replicate it across all devices.
448+ random_key = jax.random.key(0)
449+ input_buffer = jax.random.uniform(
450+ random_key,
451+ shape=(
452+ NUM_MICROBATCHES,
453+ CONTRACTING_DIM_SIZE,
454+ NON_CONTRACTING_DIM_SIZE,
455+ ),
456+ )
457+ input_buffer = jax.lax.broadcast_in_dim(
458+ input_buffer,
459+ shape=(
460+ NUM_DEVICES,
461+ NUM_MICROBATCHES,
462+ CONTRACTING_DIM_SIZE,
463+ NON_CONTRACTING_DIM_SIZE,
464+ ),
465+ broadcast_dimensions=[1, 2, 3],
466+ )
467+ input_buffer = jax.device_put(
468+ input_buffer,
469+ sharding.NamedSharding(
470+ mesh, sharding.PartitionSpec("the_one_and_only_axis")
471+ ),
472+ )
473+
474+ # Run computation.
475+ output_buffer = entry_computation(weights, input_buffer, mesh)
476+ print(f"output_buffer = \n{output_buffer}")
477+ ```
271478## NCCL flags
272479
273480These Nvidia NCCL flag values may be useful for single-host multi-device
0 commit comments