Open
Description
Bug description
In data parallel training, we start multiple workers with different initialization of the dataloader and train with horovod. After each batch update, the parameters are synced. Merlin dataloader has different number of batches depending on the selected rank. Therefore, some workers finishes the training loop and other workers are still training - this causes horovod to freeze.
import cudf
import os
import merlin.models.tf.dataset as tf_dataloader
import nvtabular as nvt
os.system('mkdir ./test/')
df = cudf.DataFrame({
'col1': range(0,9000000)
})
df.to_parquet('./test/part_1.parquet')
df = cudf.DataFrame({
'col1': range(0,10000000)
})
df.to_parquet('./test/part_2.parquet')
df = cudf.DataFrame({
'col1': range(0,11000000)
})
df.to_parquet('./test/part_3.parquet')
df = cudf.DataFrame({
'col1': range(0,12000000)
})
df.to_parquet('./test/part_4.parquet')
ds = nvt.Dataset('./test/*.parquet', part_size='100MB')
for i in range(4):
train_dl = tf_dataloader.BatchedDataset(
ds,
batch_size = 1024*16,
shuffle=True,
drop_last=True,
cat_names=['col1'],
global_size=4,
global_rank=i,
)
print(len(train_dl))
Output:
549
610
671
732