@@ -126,6 +126,16 @@ enum MemRangeAttribute : uint32_t {
126126constexpr int CpuDeviceId = static_cast <int >(-1 );
127127constexpr int InvalidDeviceId = static_cast <int >(-2 );
128128
129+ // Max scratch size is device dependent.
130+ constexpr size_t kWave32 = 32 ;
131+ constexpr size_t kWave64 = 64 ;
132+ constexpr size_t kScratchBits12X = 18 ;
133+ constexpr size_t kScratchBits9X = 15 ;
134+ constexpr size_t kCompilerRequired = 64 ;
135+ constexpr size_t kMaxStackSize12X = (((1 << kScratchBits12X ) - 1 ) * 256 / kWave32 ) - kCompilerRequired ;
136+ constexpr size_t kMaxStackSize11X = (((1 << kScratchBits9X ) - 1 ) * 256 / kWave32 ) - kCompilerRequired ;
137+ constexpr size_t kMaxStackSize9X = (((1 << kScratchBits9X ) - 1 ) * 256 / kWave64 ) - kCompilerRequired ;
138+
129139enum class ExternalSemaphoreHandleType : uint32_t {
130140 OpaqueFd = 1 , // Handle is an opaque file descriptor
131141 OpaqueWin32 = 2 , // Handle is an opaque shared NT handle
@@ -1653,11 +1663,9 @@ class Device : public RuntimeObject {
16531663 static constexpr size_t kMGInfoSizePerDevice = kMGSyncDataSize + sizeof (MGSyncInfo);
16541664 static constexpr size_t kSGInfoSize = kMGSyncDataSize ;
16551665
1656- // Amount of space used by each wave is in units of 256 dwords.
1657- // As per COMPUTE_TMPRING_SIZE.WAVE_SIZE 24:12
1658- // The field size supports a range of 0->(2M-256) dwords per wave64.
1659- // Per lane this works out to 131056 bytes or 128K - 16
1660- static constexpr size_t kMaxStackSize = ((128 * Ki) - 16 );
1666+ // Max Scratch size is based on ISA and thus per device.
1667+ // Def value is as per GFX9 being the least among supported devices.
1668+ size_t maxStackSize_ = kMaxStackSize9X ;
16611669
16621670 typedef std::list<CommandQueue*> CommandQueues;
16631671
@@ -2132,6 +2140,9 @@ class Device : public RuntimeObject {
21322140 return nullptr ;
21332141 }
21342142
2143+ // ! Returns stack size set for the device
2144+ size_t MaxStackSize () const { return maxStackSize_; }
2145+
21352146#if defined(__clang__)
21362147#if __has_feature(address_sanitizer)
21372148 virtual device::UriLocator* createUriLocator () const = 0;
0 commit comments