Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/fp8/torchao/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def train_baseline():
model.train()

for batch in train_dataloader:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
batch = batch.to(device)
outputs = model(**batch)
loss = outputs.loss
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/fp8/torchao/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def train_baseline():
model.train()

for batch in train_dataloader:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
batch = batch.to(device)
outputs = model(**batch)
loss = outputs.loss
Expand Down
6 changes: 4 additions & 2 deletions benchmarks/fp8/torchao/non_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,15 @@ def train_baseline():
last_linear = name

func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)
model.to("cuda")
accelerator = Accelerator()
device = accelerator.device
model.to(device)
convert_to_float8_training(model, module_filter_fn=func)
base_model_results = evaluate_model(model, eval_dataloader, METRIC)
model.train()

for batch in train_dataloader:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
outputs = model(**batch)
loss = outputs.loss
loss.backward()
Expand Down
16 changes: 8 additions & 8 deletions docs/source/basic_tutorials/notebook.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ Before any training can be performed, an Accelerate config file must exist in th
accelerate config
```

However, if general defaults are fine and you are *not* running on a TPU, Accelerate has a utility to quickly write your GPU configuration into a config file via [`utils.write_basic_config`].
However, if general defaults are fine and you are *not* running on a TPU, Accelerate has a utility to quickly write your device configuration into a config file via [`utils.write_basic_config`].

The following code will restart Jupyter after writing the configuration, as CUDA code was called to perform this.
The following code will restart Jupyter after writing the configuration, as CUDA runtime or XPU runtime was called to perform this.

<Tip warning={true}>

CUDA can't be initialized more than once on a multi-GPU system. It's fine to debug in the notebook and have calls to CUDA, but in order to finally train a full cleanup and restart will need to be performed.
CUDA and XPU can't be initialized more than once on a multi-device system. It's fine to debug in the notebook and have calls to CUDA/XPU, but in order to finally train a full cleanup and restart will need to be performed.

</Tip>

Expand Down Expand Up @@ -462,15 +462,15 @@ accelerate launch

## Debugging

A common issue when running the `notebook_launcher` is receiving a CUDA has already been initialized issue. This usually stems
from an import or prior code in the notebook that makes a call to the PyTorch `torch.cuda` sublibrary. To help narrow down what went wrong,
A common issue when running the `notebook_launcher` is receiving a CUDA/XPU has already been initialized issue. This usually stems
from an import or prior code in the notebook that makes a call to the PyTorch `torch.cuda` or `torch.xpu` sublibrary. To help narrow down what went wrong,
you can launch the `notebook_launcher` with `ACCELERATE_DEBUG_MODE=yes` in your environment and an additional check
will be made when spawning that a regular process can be created and utilize CUDA without issue. (Your CUDA code can still be ran afterwards).
will be made when spawning that a regular process can be created and utilize CUDA/XPU without issue. (Your CUDA/XPU code can still be ran afterwards).

## Conclusion

This notebook showed how to perform distributed training from inside of a Jupyter Notebook. Some key notes to remember:

- Make sure to save any code that use CUDA (or CUDA imports) for the function passed to [`notebook_launcher`]
- Set the `num_processes` to be the number of devices used for training (such as number of GPUs, CPUs, TPUs, etc)
- Make sure to save any code that use CUDA/XPU (or CUDA/XPU imports) for the function passed to [`notebook_launcher`]
- Set the `num_processes` to be the number of devices used for training (such as number of GPUs, XPUs, CPUs, TPUs, etc)
- If using the TPU, declare your model outside the training loop function
2 changes: 1 addition & 1 deletion docs/source/usage_guides/model_size_estimator.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ rendered properly in your Markdown viewer.

# Model memory estimator

One very difficult aspect when exploring potential models to use on your machine is knowing just how big of a model will *fit* into memory with your current graphics card (such as loading the model onto CUDA).
One very difficult aspect when exploring potential models to use on your machine is knowing just how big of a model will *fit* into memory with your current device (such as loading the model onto CUDA or XPU).

To help alleviate this, Accelerate has a CLI interface through `accelerate estimate-memory`. This tutorial will
help walk you through using it, what to expect, and at the end link to the interactive demo hosted on the Hub which will
Expand Down
4 changes: 2 additions & 2 deletions examples/by_feature/automatic_gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def training_function(config, args):

# New Code #
# We use the `find_executable_batch_size` decorator, passing in the desired observed batch size
# to train on. If a CUDA OOM error occurs, it will retry this loop cutting the batch size in
# to train on. If a device OOM error occurs, it will retry this loop cutting the batch size in
# half each time. From this, we can calculate the number of gradient accumulation steps needed
# and modify the Accelerator object as a result
@find_executable_batch_size(starting_batch_size=int(observed_batch_size))
Expand Down Expand Up @@ -234,7 +234,7 @@ def main():
parser.add_argument("--cpu", action="store_true", help="If passed, will train on the CPU.")
args = parser.parse_args()
# New Code #
# We modify the starting batch size to be an observed batch size of 256, to guarentee an initial CUDA OOM
# We modify the starting batch size to be an observed batch size of 256, to guarentee an initial device OOM
config = {"lr": 2e-5, "num_epochs": 3, "seed": 42, "batch_size": 256}
training_function(config, args)

Expand Down
6 changes: 5 additions & 1 deletion examples/complete_cv_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor

from accelerate import Accelerator, DataLoaderConfiguration
from accelerate.utils import is_xpu_available


########################################################################
Expand Down Expand Up @@ -125,7 +126,10 @@ def training_function(config, args):
# Set the seed before splitting the data.
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
elif is_xpu_available():
torch.xpu.manual_seed_all(seed)

# Split our filenames between train and validation
random_perm = np.random.permutation(len(file_names))
Expand Down
15 changes: 9 additions & 6 deletions examples/inference/pippy/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@

from accelerate import PartialState, prepare_pippy
from accelerate import __version__ as accelerate_version
from accelerate.test_utils import torch_device
from accelerate.utils import set_seed


synchronize_func = getattr(torch, torch_device, torch.cuda).synchronize

if version.parse(accelerate_version) > version.parse("0.33.0"):
raise RuntimeError(
"Using encoder/decoder models is not supported with the `torch.pipelining` integration or accelerate>=0.34.0. "
Expand Down Expand Up @@ -70,25 +73,25 @@

# The model expects a tuple during real inference
# with the data on the first device
args = (example_inputs["input_ids"].to("cuda:0"), example_inputs["decoder_input_ids"].to("cuda:0"))
args = (example_inputs["input_ids"].to(0), example_inputs["decoder_input_ids"].to(0))

# Take an average of 5 times
# Measure first batch
torch.cuda.synchronize()
synchronize_func()
start_time = time.time()
with torch.no_grad():
output = model(*args)
torch.cuda.synchronize()
synchronize_func()
end_time = time.time()
first_batch = end_time - start_time

# Now that CUDA is init, measure after
torch.cuda.synchronize()
# Now that device is init, measure after
synchronize_func()
start_time = time.time()
for i in range(5):
with torch.no_grad():
output = model(*args)
torch.cuda.synchronize()
synchronize_func()
end_time = time.time()

# The outputs are only on the final process by default
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1836,7 +1836,7 @@ def prepare_model(
if self.multi_device and not (self.parallelism_config and self.parallelism_config.tp_enabled):
if model_has_dtensor(model):
raise ValueError(
"Your model contains `DTensor` parameters, which is incompatible with DDP. Maybe you loaded your model with `device_map='auto'`? Specify `device_map='cuda'` or 'cpu' instead."
"Your model contains `DTensor` parameters, which is incompatible with DDP. Maybe you loaded your model with `device_map='auto'`? Specify `device_map='cuda'` or 'xpu' or 'cpu' instead."
)
if any(p.requires_grad for p in model.parameters()):
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
Expand Down
13 changes: 7 additions & 6 deletions src/accelerate/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,9 @@ def init_on_device(device: torch.device, include_buffers: Optional[bool] = None)
import torch.nn as nn
from accelerate import init_on_device

# init model on specified device(e.g., "cuda", "xpu" and so on)
with init_on_device(device=torch.device("cuda")):
tst = nn.Linear(100, 100) # on `cuda` device
tst = nn.Linear(100, 100) # on specified device
```
"""
if include_buffers is None:
Expand Down Expand Up @@ -231,17 +232,17 @@ def cpu_offload_with_hook(
The model to offload.
execution_device(`str`, `int` or `torch.device`, *optional*):
The device on which the model should be executed. Will default to the MPS device if it's available, then
GPU 0 if there is a GPU, and finally to the CPU.
device 0 if there is an accelerator device, and finally to the CPU.
prev_module_hook (`UserCpuOffloadHook`, *optional*):
The hook sent back by this function for a previous model in the pipeline you are running. If passed, its
offload method will be called just before the forward of the model to which this hook is attached.

Example:

```py
model_1, hook_1 = cpu_offload_with_hook(model_1, cuda_device)
model_2, hook_2 = cpu_offload_with_hook(model_2, cuda_device, prev_module_hook=hook_1)
model_3, hook_3 = cpu_offload_with_hook(model_3, cuda_device, prev_module_hook=hook_2)
model_1, hook_1 = cpu_offload_with_hook(model_1, device)
model_2, hook_2 = cpu_offload_with_hook(model_2, device, prev_module_hook=hook_1)
model_3, hook_3 = cpu_offload_with_hook(model_3, device, prev_module_hook=hook_2)

hid_1 = model_1(input)
for i in range(50):
Expand Down Expand Up @@ -446,7 +447,7 @@ def dispatch_model(
# Attaching the hook may break tied weights, so we retie them
retie_parameters(model, tied_params)

# add warning to cuda and to method
# add warning on `to` method
def add_warning(fn, model):
@wraps(fn)
def wrapper(*args, **kwargs):
Expand Down
4 changes: 3 additions & 1 deletion src/accelerate/commands/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ def estimate_command_parser(subparsers=None):
if subparsers is not None:
parser = subparsers.add_parser("estimate-memory")
else:
parser = CustomArgumentParser(description="Model size estimator for fitting a model onto CUDA memory.")
parser = CustomArgumentParser(
description="Model size estimator for fitting a model onto device(e.g. cuda, xpu) memory."
)

parser.add_argument("model_name", type=str, help="The model name on the Hugging Face Hub.")
parser.add_argument(
Expand Down
5 changes: 4 additions & 1 deletion src/accelerate/utils/ao.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,12 @@ def convert_model_to_fp8_ao(

```python
from accelerate.utils.ao import convert_model_to_fp8_ao
from accelerate import Accelerator

accelerator = Accelerator(

model = MyModel()
model.to("cuda")
model.to(accelerator.device)
convert_to_float8_training(model)

model.train()
Expand Down
4 changes: 2 additions & 2 deletions src/accelerate/utils/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def _replace_with_bnb_layers(

Returns the converted model and a boolean that indicates if the conversion has been successful or not.
"""
# bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
# bitsandbytes will initialize device(e.g. CUDA, XPU) on import, so it needs to be imported lazily
import bitsandbytes as bnb

