Skip to content

Add support for other pytorch device types, including MPS#1445

Open
adamobeng wants to merge 1 commit intojunyanz:masterfrom
adamobeng:mps
Open

Add support for other pytorch device types, including MPS#1445
adamobeng wants to merge 1 commit intojunyanz:masterfrom
adamobeng:mps

Conversation

@adamobeng
Copy link

Fixes (#1441)

Change list

  1. Add command line arguments --device_type and --device_ids which allow torch backend and device ordinals to be specified
  2. Make code specific to GPUs/cuda device-agnostic (in particular by using a list of torch devices rather than GPU ids)
  3. Maintain support for --gpu_ids argument with some special logic (it would be cleaner but non-backwards compatible to remove it)
  4. Add some tests of the argument parsing

Testing

  • Unit tests pass
  • Results generated with python train.py --dataroot ./datasets/maps --name maps --model pix2pix --direction AtoB --device_type mps seem reasonable.
  • Suggestions on more rigorous testing are welcomed!

NB: On my specific setup, loading a model trained with MPS fails with RuntimeError: don't know how to restore data location of torch.storage._UntypedStorage (tagged with mps:0), but it seems like this is a known and intermittent issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant