Skip to content

Error using lightning 2.0 when i use deepspeed and torch.compile both #17549

Open
@yw0nam

Description

@yw0nam

Bug description

Hi, I'm trying to fine tunning pretrained model from huggingface using pytroch 2.0 and lightning

If i don't use torch.compile, everything goes well.
But if i compile the model using torch.compile i got below error.

Failed to collect metadata on function, produced code may be suboptimal. Known situations this can occur are inference mode only compilation involving resize_ or prims (!schema.hasAnyAliasInfo() INTERNAL ASSERT FAILED); if your situation looks different please file a bug to PyTorch.
Traceback (most recent call last):
File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1674, in aot_wrapper_dedupe
fw_metadata, _out = run_functionalized_fw_and_collect_metadata(
File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 606, in inner
flat_f_outs = f(*flat_f_args)
File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2776, in functional_call
out = Interpreter(mod).run(*args[params_len:], **kwargs)
File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/fx/interpreter.py", line 136, in run
self.env[node] = self.run_node(node)
File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/fx/interpreter.py", line 177, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/fx/interpreter.py", line 249, in call_function
return target(*args, **kwargs)
File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_inductor/overrides.py", line 38, in torch_function
return func(*args, **kwargs)
File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 987, in torch_dispatch
return self.dispatch(func, types, args, kwargs)
File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1162, in dispatch
op_impl_out = op_impl(self, func, *args, **kwargs)
File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 410, in local_scalar_dense
raise DataDependentOutputException(func)
torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default

While executing %setitem : [#users=0] = call_function[target=operator.setitem](args = (%add, %eq, 1), kwargs = {})
Original traceback:
File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/utilities/compute.py", line 52, in _safe_divide
denom[denom == 0.0] = 1
| File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/functional/classification/accuracy.py", line 71, in _accuracy_reduce
return _safe_divide(tp, tp + fn)
| File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/classification/accuracy.py", line 205, in compute
return _accuracy_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average)

And here is my training code.

import os
import pandas as pd
from models.pl_model_hf import PL_model
import lightning  as L
from dataset_hf import *
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
import argparse
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, Wav2Vec2Processor
from omegaconf import OmegaConf as OC
from utils import str2bool

def define_argparser():
    p = argparse.ArgumentParser()
    p.add_argument("-t", '--train_config', default='./configs/train.yaml', type=str)
    p.add_argument('--exp_name', type=str, required=True)
    p.add_argument('--save_path', type=str, required=True)
    p.add_argument('--using_model', required=True, type=str)
    p.add_argument('--using_contra', required=True, type=str2bool, nargs='?', const=True, default=False)
    p.add_argument('--using_cma', required=True, type=str2bool, nargs='?', const=True, default=False,)
    p.add_argument('--batch_size', type=int, default=64)
    p.add_argument('--accumulate_grad', type=int, default=1)
    config = p.parse_args()

    return config


def main(args):
    L.seed_everything(1004)
    num_gpu = torch.cuda.device_count()
    train_config = OC.load(args.train_config)

    train_config['path']['log_path'] = os.path.join(args.save_path, "log")
    train_config['path']['ckpt_path'] = os.path.join(args.save_path, "ckpt")
    train_config['path']['exp_name'] = args.exp_name
    train_config['optimizer']['batch_size'] = args.batch_size
    train_config['trainer']['grad_acc'] = args.accumulate_grad
    train_config['model']['using_cma'] = args.using_cma
    train_config['model']['using_model'] = args.using_model
    train_config['model']['using_contra'] = args.using_contra
    # Load train and validation data
    train = pd.read_csv(train_config['path']['train_csv'])
    dev = pd.read_csv(train_config['path']['dev_csv'])
    
    text_tokenizer = AutoTokenizer.from_pretrained(train_config['model']['text_encoder'])
    audio_processor = Wav2Vec2Processor.from_pretrained(train_config['model']['audio_processor'])
    
    train_dataset = multimodal_dataset(train)
    val_dataset = multimodal_dataset(dev)

    print(
        '|train| =', len(train_dataset),
        '|valid| =', len(val_dataset),
    )

    total_batch_size = train_config['optimizer']['batch_size'] * torch.cuda.device_count()
    n_total_iterations = int(len(train_dataset) / (total_batch_size * train_config['trainer']['grad_acc']) * train_config['step']['max_epochs'])
    n_warmup_steps = int(n_total_iterations * train_config['step']['warmup_ratio'])
    
    train_config['step']['total_step'] = n_total_iterations
    train_config['step']['warm_up_step'] = n_warmup_steps
    
    print(
        '#total_iters =', n_total_iterations,
        '#warmup_iters =', n_warmup_steps,
    )
    
    train_loader = DataLoader(
        train_dataset, train_config['optimizer']['batch_size'], num_workers=4,
        collate_fn=multimodal_collator(text_tokenizer, audio_processor), pin_memory=True,
        shuffle=True, drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset, train_config['optimizer']['batch_size'], num_workers=4,
        collate_fn=multimodal_collator(text_tokenizer, audio_processor), pin_memory=True, 
        drop_last=True, shuffle=False
    )
        
        
    # Load model and configuration.
    model = torch.compile(PL_model(train_config), mode='default')

    checkpoint_callback = ModelCheckpoint(
        monitor="val/emotion_loss",
        dirpath=os.path.join(train_config['path']['ckpt_path'], train_config['path']['exp_name']),
        filename="step={step:06d}-val_emotion_loss={val/emotion_loss:.5f}",
        save_top_k=1,
        mode="min",
        auto_insert_metric_name=False,
        every_n_train_steps=train_config['step']['total_step'] // 10 
    )
    logger = TensorBoardLogger(
        train_config['path']['log_path'], name=train_config['path']['exp_name'])
    lr_monitor = LearningRateMonitor(logging_interval='step')
    
    trainer = L.Trainer(
        devices=num_gpu,
        strategy="deepspeed_stage_2",
        max_steps=train_config['step']['total_step'],
        enable_checkpointing=True,
        callbacks=[checkpoint_callback, lr_monitor],
        profiler="simple",
        accumulate_grad_batches=train_config['trainer']['grad_acc'],
        logger=logger,
        gradient_clip_val=train_config['trainer']['grad_clip_thresh'],
        precision=16,
    )
    
    trainer.fit(
        model,
        train_loader,
        val_loader
    )
    
if __name__ == '__main__':
    args = define_argparser()
    main(args)

It's not compatible with deepspeed and torch.compile in lightning?

Thanks for reading.

What version are you seeing the problem on?

master

How to reproduce the bug

'''python
CUDA_VISIBLE_DEVICES=1 python trainer_hf.py --exp_name audio --using_model audio --using_contra False --using_cma False --batch_size 12 --accumulate_grad 1 --save_path /data/research_data/model_weights/EmoNet2/MELD_out/
'''

Error messages and logs

Failed to collect metadata on function, produced code may be suboptimal.  Known situations this can occur are inference mode only compilation involving resize_ or prims (!schema.hasAnyAliasInfo() INTERNAL ASSERT FAILED); if your situation looks different please file a bug to PyTorch.
Traceback (most recent call last):
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1674, in aot_wrapper_dedupe
    fw_metadata, _out = run_functionalized_fw_and_collect_metadata(
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 606, in inner
    flat_f_outs = f(*flat_f_args)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2776, in functional_call
    out = Interpreter(mod).run(*args[params_len:], **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/fx/interpreter.py", line 136, in run
    self.env[node] = self.run_node(node)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/fx/interpreter.py", line 177, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/fx/interpreter.py", line 249, in call_function
    return target(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_inductor/overrides.py", line 38, in __torch_function__
    return func(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 987, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1162, in dispatch
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 410, in local_scalar_dense
    raise DataDependentOutputException(func)
torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default

While executing %setitem : [#users=0] = call_function[target=operator.setitem](args = (%add, %eq, 1), kwargs = {})
Original traceback:
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/utilities/compute.py", line 52, in _safe_divide
    denom[denom == 0.0] = 1
 |   File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/functional/classification/accuracy.py", line 71, in _accuracy_reduce
    return _safe_divide(tp, tp + fn)
 |   File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/classification/accuracy.py", line 205, in compute
    return _accuracy_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average)

Traceback (most recent call last):
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 670, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/__init__.py", line 1390, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx
    return aot_autograd(
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2805, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2498, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1802, in aot_wrapper_dedupe
    compiled_fn = compiler_fn(wrapped_flat_fn, deduped_flat_args, aot_config)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1278, in aot_dispatch_base
    _fw_metadata, _out = run_functionalized_fw_and_collect_metadata(
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 606, in inner
    flat_f_outs = f(*flat_f_args)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1800, in wrapped_flat_fn
    return flat_fn(*add_dupe_args(args))
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2776, in functional_call
    out = Interpreter(mod).run(*args[params_len:], **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/fx/interpreter.py", line 136, in run
    self.env[node] = self.run_node(node)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/fx/interpreter.py", line 177, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/fx/interpreter.py", line 249, in call_function
    return target(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_inductor/overrides.py", line 38, in __torch_function__
    return func(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 987, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1162, in dispatch
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 410, in local_scalar_dense
    raise DataDependentOutputException(func)
torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default

While executing %setitem : [#users=0] = call_function[target=operator.setitem](args = (%add, %eq, 1), kwargs = {})
Original traceback:
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/utilities/compute.py", line 52, in _safe_divide
    denom[denom == 0.0] = 1
 |   File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/functional/classification/accuracy.py", line 71, in _accuracy_reduce
    return _safe_divide(tp, tp + fn)
 |   File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/classification/accuracy.py", line 205, in compute
    return _accuracy_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average)


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/wonjong/codes/2023_Q2/CAN/trainer_hf.py", line 121, in <module>
    main(args)
  File "/home/wonjong/codes/2023_Q2/CAN/trainer_hf.py", line 113, in main
    trainer.fit(
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 520, in fit
    call._call_and_handle_interrupt(
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py", line 42, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 92, in launch
    return function(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 935, in _run
    results = self._run_stage()
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 976, in _run_stage
    self._run_sanity_check()
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 1005, in _run_sanity_check
    val_loop.run()
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/lightning/pytorch/loops/utilities.py", line 177, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 115, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 375, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values())
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py", line 288, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/lightning/pytorch/strategies/deepspeed.py", line 906, in validation_step
    return self.model(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1695, in forward
    loss = self.module(*inputs, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/lightning/pytorch/overrides/base.py", line 102, in forward
    return self._forward_module.validation_step(*inputs, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/home/wonjong/codes/2023_Q2/CAN/models/pl_model_hf.py", line 42, in validation_step
    emo_out, sim, contrastive_label = self.forward(text_inputs, audio_inputs)
  File "/home/wonjong/codes/2023_Q2/CAN/models/pl_model_hf.py", line 43, in <graph break in validation_step>
    emo_loss, contra_loss= self.cal_loss(emo_out, labels['emotion'], sim, contrastive_label)
  File "/home/wonjong/codes/2023_Q2/CAN/models/pl_model_hf.py", line 45, in <graph break in validation_step>
    self.valid_accuracy(emo_out, labels['emotion'])
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/metric.py", line 236, in forward
    self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/metric.py", line 292, in _forward_reduce_state_update
    self.reset()
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/metric.py", line 295, in <graph break in _forward_reduce_state_update>
    self._to_sync = self.dist_sync_on_step
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/metric.py", line 296, in <graph break in _forward_reduce_state_update>
    self._should_unsync = False
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/metric.py", line 298, in <graph break in _forward_reduce_state_update>
    self.compute_on_cpu = False
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/metric.py", line 299, in <graph break in _forward_reduce_state_update>
    self._enable_grad = True  # allow grads for batch computation
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/metric.py", line 302, in <graph break in _forward_reduce_state_update>
    self.update(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/metric.py", line 303, in <graph break in _forward_reduce_state_update>
    batch_val = self.compute()
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/metric.py", line 527, in wrapped_func
    with self.sync_context(
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/metric.py", line 532, in <graph break in wrapped_func>
    value = compute(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 404, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
    return fn(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
    return _compile(
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1792, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 517, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised DataDependentOutputException: aten._local_scalar_dense.default

While executing %setitem : [#users=0] = call_function[target=operator.setitem](args = (%add, %eq, 1), kwargs = {})
Original traceback:
  File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/utilities/compute.py", line 52, in _safe_divide
    denom[denom == 0.0] = 1
 |   File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/functional/classification/accuracy.py", line 71, in _accuracy_reduce
    return _safe_divide(tp, tp + fn)
 |   File "/home/wonjong/anaconda3/envs/MER/lib/python3.9/site-packages/torchmetrics/classification/accuracy.py", line 205, in compute
    return _accuracy_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average)


Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

Environment

Current environment
Experiment Setting
> Python==3.9.16
> CUDA==11.7
> Pytorch==2.0.0
> lightning==2.0.1
> deepspeed==0.9.0
> transformers==4.27.3
> Centos
> How you installed Lightning(`conda`, `pip`, source): pip
> Running environment of LightningApp (e.g. local, cloud): local

More info

No response

cc @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions