Closed
Description
python script (please correct me if I'm using the torchdynamo API wrong)
dynamo-test.py
import contextlib
import traceback
import time
import torch
import torchvision
import torchdynamo
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
N_WARMUP = 100
N_BENCH = 100
def bench(batch_size, use_dynamo):
model = torchvision.models.resnet50().cuda()
x = torch.randn(batch_size, 3, 224, 224, dtype=torch.float, device='cuda')
train_context = torchdynamo.optimize(aot_autograd_speedup_strategy) if use_dynamo is True else contextlib.nullcontext()
torch.cuda.synchronize()
t0 = time.time()
with train_context:
for _ in range(N_WARMUP):
out = model(x)
out.sum().backward()
torch.cuda.synchronize()
t1 = time.time()
for _ in range(N_BENCH):
out = model(x)
out.sum().backward()
torch.cuda.synchronize()
t2 = time.time()
print('Training img/s (larger better):', batch_size / ((t2 - t1) / N_BENCH))
print('Total time incl. overhead (smaller better):', t2 - t0)
print()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-b', '--batch_size', type=int, default=16)
parser.add_argument('--use_dynamo', action='store_true', default=False)
args = parser.parse_args()
print(args)
bench(args.batch_size, args.use_dynamo)
bash script
#!/bin/bash
python dynamo-test.py -b 16
python dynamo-test.py -b 16 --use_dynamo
python dynamo-test.py -b 64
python dynamo-test.py -b 64 --use_dynamo
python dynamo-test.py -b 128
python dynamo-test.py -b 128 --use_dynamo
Tested with V100 16GB on
f6bbecf ,
pytorch/vision@104073c,
torchdynamo @ 0d59ce9,
pytorch/functorch@ac0fdf1,
cuda 11.6 update 1, cudnn 8.3.3.40
results
00:00:15 Namespace(batch_size=16, use_dynamo=False)
00:00:15 Training img/s (larger better): 306.0273543197808
00:00:15 Total time incl. overhead (smaller better): 11.840148687362671
00:00:15
00:00:41 Namespace(batch_size=16, use_dynamo=True)
00:00:41 Training img/s (larger better): 307.23167670348334
00:00:41 Total time incl. overhead (smaller better): 21.702965021133423
00:00:41
00:01:25 Namespace(batch_size=64, use_dynamo=False)
00:01:25 Training img/s (larger better): 338.7971538235037
00:01:25 Total time incl. overhead (smaller better): 39.105363607406616
00:01:25
00:02:18 Namespace(batch_size=64, use_dynamo=True)
00:02:18 Training img/s (larger better): 344.55961021952044
00:02:18 Total time incl. overhead (smaller better): 48.27671003341675
00:02:18
00:03:37 Namespace(batch_size=128, use_dynamo=False)
00:03:37 Training img/s (larger better): 350.9537424835023
00:03:37 Total time incl. overhead (smaller better): 74.2812168598175
00:03:37
00:03:43 ERROR FROM offset=100 filename /opt/pytorch/vision/torchvision/models/resnet.py 156 RuntimeError
00:03:43 ERROR FROM offset=66 filename /opt/pytorch/vision/torchvision/models/resnet.py 273 RuntimeError
00:03:43 ERROR FROM offset=6 filename /opt/pytorch/vision/torchvision/models/resnet.py 283 RuntimeError
00:03:43 ========== TorchDynamo Stack Trace ==========
00:03:43 Traceback (most recent call last):
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/convert_frame.py", line 170, in _convert_frame_assert
00:03:43 code = transform_code_object(frame.f_code, transform)
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/bytecode_transformation.py", line 338, in transform_code_object
00:03:43 transformations(instructions, code_options)
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/convert_frame.py", line 146, in transform
00:03:43 tracer.run()
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 278, in run
00:03:43 and self.step()
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 256, in step
00:03:43 getattr(self, inst.opname)(inst)
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 142, in wrapper
00:03:43 return inner_fn(self, inst)
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 556, in CALL_FUNCTION
00:03:43 self.call_function(fn, args, {})
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 195, in call_function
00:03:43 self.push(fn.call_function(self, args, kwargs))
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/variables/functions.py", line 182, in call_function
00:03:43 return super().call_function(tx, args, kwargs)
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/variables/functions.py", line 64, in call_function
00:03:43 return tx.inline_user_function_return(
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 227, in inline_user_function_return
00:03:43 result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 1204, in inline_call
00:03:43 return cls.inline_call_(parent, func, args, kwargs)
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 1246, in inline_call_
00:03:43 tracer.run()
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 278, in run
00:03:43 and self.step()
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 256, in step
00:03:43 getattr(self, inst.opname)(inst)
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 142, in wrapper
00:03:43 return inner_fn(self, inst)
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 556, in CALL_FUNCTION
00:03:43 self.call_function(fn, args, {})
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 195, in call_function
00:03:43 self.push(fn.call_function(self, args, kwargs))
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/variables/nn_module.py", line 158, in call_function
00:03:43 tx.call_function(
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 195, in call_function
00:03:43 self.push(fn.call_function(self, args, kwargs))
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/variables/nn_module.py", line 185, in call_function
00:03:43 return tx.inline_user_function_return(
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 227, in inline_user_function_return
00:03:43 result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 1204, in inline_call
00:03:43 return cls.inline_call_(parent, func, args, kwargs)
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 1246, in inline_call_
00:03:43 tracer.run()
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 278, in run
00:03:43 and self.step()
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 256, in step
00:03:43 getattr(self, inst.opname)(inst)
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 142, in wrapper
00:03:43 return inner_fn(self, inst)
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 556, in CALL_FUNCTION
00:03:43 self.call_function(fn, args, {})
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 195, in call_function
00:03:43 self.push(fn.call_function(self, args, kwargs))
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/variables/nn_module.py", line 158, in call_function
00:03:43 tx.call_function(
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 195, in call_function
00:03:43 self.push(fn.call_function(self, args, kwargs))
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/variables/nn_module.py", line 172, in call_function
00:03:43 return variables.TensorVariable.create(
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/variables/tensor.py", line 94, in create
00:03:43 proxy.node.meta["example_value"] = clone_tensor(example_value)
00:03:43 File "/opt/conda/lib/python3.8/site-packages/torchdynamo/utils.py", line 151, in clone_tensor
00:03:43 y = x.clone().requires_grad_(x.requires_grad)
00:03:43 RuntimeError: CUDA out of memory. Tried to allocate 98.00 MiB (GPU 0; 15.78 GiB total capacity; 14.54 GiB already allocated; 19.75 MiB free; 14.62 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
00:03:43 ========== Exception (above) while processing ==========
00:03:43 File "dynamo-test.py", line 50, in <module>
00:03:43 bench(args.batch_size, args.use_dynamo)
00:03:43 File "dynamo-test.py", line 24, in bench
00:03:43 out = model(x)
00:03:43 File "/opt/pytorch/pytorch/torch/nn/modules/module.py", line 1129, in _call_impl
00:03:43 return forward_call(*input, **kwargs)
00:03:43 File "/opt/pytorch/vision/torchvision/models/resnet.py", line 282, in forward
00:03:43 def forward(self, x: Tensor) -> Tensor:
00:03:43 ========== End debug info ==========
00:03:43 Namespace(batch_size=128, use_dynamo=True)
00:03:43 Traceback (most recent call last):
00:03:43 File "dynamo-test.py", line 50, in <module>
00:03:43 bench(args.batch_size, args.use_dynamo)
00:03:43 File "dynamo-test.py", line 24, in bench
00:03:43 out = model(x)
00:03:43 File "/opt/pytorch/pytorch/torch/nn/modules/module.py", line 1129, in _call_impl
00:03:43 return forward_call(*input, **kwargs)
00:03:43 File "/opt/pytorch/vision/torchvision/models/resnet.py", line 283, in forward
00:03:43 return self._forward_impl(x)
00:03:43 File "/opt/pytorch/vision/torchvision/models/resnet.py", line 266, in _forward_impl
00:03:43 x = self.conv1(x)
00:03:43 File "/opt/pytorch/pytorch/torch/nn/modules/module.py", line 1129, in _call_impl
00:03:43 return forward_call(*input, **kwargs)
00:03:43 File "/opt/pytorch/pytorch/torch/nn/modules/conv.py", line 447, in forward
00:03:43 return self._conv_forward(input, self.weight, self.bias)
00:03:43 File "/opt/pytorch/pytorch/torch/nn/modules/conv.py", line 443, in _conv_forward
00:03:43 return F.conv2d(input, weight, bias, self.stride,
00:03:43 RuntimeError: CUDA out of memory. Tried to allocate 392.00 MiB (GPU 0; 15.78 GiB total capacity; 14.52 GiB already allocated; 19.75 MiB free; 14.62 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
It is seen that torchdynamo with aot_autograd_speedup_strategy has increased memory usage and longer overhead on ResNet50 model than the eager mode.
cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @csarofeen @ptrblck @jjsjann123 @kevinstephano @jansel