Skip to content

Commit 0341149

Browse files
unnecessary dataclasses
1 parent c88d8ce commit 0341149

File tree

10 files changed

+36
-70
lines changed

10 files changed

+36
-70
lines changed

examples/stable-diffusion/training/train_dreambooth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import torch.nn.functional as F
4040
import torch.utils.checkpoint
4141
import transformers
42+
from accelerate import DistributedType
4243
from accelerate.logging import get_logger
4344
from accelerate.utils import DistributedDataParallelKwargs
4445
from diffusers import (
@@ -61,7 +62,6 @@
6162

6263
from optimum.habana import GaudiConfig
6364
from optimum.habana.accelerate import GaudiAccelerator
64-
from optimum.habana.accelerate.utils.dataclasses import GaudiDistributedType
6565
from optimum.habana.diffusers import GaudiStableDiffusionPipeline
6666
from optimum.habana.transformers.trainer import _is_peft_model
6767
from optimum.habana.utils import set_seed
@@ -1088,7 +1088,7 @@ def unwrap_model(model, training=False):
10881088
if not training:
10891089
return model
10901090
else:
1091-
if accelerator.distributed_type == GaudiDistributedType.MULTI_HPU:
1091+
if accelerator.distributed_type == DistributedType.MULTI_HPU:
10921092
kwargs = {}
10931093
kwargs["gradient_as_bucket_view"] = True
10941094
accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)

examples/stable-diffusion/training/train_dreambooth_lora_flux.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import torch
3535
import torch.utils.checkpoint
3636
import transformers
37+
from accelerate import DistributedType
3738
from accelerate.logging import get_logger
3839
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration
3940
from datasets import load_dataset
@@ -69,7 +70,6 @@
6970

7071
from optimum.habana import GaudiConfig
7172
from optimum.habana.accelerate import GaudiAccelerator
72-
from optimum.habana.accelerate.utils.dataclasses import GaudiDistributedType
7373
from optimum.habana.utils import set_seed
7474

7575

@@ -762,7 +762,7 @@ def save_model_hook(models, weights, output_dir):
762762
def load_model_hook(models, input_dir):
763763
transformer_ = None
764764

765-
if not accelerator.distributed_type == GaudiDistributedType.DEEPSPEED:
765+
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
766766
while len(models) > 0:
767767
model = models.pop()
768768

@@ -1075,7 +1075,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
10751075
progress_bar.update(1)
10761076
global_step += 1
10771077

1078-
if accelerator.is_main_process or accelerator.distributed_type == GaudiDistributedType.DEEPSPEED:
1078+
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
10791079
if global_step % args.checkpointing_steps == 0:
10801080
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
10811081
if args.checkpoints_total_limit is not None:

examples/stable-diffusion/training/train_dreambooth_lora_sdxl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import torch.nn.functional as F
3636
import torch.utils.checkpoint
3737
import transformers
38+
from accelerate import DistributedType
3839
from accelerate.logging import get_logger
3940
from accelerate.utils import DistributedDataParallelKwargs
4041
from diffusers import (
@@ -68,7 +69,6 @@
6869

6970
from optimum.habana import GaudiConfig
7071
from optimum.habana.accelerate import GaudiAccelerator
71-
from optimum.habana.accelerate.utils.dataclasses import GaudiDistributedType
7272
from optimum.habana.diffusers import GaudiStableDiffusionXLPipeline
7373
from optimum.habana.transformers.trainer import _is_peft_model
7474
from optimum.habana.utils import set_seed
@@ -1019,7 +1019,7 @@ def unwrap_model(model, training=False):
10191019
if not training:
10201020
return model
10211021
else:
1022-
if accelerator.distributed_type == GaudiDistributedType.MULTI_HPU:
1022+
if accelerator.distributed_type == DistributedType.MULTI_HPU:
10231023
kwargs = {}
10241024
kwargs["gradient_as_bucket_view"] = True
10251025
accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)

