Skip to content

[BUG] Data parallel training freezes due to different number of batches #75

Open
@bschifferer

Description

@bschifferer

Bug description

In data parallel training, we start multiple workers with different initialization of the dataloader and train with horovod. After each batch update, the parameters are synced. Merlin dataloader has different number of batches depending on the selected rank. Therefore, some workers finishes the training loop and other workers are still training - this causes horovod to freeze.

import cudf
import os

import merlin.models.tf.dataset as tf_dataloader
import nvtabular as nvt

os.system('mkdir ./test/')

df = cudf.DataFrame({
    'col1': range(0,9000000)
})
df.to_parquet('./test/part_1.parquet')
df = cudf.DataFrame({
    'col1': range(0,10000000)
})
df.to_parquet('./test/part_2.parquet')
df = cudf.DataFrame({
    'col1': range(0,11000000)
})
df.to_parquet('./test/part_3.parquet')
df = cudf.DataFrame({
    'col1': range(0,12000000)
})
df.to_parquet('./test/part_4.parquet')

ds = nvt.Dataset('./test/*.parquet', part_size='100MB')
for i in range(4):
    train_dl = tf_dataloader.BatchedDataset(
        ds,
        batch_size = 1024*16,
        shuffle=True,
        drop_last=True,
        cat_names=['col1'],
        global_size=4,
        global_rank=i,
    )
    print(len(train_dl))

Output:

549
610
671
732

Metadata

Metadata

Labels

P1bugSomething isn't working

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions