forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_vision_cuda_graphs.py
More file actions
680 lines (565 loc) · 25.9 KB
/
test_vision_cuda_graphs.py
File metadata and controls
680 lines (565 loc) · 25.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
import gc
import os
from copy import deepcopy
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
import torch
from megatron.core import parallel_state
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec
from megatron.core.tensor_parallel.random import (
HAVE_TE,
initialize_rng_tracker,
model_parallel_cuda_manual_seed,
)
from megatron.core.transformer.cuda_graphs import (
HAVE_TE_GRAPHS,
VisionTECudaGraphHelper,
_layer_is_graphable,
_wrap_graph_for_vision,
get_vision_cuda_graph_seq_length,
set_current_microbatch,
)
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import is_te_min_version
from tests.unit_tests.test_utilities import Utils
TE_MIN_VERSION = "2.13.0"
_te_version_ok = HAVE_TE and is_te_min_version(TE_MIN_VERSION)
if not _te_version_ok and __name__ != "__main__":
pytest.skip(
f"Vision CUDA graph tests require TransformerEngine >= {TE_MIN_VERSION}",
allow_module_level=True,
)
# ---------------------------------------------------------------------------
# Tests for _layer_is_graphable
# ---------------------------------------------------------------------------
class TestVisionLayerIsGraphable:
def test_non_transformer_layer_returns_false(self):
config = SimpleNamespace(cuda_graph_impl="transformer_engine")
layer = torch.nn.Linear(4, 4)
assert _layer_is_graphable(layer, config) is False
@pytest.mark.flaky
@pytest.mark.flaky_in_dev
def test_wrong_cuda_graph_impl_returns_false(self):
from megatron.core.transformer.transformer_layer import TransformerLayer
config = SimpleNamespace(cuda_graph_impl="local")
layer = MagicMock(spec=TransformerLayer)
# isinstance check with MagicMock(spec=...) should pass
assert _layer_is_graphable(layer, config) is False
def test_correct_config_with_transformer_layer(self):
"""Real TransformerLayer + cuda_graph_impl='transformer_engine' -> True."""
initialize_rng_tracker(use_te_rng_tracker=True, force_reset=True)
Utils.initialize_model_parallel(
tensor_model_parallel_size=1, pipeline_model_parallel_size=1
)
model_parallel_cuda_manual_seed(123)
config = TransformerConfig(
num_layers=1,
hidden_size=16,
num_attention_heads=2,
use_cpu_initialization=True,
cuda_graph_impl="transformer_engine",
)
from megatron.core.transformer.transformer_block import TransformerBlock
block = TransformerBlock(config, get_vit_layer_with_transformer_engine_spec())
layer = block.layers[0]
assert _layer_is_graphable(layer, config) is True
Utils.destroy_model_parallel()
# ---------------------------------------------------------------------------
# Tests for _wrap_graph_for_vision
# ---------------------------------------------------------------------------
class TestWrapGraphForVision:
def test_filters_none_from_tuple(self):
def fake_graph(*args, **kwargs):
return (torch.tensor(1.0), None)
wrapped = _wrap_graph_for_vision(fake_graph)
result = wrapped()
assert result == (torch.tensor(1.0),)
def test_returns_non_tuple_unchanged(self):
t = torch.tensor(42.0)
def fake_graph(*args, **kwargs):
return t
wrapped = _wrap_graph_for_vision(fake_graph)
result = wrapped()
assert result is t
def test_preserves_all_non_none(self):
a, b = torch.tensor(1.0), torch.tensor(2.0)
def fake_graph(*args, **kwargs):
return (a, b)
wrapped = _wrap_graph_for_vision(fake_graph)
result = wrapped()
assert result == (a, b)
def test_all_none_returns_original(self):
def fake_graph(*args, **kwargs):
return (None, None)
wrapped = _wrap_graph_for_vision(fake_graph)
result = wrapped()
# filtered is empty -> returns original tuple
assert result == (None, None)
def test_preserves_te_attributes(self):
def fake_graph(*args, **kwargs):
return (torch.tensor(1.0),)
fake_graph.backward_dw = "bwd_dw_fn"
fake_graph.reset = "reset_fn"
wrapped = _wrap_graph_for_vision(fake_graph)
assert wrapped.backward_dw == "bwd_dw_fn"
assert wrapped.reset == "reset_fn"
def test_missing_te_attributes_not_set(self):
def fake_graph(*args, **kwargs):
return (torch.tensor(1.0),)
wrapped = _wrap_graph_for_vision(fake_graph)
assert not hasattr(wrapped, 'backward_dw')
assert not hasattr(wrapped, 'reset')
# ---------------------------------------------------------------------------
# Tests for get_vision_cuda_graph_seq_length
# ---------------------------------------------------------------------------
class TestGetVisionCudaGraphSeqLength:
def test_explicit_max_seq_length(self):
config = SimpleNamespace(max_vision_cuda_graph_seq_length=2048)
assert get_vision_cuda_graph_seq_length(config) == 2048
def test_explicit_max_seq_length_zero_falls_through(self):
"""max_vision_cuda_graph_seq_length=0 is falsy, should fall through."""
config = SimpleNamespace(max_vision_cuda_graph_seq_length=0)
assert get_vision_cuda_graph_seq_length(config, default_seq_length=999) == 999
def test_num_position_embeddings_only(self):
config = SimpleNamespace(num_position_embeddings=1024)
assert get_vision_cuda_graph_seq_length(config) == 1024
def test_num_position_embeddings_with_spatial_merge(self):
config = SimpleNamespace(num_position_embeddings=1024, spatial_merge_size=2)
# merge_factor = 2**2 = 4, seq = 1024 // 4 = 256
assert get_vision_cuda_graph_seq_length(config) == 256
def test_spatial_merge_size_3(self):
config = SimpleNamespace(num_position_embeddings=900, spatial_merge_size=3)
# merge_factor = 9, seq = 900 // 9 = 100
assert get_vision_cuda_graph_seq_length(config) == 100
def test_default_seq_length(self):
config = SimpleNamespace()
assert get_vision_cuda_graph_seq_length(config) == 4096
def test_custom_default(self):
config = SimpleNamespace()
assert get_vision_cuda_graph_seq_length(config, default_seq_length=512) == 512
def test_explicit_overrides_position_embeddings(self):
config = SimpleNamespace(
max_vision_cuda_graph_seq_length=8192, num_position_embeddings=1024
)
assert get_vision_cuda_graph_seq_length(config) == 8192
# ---------------------------------------------------------------------------
# Integration test for VisionTECudaGraphHelper with LLaVA model
# ---------------------------------------------------------------------------
@pytest.mark.skipif(
not (HAVE_TE and is_te_min_version("1.5.0")),
reason="use_te_rng_tracker requires TransformerEngine version >= 1.5",
)
class TestVisionTECudaGraphHelper:
"""Test VisionTECudaGraphHelper initialization, sample args, and graph lifecycle."""
def setup_method(self, method):
initialize_rng_tracker(use_te_rng_tracker=True, force_reset=True)
Utils.initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None,
)
model_parallel_cuda_manual_seed(123)
from megatron.core.models.multimodal.llava_model import LLaVAModel
self.language_hidden_size = 64
self.vision_hidden_size = 16
self.vision_num_layers = 2
language_config = TransformerConfig(
num_layers=2,
hidden_size=self.language_hidden_size,
num_attention_heads=4,
use_cpu_initialization=True,
)
self.vision_config = TransformerConfig(
num_layers=self.vision_num_layers,
hidden_size=self.vision_hidden_size,
num_attention_heads=2,
use_cpu_initialization=True,
cuda_graph_impl="transformer_engine",
bf16=True,
pipeline_dtype=torch.bfloat16,
)
vision_projection_config = TransformerConfig(
num_layers=1,
hidden_size=self.language_hidden_size,
ffn_hidden_size=32,
num_attention_heads=1,
use_cpu_initialization=True,
bf16=True,
pipeline_dtype=torch.bfloat16,
)
language_layer_spec = get_gpt_layer_with_transformer_engine_spec()
vision_layer_spec = get_vit_layer_with_transformer_engine_spec()
vision_projection_spec = deepcopy(language_layer_spec.submodules.mlp.submodules)
self.vision_config.vision_model_type = "clip"
language_config.language_model_type = "dummy"
self.llava_model = LLaVAModel(
language_transformer_config=language_config,
language_transformer_layer_spec=language_layer_spec,
language_vocab_size=8192,
language_max_sequence_length=4096,
vision_transformer_config=self.vision_config,
vision_transformer_layer_spec=vision_layer_spec,
drop_vision_class_token=False,
vision_projection_config=vision_projection_config,
vision_projection_layer_spec=vision_projection_spec,
img_h=336,
img_w=336,
patch_dim=14,
pre_process=True,
post_process=True,
add_encoder=True,
add_decoder=True,
)
self.llava_model.bfloat16()
self.vision_seq_length = 576
self.micro_batch_size = 2
def teardown_method(self, method):
Utils.destroy_model_parallel()
gc.collect()
def _make_helper(self, num_microbatches=1):
return VisionTECudaGraphHelper(
model=[self.llava_model],
vision_config=self.vision_config,
vision_seq_length=self.vision_seq_length,
micro_batch_size=self.micro_batch_size,
num_microbatches=num_microbatches,
)
# -- Initialization tests --
def test_init_finds_vision_layers(self):
helper = self._make_helper()
assert helper.vision_model is not None, "Should find vision_model"
assert helper.num_layers == self.vision_num_layers
assert len(helper.callables) == self.vision_num_layers
assert helper.graphs_created() is False
def test_init_no_vision_model_warns(self):
"""When model has no vision_model attr, helper should degrade gracefully."""
dummy_model = torch.nn.Linear(4, 4)
helper = VisionTECudaGraphHelper(
model=[dummy_model],
vision_config=self.vision_config,
vision_seq_length=self.vision_seq_length,
micro_batch_size=self.micro_batch_size,
)
assert helper.vision_model is None
assert len(helper.callables) == 0
assert helper.graphs_created() is False
# -- _get_sample_arguments tests --
def test_get_sample_arguments_shapes(self):
helper = self._make_helper(num_microbatches=1)
# order is unused by vision override; pass a dummy
sample_args, sample_kwargs_list = helper._get_sample_arguments(order=[1, -1])
expected_count = self.vision_num_layers * 1 # layers * microbatches
assert len(sample_args) == expected_count
assert len(sample_kwargs_list) == expected_count
for i, (args_item, kwargs_item) in enumerate(zip(sample_args, sample_kwargs_list)):
assert isinstance(args_item, tuple), f"sample_args[{i}] should be tuple"
assert len(args_item) == 1, f"sample_args[{i}] should have one element (hidden_states)"
hs = args_item[0]
assert hs.shape == (self.vision_seq_length, 1, self.vision_hidden_size), (
f"Expected ({self.vision_seq_length}, 1, {self.vision_hidden_size}), "
f"got {hs.shape}"
)
assert hs.dtype == torch.bfloat16
assert hs.device.type == 'cuda'
assert hs.requires_grad is True
def test_get_sample_arguments_multi_microbatch(self):
helper = self._make_helper(num_microbatches=3)
sample_args, sample_kwargs_list = helper._get_sample_arguments(order=[1, -1])
expected_count = self.vision_num_layers * 3
assert len(sample_args) == expected_count
assert len(sample_kwargs_list) == expected_count
def test_get_sample_arguments_empty_when_no_callables(self):
dummy_model = torch.nn.Linear(4, 4)
helper = VisionTECudaGraphHelper(
model=[dummy_model],
vision_config=self.vision_config,
vision_seq_length=self.vision_seq_length,
micro_batch_size=self.micro_batch_size,
)
sample_args, sample_kwargs_list = helper._get_sample_arguments(order=[1, -1])
assert sample_args == []
assert sample_kwargs_list == []
# -- create_cudagraphs / delete_cuda_graphs lifecycle --
@pytest.mark.flaky
@pytest.mark.flaky_in_dev
@pytest.mark.skipif(
not (HAVE_TE_GRAPHS and is_te_min_version("2.7.0")),
reason="TE CUDA graph capture requires TransformerEngine >= 2.7.0",
)
def test_create_and_delete_cudagraphs(self):
"""Full lifecycle: create graphs, verify state, delete, verify cleanup."""
self.llava_model.cuda()
helper = self._make_helper(num_microbatches=1)
assert not helper.graphs_created()
helper.create_cudagraphs()
assert helper.graphs_created()
# Each vision layer should have cuda_graphs attached
for layer in helper.callables:
assert hasattr(layer, 'cuda_graphs'), "Layer should have cuda_graphs after capture"
assert len(layer.cuda_graphs) == 1 # 1 microbatch
# cudagraph_manager should have been removed during capture
for layer in helper.callables:
assert not hasattr(
layer, 'cudagraph_manager'
), "cudagraph_manager should be removed before TE capture"
helper.delete_cuda_graphs()
assert not helper.graphs_created()
# cuda_graphs should be empty after delete
for layer in helper.callables:
assert layer.cuda_graphs == [], "cuda_graphs should be empty after delete"
@pytest.mark.skipif(
not (HAVE_TE_GRAPHS and is_te_min_version("2.7.0")),
reason="TE CUDA graph capture requires TransformerEngine >= 2.7.0",
)
@pytest.mark.flaky
@pytest.mark.flaky_in_dev
def test_create_cudagraphs_multi_microbatch(self):
"""Verify that graphs are created per-microbatch per-layer."""
self.llava_model.cuda()
num_mb = 2
helper = self._make_helper(num_microbatches=num_mb)
helper.create_cudagraphs()
assert helper.graphs_created()
for layer in helper.callables:
assert hasattr(layer, 'cuda_graphs')
# PP=1 collapses to 1 microbatch internally
assert len(layer.cuda_graphs) == helper.num_microbatches
helper.delete_cuda_graphs()
def test_create_cudagraphs_no_callables_is_noop(self):
"""create_cudagraphs on empty helper should not crash."""
dummy_model = torch.nn.Linear(4, 4)
helper = VisionTECudaGraphHelper(
model=[dummy_model],
vision_config=self.vision_config,
vision_seq_length=self.vision_seq_length,
micro_batch_size=self.micro_batch_size,
)
helper.create_cudagraphs()
assert not helper.graphs_created()
def test_delete_cudagraphs_before_create_asserts(self):
"""delete_cuda_graphs before creation should raise AssertionError."""
helper = self._make_helper()
with pytest.raises(AssertionError):
helper.delete_cuda_graphs()
# ---------------------------------------------------------------------------
# Integration test with PP=2: vision encoder on first pipeline stage only
# ---------------------------------------------------------------------------
@pytest.mark.skipif(
not (HAVE_TE and is_te_min_version("1.5.0")),
reason="use_te_rng_tracker requires TransformerEngine version >= 1.5",
)
class TestVisionTECudaGraphHelperPP2:
"""Test VisionTECudaGraphHelper with PP=2.
With pipeline_model_parallel_size=2 the LLaVA model is split so that the
vision encoder lives exclusively on the first pipeline stage:
- pp_rank 0: add_encoder=True, pre_process=True, post_process=False
- pp_rank 1: add_encoder=False, pre_process=False, post_process=True
This test verifies that:
1. On stage 0 the helper finds and captures vision layers.
2. On stage 1 the helper gracefully finds no vision layers.
3. With PP>1, num_microbatches is NOT collapsed to 1.
"""
def setup_method(self, method):
initialize_rng_tracker(use_te_rng_tracker=True, force_reset=True)
Utils.initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=2,
virtual_pipeline_model_parallel_size=None,
)
model_parallel_cuda_manual_seed(123)
from megatron.core.models.multimodal.llava_model import LLaVAModel
self.language_hidden_size = 64
self.vision_hidden_size = 16
self.vision_num_layers = 2
self.language_num_layers = 4
pp_rank = parallel_state.get_pipeline_model_parallel_rank()
is_first_stage = pp_rank == 0
is_last_stage = pp_rank == (parallel_state.get_pipeline_model_parallel_world_size() - 1)
language_config = TransformerConfig(
num_layers=self.language_num_layers,
hidden_size=self.language_hidden_size,
num_attention_heads=4,
use_cpu_initialization=True,
pipeline_model_parallel_size=2,
bf16=True,
pipeline_dtype=torch.bfloat16,
)
self.vision_config = TransformerConfig(
num_layers=self.vision_num_layers,
hidden_size=self.vision_hidden_size,
num_attention_heads=2,
use_cpu_initialization=True,
cuda_graph_impl="transformer_engine",
bf16=True,
pipeline_dtype=torch.bfloat16,
)
vision_projection_config = TransformerConfig(
num_layers=1,
hidden_size=self.language_hidden_size,
ffn_hidden_size=32,
num_attention_heads=1,
use_cpu_initialization=True,
bf16=True,
pipeline_dtype=torch.bfloat16,
)
language_layer_spec = get_gpt_layer_with_transformer_engine_spec()
vision_layer_spec = get_vit_layer_with_transformer_engine_spec()
vision_projection_spec = deepcopy(language_layer_spec.submodules.mlp.submodules)
self.vision_config.vision_model_type = "clip"
language_config.language_model_type = "dummy"
self.is_first_stage = is_first_stage
self.llava_model = LLaVAModel(
language_transformer_config=language_config,
language_transformer_layer_spec=language_layer_spec,
language_vocab_size=8192,
language_max_sequence_length=4096,
vision_transformer_config=self.vision_config,
vision_transformer_layer_spec=vision_layer_spec,
drop_vision_class_token=False,
vision_projection_config=vision_projection_config,
vision_projection_layer_spec=vision_projection_spec,
img_h=336,
img_w=336,
patch_dim=14,
pre_process=is_first_stage,
post_process=is_last_stage,
add_encoder=is_first_stage,
add_decoder=True,
)
self.llava_model.bfloat16()
self.vision_seq_length = 576
self.micro_batch_size = 2
def teardown_method(self, method):
Utils.destroy_model_parallel()
gc.collect()
def _make_helper(self, num_microbatches=4):
return VisionTECudaGraphHelper(
model=[self.llava_model],
vision_config=self.vision_config,
vision_seq_length=self.vision_seq_length,
micro_batch_size=self.micro_batch_size,
num_microbatches=num_microbatches,
)
def test_pp2_first_stage_finds_vision_layers(self):
"""Stage 0 should discover all vision encoder layers."""
if not self.is_first_stage:
pytest.skip("This assertion is only for pp_rank 0")
helper = self._make_helper(num_microbatches=4)
assert helper.vision_model is not None
assert helper.num_layers == self.vision_num_layers
assert len(helper.callables) == self.vision_num_layers
def test_pp2_last_stage_has_no_vision_layers(self):
"""Stage 1 should find no vision model (encoder lives on stage 0)."""
if self.is_first_stage:
pytest.skip("This assertion is only for pp_rank 1")
helper = self._make_helper(num_microbatches=4)
assert helper.vision_model is None
assert len(helper.callables) == 0
assert not helper.graphs_created()
def test_pp2_num_microbatches_preserved(self):
"""With PP>1, num_microbatches should NOT be collapsed to 1."""
if not self.is_first_stage:
pytest.skip("Vision layers only on pp_rank 0")
num_mb = 8
helper = self._make_helper(num_microbatches=num_mb)
# _get_sample_arguments generates layers * microbatches entries
sample_args, sample_kwargs_list = helper._get_sample_arguments(order=[1, -1])
expected_count = self.vision_num_layers * num_mb
assert len(sample_args) == expected_count, (
f"With PP>1, expected {expected_count} sample_args "
f"(layers={self.vision_num_layers} * mb={num_mb}), got {len(sample_args)}"
)
@pytest.mark.skipif(
not (HAVE_TE_GRAPHS and is_te_min_version("2.7.0")),
reason="TE CUDA graph capture requires TransformerEngine >= 2.7.0",
)
@pytest.mark.flaky
@pytest.mark.flaky_in_dev
def test_pp2_create_cudagraphs_first_stage(self):
"""On stage 0, CUDA graphs should be captured with the full pipeline order."""
if not self.is_first_stage:
pytest.skip("Vision layers only on pp_rank 0")
self.llava_model.cuda()
num_mb = 4
helper = self._make_helper(num_microbatches=num_mb)
assert not helper.graphs_created()
helper.create_cudagraphs()
assert helper.graphs_created()
# num_microbatches should be preserved (PP>1 does not collapse)
assert helper.num_microbatches == num_mb
# Each layer should have one graph per microbatch
for layer in helper.callables:
assert hasattr(layer, 'cuda_graphs')
assert (
len(layer.cuda_graphs) == num_mb
), f"Expected {num_mb} graphs per layer, got {len(layer.cuda_graphs)}"
# Cleanup
helper.delete_cuda_graphs()
assert not helper.graphs_created()
for layer in helper.callables:
assert layer.cuda_graphs == []
@pytest.mark.skipif(
not (HAVE_TE_GRAPHS and is_te_min_version("2.7.0")),
reason="TE CUDA graph capture requires TransformerEngine >= 2.7.0",
)
def test_pp2_create_cudagraphs_last_stage_noop(self):
"""On stage 1 (no vision model), create_cudagraphs should be a no-op."""
if self.is_first_stage:
pytest.skip("This assertion is only for pp_rank 1")
helper = self._make_helper(num_microbatches=4)
helper.create_cudagraphs()
assert not helper.graphs_created()
if __name__ == "__main__":
if not _te_version_ok:
print(f"SKIPPED: Vision CUDA graph tests require TransformerEngine >= {TE_MIN_VERSION}")
exit(0)
from _pytest.outcomes import Skipped
def run_test(test_obj, test_fn_name):
"""Run a test method, treating pytest.skip() as a non-error."""
test_obj.setup_method(method=None)
try:
getattr(test_obj, test_fn_name)()
except Skipped as e:
print(f" SKIPPED {test_fn_name}: {e}")
finally:
test_obj.teardown_method(method=None)
# Quick smoke tests for pure functions
t = TestWrapGraphForVision()
t.test_filters_none_from_tuple()
t.test_returns_non_tuple_unchanged()
t.test_preserves_all_non_none()
t.test_all_none_returns_original()
t.test_preserves_te_attributes()
t.test_missing_te_attributes_not_set()
print("_wrap_graph_for_vision tests passed.")
t2 = TestGetVisionCudaGraphSeqLength()
t2.test_explicit_max_seq_length()
t2.test_explicit_max_seq_length_zero_falls_through()
t2.test_num_position_embeddings_only()
t2.test_num_position_embeddings_with_spatial_merge()
t2.test_spatial_merge_size_3()
t2.test_default_seq_length()
t2.test_custom_default()
t2.test_explicit_overrides_position_embeddings()
print("get_vision_cuda_graph_seq_length tests passed.")
# Integration tests (require GPU + distributed init)
t3 = TestVisionTECudaGraphHelper()
run_test(t3, "test_init_finds_vision_layers")
run_test(t3, "test_get_sample_arguments_shapes")
run_test(t3, "test_create_and_delete_cudagraphs")
print("TestVisionTECudaGraphHelper tests passed.")
# PP=2 integration tests (require 2+ GPUs)
if Utils.world_size >= 2:
t4 = TestVisionTECudaGraphHelperPP2()
run_test(t4, "test_pp2_first_stage_finds_vision_layers")
run_test(t4, "test_pp2_last_stage_has_no_vision_layers")
run_test(t4, "test_pp2_num_microbatches_preserved")
run_test(t4, "test_pp2_create_cudagraphs_first_stage")
run_test(t4, "test_pp2_create_cudagraphs_last_stage_noop")
print("TestVisionTECudaGraphHelperPP2 tests passed.")
else:
print("SKIPPED TestVisionTECudaGraphHelperPP2 (requires 2+ GPUs)")
print("All vision CUDA graph tests passed.")