examples/stable-diffusion/training/train_text_to_image_sdxl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import torch.nn.functional as F
4242
import torch.utils.checkpoint
4343
import transformers
44+
from accelerate import DistributedType
4445
from accelerate.logging import get_logger
4546
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration
4647
from datasets import load_dataset
@@ -62,7 +63,6 @@
6263

6364
from optimum.habana import GaudiConfig
6465
from optimum.habana.accelerate import GaudiAccelerator
65-
from optimum.habana.accelerate.utils.dataclasses import GaudiDistributedType
6666
from optimum.habana.diffusers import (
6767
GaudiDDIMScheduler,
6868
GaudiEulerAncestralDiscreteScheduler,
@@ -896,7 +896,7 @@ def main(args):
896896
for idx, dt in enumerate(dataset["train"]):
897897
dt["image"].save(f"{args.mediapipe}/{idx}.jpg")
898898
f.write(dt["text"] + "\n")
899-
if accelerator.distributed_type != GaudiDistributedType.NO:
899+
if accelerator.distributed_type != DistributedType.NO:
900900
torch.distributed.barrier()
901901

902902
from media_pipe_imgdir import get_dataset_for_pipeline
@@ -1145,7 +1145,7 @@ def unwrap_model(model, training=False):
11451145
if not training:
11461146
return model
11471147
else:
1148-
if accelerator.distributed_type == GaudiDistributedType.MULTI_HPU:
1148+
if accelerator.distributed_type == DistributedType.MULTI_HPU:
11491149
kwargs = {}
11501150
kwargs["gradient_as_bucket_view"] = True
11511151
accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)

optimum/habana/accelerate/accelerator.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161

6262
from ..distributed import parallel_state
6363
from .state import GaudiPartialState
64-
from .utils import GaudiDistributedType, GaudiDynamoBackend, convert_model
64+
from .utils import convert_model
6565

6666

6767
logger = get_logger(__name__)
@@ -162,7 +162,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
162162
"""
163163
if device_placement is None:
164164
device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP
165-
if not evaluation_mode and self.distributed_type == GaudiDistributedType.MULTI_HPU:
165+
if not evaluation_mode and self.distributed_type == DistributedType.MULTI_HPU:
166166
device_placement = None
167167
self._models.append(model)
168168

@@ -223,13 +223,15 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
223223
elif device_placement and not self.verify_device_map(model):
224224
model = model.to(self.device)
225225
if not evaluation_mode:
226-
if self.distributed_type == GaudiDistributedType.MULTI_HPU and self._distribution_strategy != "fast_ddp":
226+
###############################################################################################################
227+
if self.distributed_type == DistributedType.MULTI_HPU and self._distribution_strategy != "fast_ddp":
227228
if any(p.requires_grad for p in model.parameters()):
228229
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
229230
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
230231
if self.ddp_handler is not None:
231232
self.ddp_handler.register_comm_hook(model)
232-
elif self.distributed_type == GaudiDistributedType.FSDP:
233+
###############################################################################################################
234+
elif self.distributed_type == DistributedType.FSDP:
233235
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
234236

235237
# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
@@ -353,7 +355,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
353355
del self._models[-2]
354356
self._models[-1] = model
355357
# torch.compile should be called last and only if the model isn't already compiled.
356-
if self.state.dynamo_plugin.backend != GaudiDynamoBackend.NO and not is_compiled_module(model):
358+
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
357359
compile_kwargs = self.state.dynamo_plugin.to_kwargs()
358360
############################################################################################################
359361
if self.use_regional_compilation:
@@ -567,7 +569,7 @@ def _prepare_deepspeed(self, *args):
567569
os.environ["DEEPSPEED_USE_HPU"] = "true"
568570
engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
569571
# torch.compile should be called if dynamo plugin backend is set and only if the model isn't already compiled.
570-
if self.state.dynamo_plugin.backend != GaudiDynamoBackend.NO and not is_compiled_module(model):
572+
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
571573
compile_kwargs = self.state.dynamo_plugin.to_kwargs()
572574
###############################################################################################################
573575
if self.use_regional_compilation:

optimum/habana/accelerate/state.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,20 @@
1717

1818
import accelerate
1919
import torch
20+
from accelerate import DistributedType
2021
from accelerate.state import PartialState
2122
from accelerate.utils import is_deepspeed_available, parse_flag_from_env
2223

2324
from optimum.utils import logging
2425

2526
from ..distributed import parallel_state
26-
from .utils import GaudiDistributedType
2727

2828

2929
logger = logging.get_logger()
3030

3131

32-
# TODO: Remove when minimize_memory is supported in upstream accelerate and sequence parallelism is managed in GaudiTrainer
32+
# TODO: Remove when minimize_memory is supported in upstream accelerate
33+
# and sequence/context parallelism is managed in GaudiTrainer or supported in upstream accelerate
3334
class GaudiPartialState(PartialState):
3435
"""
3536
Adapted from: https://github.com/huggingface/accelerate/blob/8514c35192ac9762920f1ab052e5cea4c0e46eeb/src/accelerate/state.py#L96
@@ -61,7 +62,7 @@ def __init__(self, cpu: bool = False, **kwargs):
6162
"DeepSpeed is not available, install it with: `pip install"
6263
" git+https://github.com/HabanaAI/DeepSpeed.git@1.20.0`."
6364
)
64-
self.distributed_type = GaudiDistributedType.DEEPSPEED
65+
self.distributed_type = DistributedType.DEEPSPEED
6566
import deepspeed
6667

6768
if world_size > 1:
@@ -74,12 +75,12 @@ def __init__(self, cpu: bool = False, **kwargs):
7475
logger.info("DeepSpeed is enabled.")
7576
self._mixed_precision = "no" # deepspeed handles mixed_precision using deepspeed_config
7677
elif os.environ.get("ACCELERATE_USE_FSDP", "false") == "true":
77-
self.distributed_type = GaudiDistributedType.FSDP
78+
self.distributed_type = DistributedType.FSDP
7879
if not torch.distributed.is_initialized():
7980
torch.distributed.init_process_group(backend=self.backend, rank=rank, world_size=world_size)
8081
logger.info("Enabled distributed run.")
8182
else:
82-
self.distributed_type = GaudiDistributedType.MULTI_HPU
83+
self.distributed_type = DistributedType.MULTI_HPU
8384
if not torch.distributed.is_initialized():
8485
torch.distributed.init_process_group(backend=self.backend, rank=rank, world_size=world_size)
8586
logger.info("Enabled distributed run.")
@@ -104,9 +105,9 @@ def __init__(self, cpu: bool = False, **kwargs):
104105
logger.info("FP8 amax reduction group is already initialized.")
105106
else:
106107
self.distributed_type = (
107-
GaudiDistributedType.NO
108+
DistributedType.NO
108109
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "false"
109-
else GaudiDistributedType.DEEPSPEED
110+
else DistributedType.DEEPSPEED
110111
)
111112
self.num_processes = 1
112113
self.process_index = self.local_process_index = 0

optimum/habana/accelerate/utils/__init__.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
from .dataclasses import (
2-
GaudiDistributedType,
3-
GaudiDynamoBackend,
4-
GaudiFP8RecipeKwargs,
5-
GaudiFullyShardedDataParallelPlugin,
6-
GaudiTorchDynamoPlugin,
7-
)
81
from .transformer_engine import (
92
FP8ContextWrapper,
103
convert_model,

optimum/habana/accelerate/utils/dataclasses.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

optimum/habana/transformers/trainer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import huggingface_hub.utils as hf_hub_utils
3636
import numpy as np
3737
import torch
38-
from accelerate import skip_first_batches
38+
from accelerate import DistributedType, skip_first_batches
3939
from accelerate.data_loader import SeedableRandomSampler
4040
from accelerate.utils import (
4141
DistributedDataParallelKwargs,
@@ -47,8 +47,6 @@
4747
)
4848
from huggingface_hub import upload_folder
4949
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler
50-
51-
from optimum.utils import logging
5250
from transformers import Trainer
5351
from transformers.data.data_collator import DataCollator
5452
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
@@ -106,8 +104,10 @@
106104
is_safetensors_available,
107105
)
108106

107+
from optimum.utils import logging
108+
109109
from ..accelerate import GaudiAccelerator
110-
from ..accelerate.utils import FP8ContextWrapper, GaudiDistributedType
110+
from ..accelerate.utils import FP8ContextWrapper
111111
from ..utils import (
112112
HabanaProfile,
113113
get_hpu_memory_stats,
@@ -903,7 +903,7 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio
903903
self._globalstep_last_logged = self.state.global_step
904904
self._zero_model_grad(model)
905905
_grad_norm: Optional[float] = None
906-
_should_compute_grad_norm: bool = not self.accelerator.distributed_type == GaudiDistributedType.DEEPSPEED and (
906+
_should_compute_grad_norm: bool = not self.accelerator.distributed_type == DistributedType.DEEPSPEED and (
907907
# Gradient clipping
908908
args.max_grad_norm is not None and args.max_grad_norm > 0
909909
)
@@ -1280,15 +1280,15 @@ def _maybe_log_save_evaluate(self, tr_loss, _grad_norm, model, trial, epoch, ign
12801280

12811281
# This grad_norm block was outside of _maybe_log_save_evaluate method causing perf degradation.
12821282
# Moving it here so the grad tensor is only copied when it's needed.
1283-
if self.accelerator.distributed_type == GaudiDistributedType.DEEPSPEED:
1283+
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
12841284
grad_norm = model.get_global_grad_norm()
12851285
# In some cases the grad norm may not return a float
12861286
if hasattr(grad_norm, "item"):
12871287
grad_norm = grad_norm.item()
12881288
else:
12891289
if (
12901290
_grad_norm is not None
1291-
and self.accelerator.distributed_type != GaudiDistributedType.FSDP
1291+
and self.accelerator.distributed_type != DistributedType.FSDP
12921292
and _grad_norm.size() == torch.Size([1])
12931293
):
12941294
grad_norm = _grad_norm.detach().item()

optimum/habana/transformers/training_args.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pathlib import Path
2323
from typing import Optional, Union
2424

25+
from accelerate import DistributedType
2526
from accelerate.state import AcceleratorState
2627
from packaging import version
2728
from transformers.debug_utils import DebugOption
@@ -47,7 +48,6 @@
4748
from optimum.utils import logging
4849

4950
from ..accelerate.state import GaudiPartialState
50-
from ..accelerate.utils import GaudiDistributedType
5151
from ..utils import get_habana_frameworks_version
5252
from .gaudi_configuration import GaudiConfig
5353

@@ -922,7 +922,7 @@ def _setup_devices(self) -> "torch.device":
922922
)
923923
# We rely on `PartialState` to yell if there's issues here (which it will)
924924
self.distributed_state = GaudiPartialState(cpu=self.use_cpu)
925-
if self.deepspeed and self.distributed_state.distributed_type != GaudiDistributedType.DEEPSPEED:
925+
if self.deepspeed and self.distributed_state.distributed_type != DistributedType.DEEPSPEED:
926926
raise RuntimeError(
927927
"Tried to use an already configured `Accelerator` or `PartialState` that was not initialized for DeepSpeed, "
928928
"but also passed in a `deepspeed` configuration to the `TrainingArguments`. Please set "
@@ -999,7 +999,7 @@ def _setup_devices(self) -> "torch.device":
999999
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
10001000
)
10011001

1002-
if self.distributed_state.distributed_type == GaudiDistributedType.NO:
1002+
if self.distributed_state.distributed_type == DistributedType.NO:
10031003
self._n_gpu = 0
10041004

10051005
return device

0 commit comments

Comments
 (0)