@@ -1801,7 +1801,11 @@ def pytest_generate_tests(metafunc):
18011801
18021802 # Parametrize framework specific tests
18031803 for fixture in FRAMEWORK_FIXTURES :
1804+ if fixture in ["gpu" ]:
1805+ LOGGER .info (f"Checking gpu fixture 1: { fixture } " )
18041806 if fixture in metafunc .fixturenames :
1807+ if fixture in ["gpu" ]:
1808+ LOGGER .info (f"Checking gpu fixture 2: { fixture } " )
18051809 lookup = fixture .replace ("___" , ":" ).replace ("__" , "." ).replace ("_" , "-" )
18061810 images_to_parametrize = []
18071811 for image in images :
@@ -1840,13 +1844,8 @@ def pytest_generate_tests(metafunc):
18401844 continue
18411845 if not framework_version_within_limit (metafunc , image ):
18421846 continue
1843- if "below_cuda129_only" in metafunc .fixturenames :
1844- cuda_version = get_cuda_version_from_tag (image )
1845- is_below = is_below_cuda_version ("12.9" , image )
1846- LOGGER .info (f"CUDA version check for image { image } :" )
1847- LOGGER .info (f"- CUDA version: { cuda_version } " )
1848- LOGGER .info (f"- Is below CUDA 12.9: { is_below } " )
1849- LOGGER .info (f"- Final CUDA check result: { cuda_version_within_limit (metafunc , image )} " )
1847+ if fixture in ["gpu" ]:
1848+ LOGGER .info (f"Checking gpu fixture 3: { fixture } " )
18501849 if not cuda_version_within_limit (metafunc , image ):
18511850 continue
18521851 if "non_huggingface_only" in metafunc .fixturenames and "huggingface" in image :
@@ -1862,6 +1861,8 @@ def pytest_generate_tests(metafunc):
18621861 "graviton" in image or "arm64" in image
18631862 ):
18641863 continue
1864+ if fixture in ["gpu" ]:
1865+ LOGGER .info (f"Checking gpu fixture 4: { fixture } " )
18651866 if "training_compiler_only" in metafunc .fixturenames and not (
18661867 "trcomp" in image
18671868 ):
@@ -1872,6 +1873,8 @@ def pytest_generate_tests(metafunc):
18721873 or is_standard_lookup
18731874 or is_trcomp_lookup
18741875 ):
1876+ if fixture in ["gpu" ]:
1877+ LOGGER .info (f"Checking gpu fixture 5: { fixture } " )
18751878 if (
18761879 "cpu_only" in metafunc .fixturenames
18771880 and "cpu" in image
@@ -1880,6 +1883,8 @@ def pytest_generate_tests(metafunc):
18801883 images_to_parametrize .append (image )
18811884 elif "gpu_only" in metafunc .fixturenames and "gpu" in image :
18821885 images_to_parametrize .append (image )
1886+ if fixture in ["gpu" ]:
1887+ LOGGER .info (f"Checking gpu fixture 6: { fixture } " )
18831888 elif (
18841889 "graviton_compatible_only" in metafunc .fixturenames
18851890 and "graviton" in image
@@ -1895,6 +1900,9 @@ def pytest_generate_tests(metafunc):
18951900 ):
18961901 images_to_parametrize .append (image )
18971902
1903+
1904+ if fixture in ["gpu" ]:
1905+ LOGGER .info (f"Checking gpu fixture 7: { fixture } { images_to_parametrize } " )
18981906 # Remove all images tagged as "py2" if py3_only is a fixture
18991907 if images_to_parametrize and "py3_only" in metafunc .fixturenames :
19001908 images_to_parametrize = [
@@ -1920,6 +1928,9 @@ def pytest_generate_tests(metafunc):
19201928 nightly_images_to_parametrize .append (image_candidate )
19211929 images_to_parametrize = nightly_images_to_parametrize
19221930
1931+ if fixture in ["gpu" ]:
1932+ LOGGER .info (f"Checking gpu fixture 8: { fixture } { images_to_parametrize } " )
1933+
19231934 # Parametrize tests that spin up an ecs cluster or tests that spin up an EC2 instance with a unique name
19241935 values_to_generate_for_fixture = {
19251936 "ecs_container_instance" : "ecs_cluster_name" ,
@@ -1930,6 +1941,8 @@ def pytest_generate_tests(metafunc):
19301941 fixtures_parametrized = generate_unique_values_for_fixtures (
19311942 metafunc , images_to_parametrize , values_to_generate_for_fixture
19321943 )
1944+ if fixture in ["gpu" ]:
1945+ LOGGER .info (f"Checking gpu fixture 9: { fixture } { fixtures_parametrized } " )
19331946 if fixtures_parametrized :
19341947 for new_fixture_name , test_parametrization in fixtures_parametrized .items ():
19351948 metafunc .parametrize (f"{ fixture } ,{ new_fixture_name } " , test_parametrization )
0 commit comments