Skip to content

[BUG] ScatterD issue in predicated_tile_iterator with unguarded pointer updates in load_with_byte_offset and operator+= #3101

@pengpeng-yu

Description

@pengpeng-yu

Which component has the problem?

CUTLASS C++

Bug Report

Describe the bug

I believe there is a remaining ScatterD bug in include/cutlass/epilogue/threadblock/predicated_tile_iterator.h.

Issue #965 identified a ScatterD bug in the epilogue iterator store path, and commit e066ced fixed the corresponding pointer updates in store_with_byte_offset(). However, a similar issue appears to remain in the load path.

In load_with_byte_offset(), the row increment is already guarded by if (!ScatterD)

if (row + 1 < ThreadMap::Iterations::kRow) {
  if (!ScatterD) {
    byte_pointer += params_.increment_row;
  }
}

But the group increment and cluster increment are still unconditional:

if (group + 1 < ThreadMap::Iterations::kGroup) {
  byte_pointer += params_.increment_group;
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
  byte_pointer += params_.increment_cluster;
}

This looks like the same class of issue that was previously fixed in the store path by commit e066ced. I believe the load path should be updated in the same way to avoid potential invalid memory accesses in gather/scatter epilogue usage. Suggested change:

if (group + 1 < ThreadMap::Iterations::kGroup) {
  if (!ScatterD) {
    byte_pointer += params_.increment_group;
  }
}

if (cluster + 1 < ThreadMap::Iterations::kCluster) {
  if (!ScatterD) {
    byte_pointer += params_.increment_cluster;
  }
}

There also appears to be a second issue in the same file. operator+=() does not implement the same ScatterD / PermuteD pointer-advance logic as operator++(). As a result, operator+=() and repeated operator++() calls do not appear to have consistent semantics under ScatterD. A possible fix would be:

  CUTLASS_HOST_DEVICE
  PredicatedTileIterator &operator+=(int increment)
  {
    // Row
    state_[0] += increment;
    int increment_row = state_[0] / ThreadMap::Count::kRow;
    state_[0] = state_[0] % ThreadMap::Count::kRow;

    if (!ScatterD) {
      byte_pointer_ += (params_.advance_row * increment);
    }

    if (!ScatterD && !PermuteD) {
      store_byte_pointer_ += (params_.advance_row * increment);
    }

    thread_start_row_ += (ThreadMap::Shape::kRow * increment);

    // Group
    state_[1] += increment_row;
    int increment_group = state_[1] / ThreadMap::Count::kGroup;
    state_[1] = state_[1] % ThreadMap::Count::kGroup;

    if (!ScatterD) {
      byte_pointer_ += (params_.advance_group * increment_row);
    }

    if (!ScatterD && !PermuteD) {
      store_byte_pointer_ += (params_.advance_group * increment_row);
    }

    thread_start_row_ +=
        (ThreadMap::Shape::kGroup - 1) *
        ThreadMap::Shape::kRow *
        ThreadMap::Count::kRow *
        increment_row;

        
    // Cluster
    state_[2] += increment_group;
    int increment_cluster = state_[2] / ThreadMap::Count::kCluster;
    state_[2] = state_[2] % ThreadMap::Count::kCluster;

    if (!ScatterD) {
      byte_pointer_ += (params_.advance_cluster * increment_group);
    }

    if (!ScatterD && !PermuteD) {
      store_byte_pointer_ += (params_.advance_cluster * increment_group);
    }

    thread_start_row_ +=
        ThreadMap::Count::kGroup *
        ThreadMap::Shape::kGroup *
        ThreadMap::Count::kRow *
        ThreadMap::Shape::kRow *
        increment_group;

    // Tile
    if (!ScatterD) {
      byte_pointer_ += (params_.advance_tile * increment_cluster);
    }

    if (!ScatterD && !PermuteD) {
      store_byte_pointer_ += (params_.advance_tile * increment_cluster);
    }

    thread_start_row_ +=
        ThreadMap::Shape::kGroup *
        ThreadMap::Shape::kRow *
        ThreadMap::Shape::kCluster *
        ThreadMap::Shape::kTile *
        increment_cluster;

    return *this;
  }

Steps/Code to reproduce bug

I have not included a minimal reproducer yet, but I can provide one if needed.

Expected behavior

I would expect:

  1. load_with_byte_offset() to guard increment_group and increment_cluster with if (!ScatterD), consistent with the previous fix in the store path.
  2. operator+=() to follow the same ScatterD / PermuteD pointer-advance semantics as operator++().

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions