Skip to content

Migrate torch_xla.device() to torch.device('xla') #9253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions API_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ print(t)

This code should look familiar. PyTorch/XLA uses the same interface as regular
PyTorch with a few additions. Importing `torch_xla` initializes PyTorch/XLA, and
`torch_xla.device()` returns the current XLA device. This may be a CPU or TPU
`torch.device('xla')` returns the current XLA device. This may be a CPU or TPU
depending on your environment.

## XLA Tensors are PyTorch Tensors
Expand All @@ -47,7 +47,7 @@ Or used with neural network modules:

```python
l_in = torch.randn(10, device='xla')
linear = torch.nn.Linear(10, 20).to(torch_xla.device())
linear = torch.nn.Linear(10, 20).to('xla')
l_out = linear(l_in)
print(l_out)
```
Expand Down Expand Up @@ -112,7 +112,7 @@ train_loader = xu.SampleGenerator(
torch.zeros(batch_size, dtype=torch.int64)),
sample_count=60000 // batch_size // xr.world_size())

device = torch_xla.device() # Get the XLA device (TPU).
device = torch.device('xla') # Get the XLA device (TPU).
model = MNIST().train().to(device) # Create a model and move it to the device.
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
Expand Down Expand Up @@ -169,7 +169,7 @@ def _mp_fn(index):
index: Index of the process.
"""

device = torch_xla.device() # Get the device assigned to this process.
device = torch.device('xla') # Get the device assigned to this process.
# Wrap the loader for multi-device.
mp_device_loader = pl.MpDeviceLoader(train_loader, device)

Expand Down Expand Up @@ -197,7 +197,7 @@ single device snippet. Let's go over then one by one.
- `torch_xla.launch()`
- Creates the processes that each run an XLA device.
- This function is a wrapper of multithreading spawn to allow user run the script with torchrun command line also. Each process will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device.
- Note that if you print the `torch_xla.device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details).
- Note that if you print the `torch.device('xla')` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details).
- `MpDeviceLoader`
- Loads the training data onto each device.
- `MpDeviceLoader` can wrap on a torch dataloader. It can preload the data to the device and overlap the dataloading with device execution to improve the performance.
Expand Down Expand Up @@ -290,7 +290,7 @@ import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = torch_xla.device()
device = torch.device('xla')

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,12 @@ To update your existing training loop, make the following changes:
...

+ # Move the model paramters to your XLA device
+ model.to(torch_xla.device())
+ model.to('xla')

for inputs, labels in train_loader:
+ with torch_xla.step():
+ # Transfer data to the XLA device. This happens asynchronously.
+ inputs, labels = inputs.to(torch_xla.device()), labels.to(torch_xla.device())
+ inputs, labels = inputs.to('xla'), labels.to('xla')
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
Expand Down Expand Up @@ -196,7 +196,7 @@ If you're using `DistributedDataParallel`, make the following changes:
+ # Rank and world size are inferred from the XLA device runtime
+ dist.init_process_group("xla", init_method='xla://')
+
+ model.to(torch_xla.device())
+ model.to('xla')
+ ddp_model = DDP(model, gradient_as_bucket_view=True)

- model = model.to(rank)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _default_iter_fn(self, benchmark_experiment: BenchmarkExperiment,

def _pure_wall_time_iter_fn(self, benchmark_experiment: BenchmarkExperiment,
benchmark_model: BenchmarkModel, input_tensor):
device = torch_xla.device() if benchmark_experiment.xla else 'cuda'
device = torch.device('xla') if benchmark_experiment.xla else 'cuda'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we may have a linting rule about single vs double quotes - whatever is consistent is fine.

sync_fn = xm.wait_device_ops if benchmark_experiment.xla else torch.cuda.synchronize
timing, output = bench.do_bench(
lambda: benchmark_model.model_iter_fn(
Expand Down
6 changes: 3 additions & 3 deletions contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T19:30:28.607393Z",
Expand All @@ -210,7 +210,7 @@
"lock = mp.Manager().Lock()\n",
"\n",
"def print_device(i, lock):\n",
" device = torch_xla.device()\n",
" device = torch.device('xla')\n",
" with lock:\n",
" print('process', i, device)"
]
Expand Down Expand Up @@ -454,7 +454,7 @@
"import torch_xla.experimental.pjrt_backend # Required for torch.distributed on TPU v2 and v3\n",
"\n",
"def toy_model(index, lock):\n",
" device = torch_xla.device()\n",
" device = torch.device('xla')\n",
" dist.init_process_group('xla', init_method='xla://')\n",
"\n",
" # Initialize a basic toy model\n",
Expand Down
2 changes: 1 addition & 1 deletion contrib/kaggle/pytorch-xla-2-0-on-kaggle.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@
"\n",
"pipeline = DiffusionPipeline.from_pretrained(\"runwayml/stable-diffusion-v1-5\")\n",
"# Move the model to the first TPU core\n",
"pipeline = pipeline.to(torch_xla.device())"
"pipeline = pipeline.to('xla')"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/source/learn/_pjrt.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ import torch_xla.distributed.xla_backend


def _mp_fn(index):
device = torch_xla.device()
device = torch.device('xla')
- dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size())
+ dist.init_process_group('xla', init_method='xla://')

Expand Down
4 changes: 2 additions & 2 deletions docs/source/learn/eager.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import torch
import torch_xla
import torchvision

device = torch_xla.device()
device = torch.device('xla')
model = torchvision.models.resnet18().to(device)
input = torch.randn(64, 3, 224, 224).to(device)

Expand Down Expand Up @@ -71,7 +71,7 @@ import torchvision
# Run ops eagerly by default
torch_xla.experimental.eager_mode(True)

device = torch_xla.device()
device = torch.device('xla')
model = torchvision.models.resnet18().to(device)

# Mark the function to be compiled
Expand Down
12 changes: 6 additions & 6 deletions docs/source/learn/pytorch-on-xla-devices.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ print(t)

This code should look familiar. PyTorch/XLA uses the same interface as
regular PyTorch with a few additions. Importing `torch_xla` initializes
PyTorch/XLA, and `torch_xla.device()` returns the current XLA device. This
PyTorch/XLA, and `torch.device('xla')` returns the current XLA device. This
may be a CPU or TPU depending on your environment.

## XLA Tensors are PyTorch Tensors
Expand All @@ -47,7 +47,7 @@ Or used with neural network modules:

``` python
l_in = torch.randn(10, device='xla')
linear = torch.nn.Linear(10, 20).to(torch_xla.device())
linear = torch.nn.Linear(10, 20).to('xla')
l_out = linear(l_in)
print(l_out)
```
Expand Down Expand Up @@ -81,7 +81,7 @@ The following snippet shows a network training on a single XLA device:
``` python
import torch_xla.core.xla_model as xm

device = torch_xla.device()
device = torch.device('xla')
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
Expand Down Expand Up @@ -120,7 +120,7 @@ import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

def _mp_fn(index):
device = torch_xla.device()
device = torch.device('xla')
mp_device_loader = pl.MpDeviceLoader(train_loader, device)

model = MNIST().train().to(device)
Expand Down Expand Up @@ -148,7 +148,7 @@ previous single device snippet. Let's go over then one by one.
will only be able to access the device assigned to the current
process. For example on a TPU v4-8, there will be 4 processes
being spawn up and each process will own a TPU device.
- Note that if you print the `torch_xla.device()` on each process you
- Note that if you print the `torch.device('xla')` on each process you
will see `xla:0` on all devices. This is because each process
can only see one device. This does not mean multi-process is not
functioning. The only execution is with PJRT runtime on TPU v2
Expand Down Expand Up @@ -283,7 +283,7 @@ import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = torch_xla.device()
device = torch.device('xla')

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)
Expand Down
8 changes: 4 additions & 4 deletions docs/source/learn/xla-overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ repo. contains examples for training and serving many LLM and diffusion models.

General guidelines to modify your code:

- Replace `cuda` with `torch_xla.device()`
- Replace `cuda` with `torch.device('xla')`
- Remove progress bar, printing that would access the XLA tensor
values
- Reduce logging and callbacks that would access the XLA tensor values
Expand Down Expand Up @@ -227,7 +227,7 @@ tutorial, but you can pass the `device` value to the function as well.

``` python
import torch_xla.core.xla_model as xm
self.device = torch_xla.device()
self.device = torch.device('xla')
```

Another place in the code that has cuda specific code is DDIM scheduler.
Expand All @@ -244,7 +244,7 @@ if attr.device != torch.device("cuda"):
with

``` python
device = torch_xla.device()
device = torch.device('xla')
attr = attr.to(torch.device(device))
```

Expand Down Expand Up @@ -339,7 +339,7 @@ with the following lines:

``` python
import torch_xla.core.xla_model as xm
device = torch_xla.device()
device = torch.device('xla')
pipe.to(device)
```

Expand Down
14 changes: 7 additions & 7 deletions docs/source/perf/amp.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ from torch_xla.amp import syncfree
import torch_xla.core.xla_model as xm

# Creates model and optimizer in default precision
model = Net().to(torch_xla.device())
model = Net().to('xla')
# Pytorch/XLA provides sync-free optimizers for improved performance
optimizer = syncfree.SGD(model.parameters(), ...)

for input, target in data:
optimizer.zero_grad()

# Enables autocasting for the forward pass
with autocast(torch_xla.device()):
with autocast(torch.device('xla')):
output = model(input)
loss = loss_fn(output, target)

Expand All @@ -36,7 +36,7 @@ for input, target in data:
xm.optimizer_step.(optimizer)
```

`autocast(torch_xla.device())` aliases `torch.autocast('xla')` when the XLA
`autocast(torch.device('xla'))` aliases `torch.autocast('xla')` when the XLA
Device is a TPU. Alternatively, if a script is only used with TPUs, then
`torch.autocast('xla', dtype=torch.bfloat16)` can be directly used.

Expand Down Expand Up @@ -106,7 +106,7 @@ from torch_xla.amp import syncfree
import torch_xla.core.xla_model as xm

# Creates model and optimizer in default precision
model = Net().to(torch_xla.device())
model = Net().to('xla')
# Pytorch/XLA provides sync-free optimizers for improved performance
optimizer = syncfree.SGD(model.parameters(), ...)
scaler = GradScaler()
Expand All @@ -115,7 +115,7 @@ for input, target in data:
optimizer.zero_grad()

# Enables autocasting for the forward pass
with autocast(torch_xla.device()):
with autocast(torch.device('xla')):
output = model(input)
loss = loss_fn(output, target)

Expand All @@ -127,12 +127,12 @@ for input, target in data:
scaler.update()
```

`autocast(torch_xla.device())` aliases `torch.cuda.amp.autocast()` when the
`autocast(torch.device('xla'))` aliases `torch.cuda.amp.autocast()` when the
XLA Device is a CUDA device (XLA:GPU). Alternatively, if a script is
only used with CUDA devices, then `torch.cuda.amp.autocast` can be
directly used, but requires `torch` is compiled with `cuda` support for
datatype of `torch.bfloat16`. We recommend using
`autocast(torch_xla.device())` on XLA:GPU as it does not require
`autocast(torch.device('xla'))` on XLA:GPU as it does not require
`torch.cuda` support for any datatypes, including `torch.bfloat16`.

### AMP for XLA:GPU Best Practices
Expand Down
2 changes: 1 addition & 1 deletion docs/source/perf/ddp.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def demo_basic(rank):
setup(rank, world_size)

# create model and move it to XLA device
device = torch_xla.device()
device = torch.device('xla')
model = ToyModel().to(device)
ddp_model = DDP(model, gradient_as_bucket_view=True)

Expand Down
8 changes: 4 additions & 4 deletions docs/source/perf/dynamo.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import torch
import torch_xla.core.xla_model as xm

def add(a, b):
a_xla = a.to(torch_xla.device())
b_xla = b.to(torch_xla.device())
a_xla = a.to('xla')
b_xla = b.to('xla')
return a_xla + b_xla

compiled_code = torch.compile(add, backend='openxla')
Expand All @@ -41,7 +41,7 @@ import torchvision
import torch_xla.core.xla_model as xm

def eval_model(loader):
device = torch_xla.device()
device = torch.device('xla')
xla_resnet18 = torchvision.models.resnet18().to(device)
xla_resnet18.eval()
dynamo_resnet18 = torch.compile(
Expand Down Expand Up @@ -129,7 +129,7 @@ def train_model(model, data, target, optimizer):
return pred

def train_model_main(loader):
device = torch_xla.device()
device = torch.device('xla')
xla_resnet18 = torchvision.models.resnet18().to(device)
xla_resnet18.train()
dynamo_train_model = torch.compile(
Expand Down
4 changes: 2 additions & 2 deletions docs/source/perf/fori_loop.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ result = while_loop(cond_fn, body_fn, init)
>>> from torch._higher_order_ops.while_loop import while_loop
>>> import torch_xla.core.xla_model as xm
>>>
>>> device = torch_xla.device()
>>> device = torch.device('xla')
>>>
>>> def cond_fn(iteri, x):
... return iteri > 0
Expand Down Expand Up @@ -60,7 +60,7 @@ with similar logic: cumulative plus 1 for ten times:
>>> import torch_xla
>>> import torch_xla.core.xla_model as xm
>>>
>>> device = torch_xla.device()
>>> device = torch.device('xla')
>>>
>>> init_val = torch.tensor(1, device=device)
>>> iteri = torch.tensor(50, device=device)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/perf/quantized_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ scaler = torch.randn((N_OUTPUT_FEATURES,), dtype=torch.bfloat16)
# Call with torch CPU tensor (For debugging purpose)
matmul_output = torch.ops.xla.quantized_matmul(x, w_int, scaler)

device = torch_xla.device()
device = torch.device('xla')
x_xla = x.to(device)
w_int_xla = w_int.to(device)
scaler_xla = scaler.to(device)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/perf/spmd_basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))

t = torch.randn(8, 4).to(torch_xla.device())
t = torch.randn(8, 4).to('xla')

# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = ('data', 'model')
Expand Down
2 changes: 1 addition & 1 deletion examples/train_decoder_only_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self,
torch.zeros(self.batch_size, self.seq_len, dtype=torch.int64)),
sample_count=self.train_dataset_len // self.batch_size)

self.device = torch_xla.device()
self.device = torch.device('xla')
self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device)
self.model = decoder_cls(self.config).to(self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0001)
Expand Down
2 changes: 1 addition & 1 deletion examples/train_resnet_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def train_loop_fn(self, loader, epoch):
for step, (data, target) in enumerate(loader):
self.optimizer.zero_grad()
# Enables autocasting for the forward pass
with autocast(torch_xla.device()):
with autocast(torch.device('xla')):
output = self.model(data)
loss = self.loss_fn(output, target)
# TPU amp uses bf16 hence gradient scaling is not necessary. If runnign with XLA:GPU
Expand Down
Loading
Loading