Description
Describe the issue
Unexpectedly training with the SGD optimizer is slower than training with the AdamW optimizer. By profiling with Nsight Systems I found out that the SGD optimizer copies approx. 500 MB from the GPU to pageable CPU memory before executing the SGD kernel. The copy takes about 90 ms and the SGD computation 19 ms time. On the other hand, the AdamW optimizer directly starts the kernel which takes approx. 37 ms to compute.
What I expected is the SGD optimizer to be faster than the AdamW, because it only computes one momentum. This would be the case if the odd copy to the CPU would not take place.
Following the profiler report of SGD optimizer.step():
and the profiler report of AdamW optimizer.step():
My training loop:
for data, target in train_data_loader:
with nvtx.annotate(message="data_to_device", color="gray", domain="task"):
ort_data.update_inplace(data.numpy())
ort_target.update_inplace(target.numpy().astype(np.int64))
with nvtx.annotate(message="train", color="green", domain="task"):
_ = model(ort_data, ort_target)
with nvtx.annotate(message="optimize", color="blue", domain="task"):
optimizer.step()
model.lazy_reset_grad()
To reproduce
I am training the VGG16 model from torch-vision using ONNX Runtime on-device training. The optimizer type is selected by specifying it in the artifact generation method call.
To reproduce download the ZIP file and execute the scripts in a Python environment. The script vgg16_gen_artifacts.py
generates the training artifacts (edit to switch between SGD and AdamW Optimizer) and the script train_vgg16.py
performs the training.
To profile with Nsight Systems run:
nsys profile -w true --cuda-memory-usage true --cudabacktrace=true --gpu-metrics-device=all -t cuda,nvtx,cudnn,cublas --capture-range=cudaProfilerApi --capture-range-end=stop -f true -o vgg16 python train_vgg16.py
train_vgg.zip
The training is executed on an NVIDIA Jetson Orin running on JetPack 5.1.2.
Urgency
Due to the project deadline, I would appreciate an answer by the end of this week.
Thanks for helping.
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
1.16.3
PyTorch Version
2.1.0
Execution Provider
CUDA
Execution Provider Library Version
CUDA 11.4