Skip to content

Commit 1e2583c

Browse files
authored
Fix failing distributed sampler test (#3453)
- Fix failing distributed sampler test - Updated pytorch docker image in gpu GHA workflows
1 parent d54f710 commit 1e2583c

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

.github/workflows/gpu-hvd-tests.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
pytorch-channel: [pytorch]
2525
fail-fast: false
2626
env:
27-
DOCKER_IMAGE: "pytorch/conda-builder:cuda12.1"
27+
DOCKER_IMAGE: "pytorch/almalinux-builder:cuda12.8"
2828
REPOSITORY: ${{ github.repository }}
2929
PR_NUMBER: ${{ github.event.pull_request.number }}
3030
runs-on: linux.8xlarge.nvidia.gpu
@@ -113,6 +113,10 @@ jobs:
113113
pip install -r requirements-dev.txt
114114
pip install -e .
115115
116+
# Upgrade pyOpenSSL to avoid issue:
117+
# AttributeError: module 'lib' has no attribute 'X509_V_FLAG_NOTIFY_POLICY'. Did you mean: 'X509_V_FLAG_EXPLICIT_POLICY'?
118+
pip install -U pyOpenSSL
119+
116120
EOF
117121
)
118122

.github/workflows/gpu-tests.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
pytorch-channel: [pytorch, pytorch-nightly]
2525
fail-fast: false
2626
env:
27-
DOCKER_IMAGE: "pytorch/almalinux-builder:cuda12.4"
27+
DOCKER_IMAGE: "pytorch/almalinux-builder:cuda12.8"
2828
REPOSITORY: ${{ github.repository }}
2929
PR_NUMBER: ${{ github.event.pull_request.number }}
3030
runs-on: linux.g4dn.12xlarge.nvidia.gpu
@@ -113,6 +113,10 @@ jobs:
113113
pip install -r requirements-dev.txt
114114
pip install -e .
115115
116+
# Upgrade pyOpenSSL to avoid issue:
117+
# AttributeError: module 'lib' has no attribute 'X509_V_FLAG_NOTIFY_POLICY'. Did you mean: 'X509_V_FLAG_EXPLICIT_POLICY'?
118+
pip install -U pyOpenSSL
119+
116120
EOF
117121
)
118122

tests/ignite/distributed/test_auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def test_dist_proxy_sampler():
311311
DistributedProxySampler(None)
312312

313313
with pytest.raises(TypeError, match=r"Argument sampler should have length"):
314-
DistributedProxySampler(Sampler([1]))
314+
DistributedProxySampler(Sampler())
315315

316316
with pytest.raises(TypeError, match=r"Argument sampler must not be a distributed sampler already"):
317317
DistributedProxySampler(DistributedSampler(sampler, num_replicas=num_replicas, rank=0))

0 commit comments

Comments
 (0)