-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Description
Which component has the problem?
CUTLASS C++
Bug Report
Describe the bug
I believe there is a remaining ScatterD bug in include/cutlass/epilogue/threadblock/predicated_tile_iterator.h.
Issue #965 identified a ScatterD bug in the epilogue iterator store path, and commit e066ced fixed the corresponding pointer updates in store_with_byte_offset(). However, a similar issue appears to remain in the load path.
In load_with_byte_offset(), the row increment is already guarded by if (!ScatterD)
if (row + 1 < ThreadMap::Iterations::kRow) {
if (!ScatterD) {
byte_pointer += params_.increment_row;
}
}But the group increment and cluster increment are still unconditional:
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}This looks like the same class of issue that was previously fixed in the store path by commit e066ced. I believe the load path should be updated in the same way to avoid potential invalid memory accesses in gather/scatter epilogue usage. Suggested change:
if (group + 1 < ThreadMap::Iterations::kGroup) {
if (!ScatterD) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
if (!ScatterD) {
byte_pointer += params_.increment_cluster;
}
}There also appears to be a second issue in the same file. operator+=() does not implement the same ScatterD / PermuteD pointer-advance logic as operator++(). As a result, operator+=() and repeated operator++() calls do not appear to have consistent semantics under ScatterD. A possible fix would be:
CUTLASS_HOST_DEVICE
PredicatedTileIterator &operator+=(int increment)
{
// Row
state_[0] += increment;
int increment_row = state_[0] / ThreadMap::Count::kRow;
state_[0] = state_[0] % ThreadMap::Count::kRow;
if (!ScatterD) {
byte_pointer_ += (params_.advance_row * increment);
}
if (!ScatterD && !PermuteD) {
store_byte_pointer_ += (params_.advance_row * increment);
}
thread_start_row_ += (ThreadMap::Shape::kRow * increment);
// Group
state_[1] += increment_row;
int increment_group = state_[1] / ThreadMap::Count::kGroup;
state_[1] = state_[1] % ThreadMap::Count::kGroup;
if (!ScatterD) {
byte_pointer_ += (params_.advance_group * increment_row);
}
if (!ScatterD && !PermuteD) {
store_byte_pointer_ += (params_.advance_group * increment_row);
}
thread_start_row_ +=
(ThreadMap::Shape::kGroup - 1) *
ThreadMap::Shape::kRow *
ThreadMap::Count::kRow *
increment_row;
// Cluster
state_[2] += increment_group;
int increment_cluster = state_[2] / ThreadMap::Count::kCluster;
state_[2] = state_[2] % ThreadMap::Count::kCluster;
if (!ScatterD) {
byte_pointer_ += (params_.advance_cluster * increment_group);
}
if (!ScatterD && !PermuteD) {
store_byte_pointer_ += (params_.advance_cluster * increment_group);
}
thread_start_row_ +=
ThreadMap::Count::kGroup *
ThreadMap::Shape::kGroup *
ThreadMap::Count::kRow *
ThreadMap::Shape::kRow *
increment_group;
// Tile
if (!ScatterD) {
byte_pointer_ += (params_.advance_tile * increment_cluster);
}
if (!ScatterD && !PermuteD) {
store_byte_pointer_ += (params_.advance_tile * increment_cluster);
}
thread_start_row_ +=
ThreadMap::Shape::kGroup *
ThreadMap::Shape::kRow *
ThreadMap::Shape::kCluster *
ThreadMap::Shape::kTile *
increment_cluster;
return *this;
}Steps/Code to reproduce bug
I have not included a minimal reproducer yet, but I can provide one if needed.
Expected behavior
I would expect:
load_with_byte_offset()to guardincrement_groupandincrement_clusterwithif (!ScatterD), consistent with the previous fix in the store path.operator+=()to follow the sameScatterD/PermuteDpointer-advance semantics asoperator++().