has_been_replaced = False
Expand Down Expand Up @@ -425,7 +425,7 @@ def get_keys_to_not_convert(model):

def has_4bit_bnb_layers(model):
"""Check if we have `bnb.nn.Linear4bit` or `bnb.nn.Linear8bitLt` layers inside our model"""
# bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
# bitsandbytes will initialize device(e.g. CUDA, XPU) on import, so it needs to be imported lazily
import bitsandbytes as bnb

for m in model.modules():
Expand Down
22 changes: 11 additions & 11 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def is_peft_model(model):

def check_device_same(first_device, second_device):
"""
Utility method to check if two `torch` devices are similar. When dealing with CUDA devices, torch throws `False`
for `torch.device("cuda") == torch.device("cuda:0")` whereas they should be the same
Utility method to check if two `torch` devices are similar. When dealing torch accelerator devices(e.g. cuda, xpu),
torch throws `False` for `torch.device("cuda") == torch.device("cuda:0")` whereas they should be the same

Args:
first_device (`torch.device`):
Expand All @@ -94,12 +94,12 @@ def check_device_same(first_device, second_device):
return False

if first_device.type != "cpu" and first_device.index is None:
# In case the first_device is a cuda device and have
# In case the first_device is an torch accelerator device(e.g. cuda, xpu) and have
# the index attribute set to `None`, default it to `0`
first_device = torch.device(first_device.type, index=0)

if second_device.type != "cpu" and second_device.index is None:
# In case the second_device is a cuda device and have
# In case the second_device is an torch accelerator device(e.g. cuda, xpu) and have
# the index attribute set to `None`, default it to `0`
second_device = torch.device(second_device.type, index=0)

Expand Down Expand Up @@ -307,7 +307,7 @@ def set_module_tensor_to_device(

device_quantization = None
with torch.no_grad():
# leave it on cpu first before moving them to cuda
# leave it on cpu first before moving them to device
# # fix the case where the device is meta, we don't want to put it on cpu because there is no data =0
if (
param is not None
Expand Down Expand Up @@ -385,23 +385,23 @@ def set_module_tensor_to_device(
and str(module.weight.device) != "meta"
):
# quantize only if necessary
device_index = torch.device(device).index if torch.device(device).type == "cuda" else None
device_index = torch.device(device).index if torch.device(device).type in ["cuda", "xpu"] else None
if not getattr(module.weight, "SCB", None) and device_index is not None:
if module.bias is not None and module.bias.device.type != "meta":
# if a bias exists, we need to wait until the bias is set on the correct device
module = module.cuda(device_index)
module = module.to(device_index)
elif module.bias is None:
# if no bias exists, we can quantize right away
module = module.cuda(device_index)
module = module.to(device_index)
elif (
module.__class__.__name__ == "Linear4bit"
and getattr(module.weight, "quant_state", None) is None
and str(module.weight.device) != "meta"
):
# quantize only if necessary
device_index = torch.device(device).index if torch.device(device).type == "cuda" else None
device_index = torch.device(device).index if torch.device(device).type in ["cuda", "xpu"] else None
if not getattr(module.weight, "quant_state", None) and device_index is not None:
module.weight = module.weight.cuda(device_index)
module.weight = module.weight.to(device_index)

# clean pre and post forward hook
if clear_cache and device not in ("cpu", "meta"):
Expand Down Expand Up @@ -749,7 +749,7 @@ def get_max_memory(max_memory: Optional[dict[Union[int, str], Union[int, str]]]

if max_memory is None:
max_memory = {}
# Make sure CUDA is initialized on each GPU to have the right memory info.
# Make sure device is initialized on each device to have the right memory info.
if is_npu_available():
for i in range(torch.npu.device_count()):
try:
Expand Down
11 changes: 7 additions & 4 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,21 +244,24 @@ def test_free_memory_dereferences_prepared_components(self):
assert len(accelerator._schedulers) == 0
assert len(accelerator._dataloaders) == 0

# The less-than comes *specifically* from CUDA CPU things/won't be present on CPU builds
# The less-than comes *specifically* from device CPU things/won't be present on CPU builds
assert free_cpu_ram_after <= free_cpu_ram_before

@require_non_torch_xla
def test_env_var_device(self):
"""Tests that setting the torch device with ACCELERATE_TORCH_DEVICE overrides default device."""
PartialState._reset_state()

# Mock torch.cuda.set_device to avoid an exception as the device doesn't exist
# Mock torch's set_device call to avoid an exception as the device doesn't exist
def noop(*args, **kwargs):
pass

with patch("torch.cuda.set_device", noop), patch_environment(ACCELERATE_TORCH_DEVICE="cuda:64"):
with (
patch(f"torch.{torch_device}.set_device", noop),
patch_environment(ACCELERATE_TORCH_DEVICE=f"{torch_device}:64"),
):
accelerator = Accelerator()
assert str(accelerator.state.device) == "cuda:64"
assert str(accelerator.state.device) == f"{torch_device}:64"

@parameterized.expand([(True, True), (True, False), (False, False)], name_func=parameterized_custom_name_func)
def test_save_load_model(self, use_safetensors, tied_weights):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def test_dispatch_model_tied_weights_memory(self):
"linear4": device_0,
}

# Just to initialize CUDA context.
# Just to initialize device context.
a = torch.rand(5).to(device_0) # noqa: F841

free_memory_bytes = torch_accelerator_module.mem_get_info(device_0)[0]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_memory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


def raise_fake_out_of_memory():
raise RuntimeError("CUDA out of memory.")
raise RuntimeError(f"{torch_device.upper()} out of memory.")


class ModelForTest(nn.Module):
Expand Down
Loading