Skip to content

Commit 3bffed5

Browse files
Neuron integration (#3935)
* feat: add `is_neuron_available` utility function * feat: add `DistributedType.MUTLI_NEURON` * feat: adapt state classes for Neuron * feat: add support for Neuron in Accelerator * feat: adapted local sgd to fail with Neuron cores * feat: add proper constant for Neuron * feat: adapt dataclasses for Neuron * feat: adapt environment for Neuron * feat: adapt mixed precision utils for Neuron * feat: integrate randomness features for Neuron * feat: save and load Neuron RNG state * feat: add clear cache method for Neuron * feat: add support for Neuron device with launch * feat: add support for Neuron devices in commands * feat: adapt big modeling for Neuron * feat: adapt test utils for Neuron * feat: add "neuron" handling in set_device
1 parent ae36c74 commit 3bffed5

File tree

22 files changed

+136
-3
lines changed

22 files changed

+136
-3
lines changed

src/accelerate/accelerator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,7 @@ def multi_device(self):
664664
DistributedType.MULTI_NPU,
665665
DistributedType.MULTI_XPU,
666666
DistributedType.MULTI_HPU,
667+
DistributedType.MULTI_NEURON,
667668
)
668669

669670
@property

src/accelerate/big_modeling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
is_bnb_available,
4343
is_mlu_available,
4444
is_musa_available,
45+
is_neuron_available,
4546
is_npu_available,
4647
is_sdaa_available,
4748
is_xpu_available,
@@ -477,6 +478,8 @@ def wrapper(*args, **kwargs):
477478
model.musa = add_warning(model.musa, model)
478479
elif is_xpu_available():
479480
model.xpu = add_warning(model.xpu, model)
481+
elif is_neuron_available():
482+
model.neuron = add_warning(model.neuron, model)
480483
else:
481484
model.cuda = add_warning(model.cuda, model)
482485

@@ -499,6 +502,8 @@ def wrapper(*args, **kwargs):
499502
device = f"sdaa:{device}"
500503
elif is_musa_available() and isinstance(device, int):
501504
device = f"musa:{device}"
505+
elif is_neuron_available() and isinstance(device, int):
506+
device = f"neuron:{device}"
502507
if device != "disk":
503508
model.to(device)
504509
else:

src/accelerate/checkpointing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
is_hpu_available,
3636
is_mlu_available,
3737
is_musa_available,
38+
is_neuron_available,
3839
is_sdaa_available,
3940
is_torch_version,
4041
is_torch_xla_available,
@@ -167,6 +168,8 @@ def save_accelerator_state(
167168
states["torch_musa_manual_seed"] = torch.musa.get_rng_state_all()
168169
if is_hpu_available():
169170
states["torch_hpu_manual_seed"] = torch.hpu.get_rng_state_all()
171+
if is_neuron_available():
172+
states["torch_neuron_manual_seed"] = torch.neuron.get_rng_state_all()
170173
if is_cuda_available():
171174
states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all()
172175
if is_torch_xla_available():
@@ -302,6 +305,8 @@ def load_accelerator_state(
302305
torch.musa.set_rng_state_all(states["torch_musa_manual_seed"])
303306
elif is_hpu_available():
304307
torch.hpu.set_rng_state_all(states["torch_hpu_manual_seed"])
308+
elif is_neuron_available():
309+
torch.neuron.set_rng_state_all(states["torch_neuron_manual_seed"])
305310
else:
306311
torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"])
307312
if is_torch_xla_available():

src/accelerate/commands/config/cluster.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
is_mps_available,
2727
is_msamp_available,
2828
is_musa_available,
29+
is_neuron_available,
2930
is_npu_available,
3031
is_sdaa_available,
3132
is_torchao_available,
@@ -68,6 +69,7 @@ def get_cluster_input():
6869
"multi-MLU",
6970
"multi-SDAA",
7071
"multi-MUSA",
72+
"multi-NEURON",
7173
"TPU",
7274
],
7375
_convert_distributed_mode,
@@ -92,6 +94,7 @@ def get_cluster_input():
9294
DistributedType.MULTI_XPU,
9395
DistributedType.MULTI_CPU,
9496
DistributedType.MULTI_HPU,
97+
DistributedType.MULTI_NEURON,
9598
]:
9699
num_machines = _ask_field(
97100
"How many different machines will you use (use more than 1 for multi-node training)? [1]: ",
@@ -218,6 +221,7 @@ def get_cluster_input():
218221
DistributedType.MULTI_MLU,
219222
DistributedType.MULTI_SDAA,
220223
DistributedType.MULTI_MUSA,
224+
DistributedType.MULTI_NEURON,
221225
DistributedType.NO,
222226
]
223227
and not use_mps
@@ -229,6 +233,9 @@ def get_cluster_input():
229233
error_message="Please enter yes or no.",
230234
)
231235
if use_deepspeed:
236+
if distributed_type is DistributedType.MULTI_NEURON:
237+
raise RuntimeError("DeepSpeed is not supported on Neuron devices.")
238+
232239
distributed_type = DistributedType.DEEPSPEED
233240
assert is_deepspeed_available(), (
234241
"DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source"
@@ -376,6 +383,7 @@ def get_cluster_input():
376383
DistributedType.MULTI_MUSA,
377384
DistributedType.MULTI_XPU,
378385
DistributedType.MULTI_HPU,
386+
DistributedType.MULTI_NEURON,
379387
]:
380388
use_fsdp = _ask_field(
381389
"Do you want to use FullyShardedDataParallel? [yes/NO]: ",
@@ -384,7 +392,10 @@ def get_cluster_input():
384392
error_message="Please enter yes or no.",
385393
)
386394
if use_fsdp:
395+
if distributed_type is DistributedType.MULTI_NEURON:
396+
raise NotImplementedError("FSDP is not currently supported on Neuron devices.")
387397
distributed_type = DistributedType.FSDP
398+
388399
if distributed_type == DistributedType.FSDP:
389400
fsdp_config["fsdp_version"] = _ask_options(
390401
"What should be your FSDP version? [2]: ",
@@ -624,10 +635,11 @@ def get_cluster_input():
624635
DistributedType.MULTI_SDAA,
625636
DistributedType.MULTI_MUSA,
626637
DistributedType.MULTI_NPU,
638+
DistributedType.MULTI_NEURON,
627639
DistributedType.XLA,
628640
]:
629641
machine_type = str(distributed_type).split(".")[1].replace("MULTI_", "")
630-
if machine_type == "TPU":
642+
if machine_type in ["TPU", "NEURON"]:
631643
machine_type += " cores"
632644
elif machine_type == "CPU":
633645
machine_type = "processes"
@@ -664,6 +676,7 @@ def get_cluster_input():
664676
DistributedType.MULTI_NPU,
665677
DistributedType.MULTI_XPU,
666678
DistributedType.MULTI_HPU,
679+
DistributedType.MULTI_NEURON,
667680
DistributedType.NO,
668681
]
669682
and not use_cpu
@@ -681,6 +694,8 @@ def get_cluster_input():
681694
machine_type = "XPU(s)"
682695
elif is_hpu_available():
683696
machine_type = "HPU(s)"
697+
elif is_neuron_available():
698+
machine_type = "Neuron cores"
684699
else:
685700
machine_type = "GPU(s)"
686701
gpu_ids = _ask_field(

src/accelerate/commands/config/config_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def _convert_distributed_mode(value):
8181
"MULTI_MLU",
8282
"MULTI_SDAA",
8383
"MULTI_MUSA",
84+
"MULTI_NEURON",
8485
"XLA",
8586
][value]
8687
)

