Skip to content

Module 19: Add Apple Silicon (MPS) device support and fix ART device handling #30

Description

@aksharguptaudacity

Summary

Every device-selection site in Module 19 uses the two-way device = "cuda" if torch.cuda.is_available() else "cpu". On Apple Silicon (a meaningful slice of learners) this silently resolves to CPU even though MPS is available, so training and evaluation run far slower than necessary. Separately, the helper functions that bridge to ART build inference tensors on the caller-passed device rather than the model's live device, which breaks once an MPS path is enabled (ART moves the model to CPU during an attack). Both are covered by existing conventions (STYLE-005, STYLE-006).

Evidence

Confirmed by running the notebooks on an M-series Mac: torch.backends.mps.is_available() is True, but the committed code picks "cpu".

Two-way device selection (should be the three-way cascade):

  • demo/notebooks/robustness_evaluation_pipeline_demo.ipynb (setup cell)
  • demo/scripts/run_robustness_demo.py:41
  • demo/targets/logistics_cifar10_resnet18.py:45
  • exercise/solution/notebooks/traffic_sign_robustness_assessment.ipynb (setup cell)
  • exercise/starter/notebooks/traffic_sign_robustness_assessment.ipynb (setup cell)
  • exercise/solution/scripts/run_assessment.py:46
  • exercise/starter/scripts/run_assessment.py:46

ART interop (should query the model's live device, per STYLE-006):

  • demo/src/robustness_eval_utils.py::predict — uses model(batch.to(device)) with the caller-passed device.
  • exercise/solution/src/traffic_sign_robustness_utils.py::predict (≈line 220) — same.

Proposed change

  1. Replace every two-way selection with the STYLE-005 cascade (cuda → mps → cpu, setting PYTORCH_ENABLE_MPS_FALLBACK=1 on the MPS branch).
  2. In the ART-bridging predict helpers, derive the inference device from next(model.parameters()).device rather than the caller's device arg (STYLE-006), so the helpers stay correct after ART silently relocates the model to CPU.

Note: ART attacks themselves still run on CPU (ART has no MPS path) — that's expected; the fix is about training/eval and not crashing on a device mismatch.

Acceptance criteria

  • All listed sites use the three-way cascade; no two-way cuda or cpu remains in the module.
  • Both notebooks run end-to-end on MPS without device-mismatch errors.
  • predict helpers use the model's live device.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions