-
Notifications
You must be signed in to change notification settings - Fork 578
Description
🐛 Describe the bug
I'm experiencing a bug while attempting to train a YOLO-NAS model using the YoloDarknetFormatDetectionDataset. The error occurs when invoking the train method, specifically in the _bbox_loss function within the PPYoloELoss module of the SUPERGRADIENTS library. This issue is observed only when using the YoloDarknetFormatDetectionDataset (CODE_2), whereas using coco_detection_yolo_format_train (CODE_1) does not result in this error.
Details & Error Logs:
In CODE_2, while implementing a class-balanced sampling strategy for overcoming class imbalance in a dataset, the following runtime error surfaces during model training:
RuntimeError: numel: integer multiplication overflow
Steps to Reproduce:
- Use the
YoloDarknetFormatDetectionDatasetto load datasets for training. - Employ the
ClassBalancedSamplerand initialize theDataLoader. - Execute the training process.
- Encounter the overflow error within the loss computation method
_bbox_loss.
Additional Information:
The error seems to come from the line involving torch.masked_select(pred_bboxes, bbox_mask).reshape([-1, 4]) in super_gradients/training/losses/ppyolo_loss.py
Code
import os
import logging
import yaml
import itertools
from super_gradients.training import models, Trainer
from super_gradients.training.losses import PPYoloELoss
from super_gradients.training.datasets.detection_datasets import YoloDarknetFormatDetectionDataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import glob
from super_gradients.training.utils.collate_fn import DetectionCollateFN
from super_gradients.training.datasets.samplers.class_balanced_sampler import ClassBalancedSampler
import torch
DATASET_RELATIVE_PATH = 'ComputerUse_original_resized' ## yolov5 format
# Configure logging (if needed)
def get_class_frequency(dataloader):
class_frequency = {}
for batch_x, batch_y in tqdm(dataloader, total=len(dataloader), desc="Calculating class frequencies"):
for cls in batch_y[:, -1]:
cls = int(cls.item())
class_frequency[cls] = class_frequency.get(cls, 0) + 1
return class_frequency
# Configure the logging system
logging.basicConfig(level=logging.DEBUG)
def main():
HOME = os.getcwd()
logging.info("Working directory: %s", HOME)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
logging.info("Device: %s", DEVICE)
EXPERIMENT_NAME = 'test'
max_epochs = 100
initial_lr = 1e-4
if not DATASET_RELATIVE_PATH:
raise ValueError(
"For local datasets, 'DATASET_RELATIVE_PATH' must be specified.")
if not os.path.isdir(DATASET_RELATIVE_PATH):
raise FileNotFoundError(
f"Local dataset directory not found at: {DATASET_RELATIVE_PATH}")
# Path to data.yaml within the local dataset directory
data_yaml_path = os.path.join(DATASET_RELATIVE_PATH, 'data.yaml')
if not os.path.isfile(data_yaml_path):
raise FileNotFoundError(
f"data.yaml not found in the local dataset directory: {data_yaml_path}")
# Load and parse data.yaml
with open(data_yaml_path, 'r') as f:
try:
data_yaml = yaml.safe_load(f)
except yaml.YAMLError as e:
raise ValueError(f"Error parsing data.yaml: {e}")
# Extract class names
CLASSES = data_yaml.get('names')
if not CLASSES:
raise ValueError(
"No class names found in data.yaml under 'names' key.")
logging.info(f"Loaded classes from data.yaml: {CLASSES}")
# Extract data paths
train_images_dir = data_yaml.get('train')
val_images_dir = data_yaml.get('val')
test_images_dir = data_yaml.get('test') # Optional
if not all([train_images_dir, val_images_dir]):
raise ValueError("data.yaml must contain 'train' and 'val' paths.")
# Define label directories based on data.yaml paths
train_labels_dir = os.path.join(
os.path.dirname(train_images_dir), 'labels')
val_labels_dir = os.path.join(
os.path.dirname(val_images_dir), 'labels')
test_labels_dir = os.path.join(os.path.dirname(
test_images_dir), 'labels') if test_images_dir else None
# Preprocess to remove label files not associated with any image
def clean_label_files(images_dir, labels_dir):
image_files = set(os.path.splitext(os.path.basename(f))[0] for f in glob.glob(
os.path.join(images_dir, "*.*")) if os.path.isfile(f))
label_files = glob.glob(os.path.join(labels_dir, "*.txt"))
orphan_labels = []
for label_file in label_files:
label_name = os.path.splitext(os.path.basename(label_file))[0]
if label_name not in image_files:
orphan_labels.append(label_file)
if orphan_labels:
logging.warning(
f"Removing {len(orphan_labels)} orphan label files from {labels_dir}")
for label_file in orphan_labels:
os.remove(label_file)
clean_label_files(train_images_dir, train_labels_dir)
clean_label_files(val_images_dir, val_labels_dir)
if test_images_dir and test_labels_dir:
clean_label_files(test_images_dir, test_labels_dir)
CHECKPOINT_DIR = os.path.join(HOME, 'checkpoints')
# Initialize training dataset
train_dataset = YoloDarknetFormatDetectionDataset(
data_dir=DATASET_RELATIVE_PATH, images_dir="train/images", labels_dir="train/labels", classes=CLASSES
)
# Initialize validation dataset
val_dataset = YoloDarknetFormatDetectionDataset(
data_dir=DATASET_RELATIVE_PATH, images_dir="valid/images", labels_dir="valid/labels", classes=CLASSES
)
# Get initial class balance for training
initial_balance = train_dataset.get_dataset_classes_information()
initial_class_balance = initial_balance.sum(axis=0)
initial_discrepancy = initial_class_balance.max() / initial_class_balance.min()
print(
f" \n\nBEFORE BALANCE:"
f" Most frequent class (#{np.argmax(initial_class_balance)}) appears {initial_class_balance.max()} times, which is {initial_discrepancy:.2f}x"
f" more frequent than the least frequent class (#{np.argmin(initial_class_balance)}) that appears only {initial_class_balance.min()} times!"
)
# Create DataLoader for training without sampling
vanilla_dataloader = DataLoader(
train_dataset,
batch_size=8,
drop_last=False,
collate_fn=DetectionCollateFN()
)
vanilla_sampled_class_balance = np.zeros_like(initial_class_balance)
# Calculate class frequencies for vanilla DataLoader
for k, v in get_class_frequency(vanilla_dataloader).items():
vanilla_sampled_class_balance[k] = v
# Verify that vanilla sampling matches initial class balance
np.testing.assert_equal(vanilla_sampled_class_balance, initial_class_balance) # no special sampling
# Define parameter grid for oversampling
#oversample_thresholds = [None, 0.5, 1.0, 1.5, 2.0, 3.0, 3.5, 4.0, 4.5, 5.0]
#oversample_aggressivenesses = [0.5, 1.0, 1.5, 2.0, 3.0, 3.5, 4.0, 4.5, 5.0]
oversample_thresholds = [2.5]
oversample_aggressivenesses = [3.0]
parameter_grid = list(itertools.product(oversample_thresholds, oversample_aggressivenesses))
best_discrepancy = float('inf')
best_setting = None
# Iterate over all settings to find the best one
for oversample_threshold, oversample_aggressiveness in parameter_grid:
sampler = ClassBalancedSampler(
train_dataset,
oversample_threshold=oversample_threshold,
oversample_aggressiveness=oversample_aggressiveness
)
train_dataloader_balanced = DataLoader(
train_dataset,
batch_size=8,
drop_last=False,
collate_fn=DetectionCollateFN(),
sampler=sampler
)
balanced_sampled_class_balance = np.zeros_like(initial_class_balance)
# Calculate class frequencies for balanced DataLoader
for k, v in get_class_frequency(train_dataloader_balanced).items():
balanced_sampled_class_balance[k] = v
# Handle division by zero if any class has zero samples
min_count = balanced_sampled_class_balance.min()
if min_count == 0:
print(f"Skipping setting (oversample_threshold={oversample_threshold}, oversample_aggressiveness={oversample_aggressiveness}) due to zero samples in a class.")
continue # Skip this setting as it's invalid
balanced_discrepancy = balanced_sampled_class_balance.max() / min_count
print(
f"AFTER BALANCE (oversample_threshold={oversample_threshold}, oversample_aggressiveness={oversample_aggressiveness}):"
f" Most frequent class (#{np.argmax(balanced_sampled_class_balance)}) appears {balanced_sampled_class_balance.max()} times,"
f" which is {balanced_discrepancy:.2f}x more frequent than the least frequent class (#{np.argmin(balanced_sampled_class_balance)})"
f" that appears only {balanced_sampled_class_balance.min()} times!"
)
# Update best setting if current discrepancy is lower
if balanced_discrepancy < best_discrepancy:
best_discrepancy = balanced_discrepancy
best_setting = (oversample_threshold, oversample_aggressiveness)
if best_setting is None:
raise ValueError("No valid sampling settings found. Please check your parameter grid and dataset.")
print(
f"\nOptimal Sampling Settings:"
f" oversample_threshold={best_setting[0]}, oversample_aggressiveness={best_setting[1]}"
f" with balanced_discrepancy={best_discrepancy:.2f}"
)
optimal_sampler = ClassBalancedSampler(
train_dataset,
oversample_threshold=best_setting[0],
oversample_aggressiveness=best_setting[1]
)
train_dataloader_optimal = DataLoader(
train_dataset,
batch_size=8,
drop_last=False,
collate_fn=DetectionCollateFN(),
sampler=optimal_sampler
)
val_dataloader = DataLoader(
val_dataset,
batch_size=8,
shuffle=False, # No need to shuffle validation data
drop_last=False,
collate_fn=DetectionCollateFN()
)
model = models.get(
'yolo_nas_l',
num_classes=len(CLASSES),
pretrained_weights="coco"
)
trainer = Trainer(experiment_name=EXPERIMENT_NAME,
ckpt_root_dir=CHECKPOINT_DIR)
loss_fn = PPYoloELoss(
use_static_assigner=False,
use_varifocal_loss=False,
num_classes=len(CLASSES),
)
train_params = {
"max_epochs": max_epochs,
"loss": loss_fn,
"criterion_params": None,
"warmup_mode": "LinearEpochLRWarmup",
"lr_warmup_epochs": 5,
"lr_warmup_steps": 0,
"warmup_initial_lr": 1e-6,
"initial_lr": initial_lr,
"metric_to_watch": '[email protected]',
"batch_accumulate": 4,
}
train_params["loss_logging_items_names"] = ["Total Loss", "PPYoloE Loss"]
trainer.train(
model=model,
training_params=train_params,
train_loader=train_dataloader_optimal,
valid_loader=val_dataloader
)
if __name__ == "__main__":
main()Many thanks for your assistance !
Versions
PyTorch version: 1.13.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.39
Python version: 3.9.2rc1 (default, Jan 4 2025, 03:03:13) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.0.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3070
Nvidia driver version: 566.14
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 24
On-line CPU(s) list: 0-23
Vendor ID: GenuineIntel
Model name: 12th Gen Intel(R) Core(TM) i9-12900
CPU family: 6
Model: 151
Thread(s) per core: 2
Core(s) per socket: 12
Socket(s): 1
Stepping: 2
BogoMIPS: 4838.40
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 576 KiB (12 instances)
L1i cache: 384 KiB (12 instances)
L2 cache: 15 MiB (12 instances)
L3 cache: 30 MiB (1 instance)
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Vulnerable: No microcode
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==1.23.0
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] onnx==1.15.0
[pip3] onnxruntime==1.15.0
[pip3] onnxsim==0.4.36
[pip3] torch==1.13.1+cu117
[pip3] torchaudio==0.13.1+cu117
[pip3] torchmetrics==0.8.0
[pip3] torchvision==0.14.1+cu117
[pip3] triton==3.1.0