src/accelerate/commands/config/default.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
is_hpu_available,
2323
is_mlu_available,
2424
is_musa_available,
25+
is_neuron_available,
2526
is_npu_available,
2627
is_sdaa_available,
2728
is_xpu_available,
@@ -118,6 +119,14 @@ def write_basic_config(mixed_precision="no", save_location: str = default_json_c
118119
config["distributed_type"] = "MULTI_NPU"
119120
else:
120121
config["distributed_type"] = "NO"
122+
elif is_neuron_available():
123+
num_neuron_cores = torch.neuron.device_count()
124+
config["num_processes"] = num_neuron_cores
125+
config["use_cpu"] = False
126+
if num_neuron_cores > 1:
127+
config["distributed_type"] = "MULTI_NEURON"
128+
else:
129+
config["distributed_type"] = "NO"
121130
else:
122131
num_xpus = 0
123132
config["use_cpu"] = True

src/accelerate/commands/env.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,14 @@
2626
from accelerate import __version__ as version
2727
from accelerate.commands.config import default_config_file, load_config_from_file
2828

29-
from ..utils import is_mlu_available, is_musa_available, is_npu_available, is_sdaa_available, is_xpu_available
29+
from ..utils import (
30+
is_mlu_available,
31+
is_musa_available,
32+
is_neuron_available,
33+
is_npu_available,
34+
is_sdaa_available,
35+
is_xpu_available,
36+
)
3037

3138

3239
def env_command_parser(subparsers=None):
@@ -52,6 +59,7 @@ def env_command(args):
5259
pt_sdaa_available = is_sdaa_available()
5360
pt_musa_available = is_musa_available()
5461
pt_npu_available = is_npu_available()
62+
pt_neuron_available = is_neuron_available()
5563

5664
accelerator = "N/A"
5765
if pt_cuda_available:
@@ -66,6 +74,8 @@ def env_command(args):
6674
accelerator = "MUSA"
6775
elif pt_npu_available:
6876
accelerator = "NPU"
77+
elif pt_neuron_available:
78+
accelerator = "NEURON"
6979

7080
accelerate_config = "Not found"
7181
# Get the default from the config file.
@@ -101,6 +111,8 @@ def env_command(args):
101111
info["SDAA type"] = torch.sdaa.get_device_name()
102112
elif pt_musa_available:
103113
info["MUSA type"] = torch.musa.get_device_name()
114+
elif pt_neuron_available:
115+
info["NEURON type"] = torch.neuron.get_device_name()
104116
elif pt_npu_available:
105117
info["CANN version"] = torch.version.cann
106118

src/accelerate/commands/launch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
is_hpu_available,
4343
is_mlu_available,
4444
is_musa_available,
45+
is_neuron_available,
4546
is_npu_available,
4647
is_rich_available,
4748
is_sagemaker_available,
@@ -1231,6 +1232,7 @@ def _validate_launch_command(args):
12311232
DistributedType.MULTI_MUSA,
12321233
DistributedType.MULTI_XPU,
12331234
DistributedType.MULTI_HPU,
1235+
DistributedType.MULTI_NEURON,
12341236
)
12351237
else False
12361238
)
@@ -1309,6 +1311,8 @@ def _validate_launch_command(args):
13091311
args.num_processes = torch.npu.device_count()
13101312
elif is_hpu_available():
13111313
args.num_processes = torch.hpu.device_count()
1314+
elif is_neuron_available():
1315+
args.num_processes = torch.neuron.device_count()
13121316
else:
13131317
args.num_processes = torch.cuda.device_count()
13141318
warned.append(f"\t`--num_processes` was set to a value of `{args.num_processes}`")
@@ -1324,6 +1328,7 @@ def _validate_launch_command(args):
13241328
or (is_mlu_available() and torch.mlu.device_count() > 1)
13251329
or (is_sdaa_available() and torch.sdaa.device_count() > 1)
13261330
or (is_musa_available() and torch.musa.device_count() > 1)
1331+
or (is_neuron_available() and torch.neuron.device_count() > 1)
13271332
or (torch.cuda.is_available() and torch.cuda.device_count() > 1)
13281333
)
13291334
):

src/accelerate/local_sgd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(self, accelerator: Accelerator, model: torch.nn.Module, local_sgd_s
7575
DistributedType.MULTI_SDAA,
7676
DistributedType.MULTI_MUSA,
7777
DistributedType.MULTI_NPU,
78+
DistributedType.MULTI_NEURON,
7879
]:
7980
raise NotImplementedError("LocalSGD is supported only for CPUs and GPUs (no DeepSpeed or MegatronLM)")
8081
self.enabled = enabled and accelerator.distributed_type != DistributedType.NO

src/accelerate/state.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
is_mlu_available,
4343
is_mps_available,
4444
is_musa_available,
45+
is_neuron_available,
4546
is_npu_available,
4647
is_sdaa_available,
4748
is_torch_xla_available,
@@ -404,6 +405,7 @@ def wait_for_everyone(self):
404405
DistributedType.MULTI_XPU,
405406
DistributedType.MULTI_CPU,
406407
DistributedType.MULTI_HPU,
408+
DistributedType.MULTI_NEURON,
407409
DistributedType.DEEPSPEED,
408410
DistributedType.FSDP,
409411
):
@@ -726,6 +728,7 @@ def default_device(self) -> torch.device:
726728
- MUSA if `is_musa_available()`
727729
- NPU if `is_npu_available()`
728730
- HPU if `is_hpu_available()`
731+
- NEURON if `is_neuron_available()`
729732
- CPU otherwise
730733
"""
731734
if is_mps_available():
@@ -747,6 +750,8 @@ def default_device(self) -> torch.device:
747750
return torch.device("cuda")
748751
elif is_xpu_available():
749752
return torch.device("xpu")
753+
elif is_neuron_available():
754+
return torch.device("neuron")
750755
else:
751756
return torch.device("cpu")
752757

@@ -791,6 +796,9 @@ def _prepare_backend(
791796
if backend is None:
792797
backend = "xccl"
793798
distributed_type = DistributedType.MULTI_XPU
799+
elif is_neuron_available():
800+
backend = "neuron"
801+
distributed_type = DistributedType.MULTI_NEURON
794802

795803
if (
796804
distributed_type is None
@@ -821,7 +829,7 @@ def set_device(self):
821829
self.device = torch.device("cpu") if self._cpu else self.default_device
822830
return
823831
device = str(self.distributed_type).split(".")[-1].replace("MULTI_", "").lower()
824-
if device not in ("cpu", "gpu", "mlu", "musa", "npu", "xpu", "xla", "hpu", "sdaa"):
832+
if device not in ("cpu", "gpu", "mlu", "musa", "npu", "xpu", "xla", "hpu", "sdaa", "neuron"):
825833
raise ValueError(
826834
f"Can't set device for {self.distributed_type} ({device}), verify we should be calling `_set_device()` for it!"
827835
)
@@ -984,6 +992,7 @@ def __init__(
984992
DistributedType.MULTI_NPU,
985993
DistributedType.MULTI_XPU,
986994
DistributedType.MULTI_HPU,
995+
DistributedType.MULTI_NEURON,
987996
]:
988997
# TODO: Siro - remove when axolotl fixes their side
989998
if not os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true":

0 commit comments

Comments
 (0)