Skip to content

Commit f3a985c

Browse files
authored
[TRTLLM-10296][fix] Fix the potential misaligned access due to vectorized ld/st instructions in NVLinkOneSided A2A. (NVIDIA#10539)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
1 parent dbb858a commit f3a985c

File tree

3 files changed

+74
-56
lines changed

3 files changed

+74
-56
lines changed

cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ namespace torch_ext
3333
namespace moe_comm
3434
{
3535

36+
static constexpr size_t CACHELINE_ALIGNMENT = 128;
37+
3638
// TODO: Is Alignment necessary?
3739
// Helper function to align offset to specified byte boundary
3840
inline size_t alignOffset(size_t offset, size_t alignment)
@@ -46,7 +48,6 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens)
4648
// TODO: Use lambdas to encapsulate offset and alignment for each entry, which is less error prone and easier to
4749
// read.
4850
constexpr size_t SIZEOF_INT32 = 4;
49-
constexpr size_t CACHELINE_ALIGNMENT = 128;
5051

5152
MoeA2ADataOffsets offsets;
5253
size_t offset = 0;
@@ -203,29 +204,43 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
203204
TORCH_CHECK(payload.is_contiguous(), "All payloads must be contiguous");
204205
}
205206

206-
// Calculate buffer sizes for all payloads
207-
// Each payload buffer needs space for data from ALL ranks: epSize * maxTokensPerRank * elementsPerToken
208-
int64_t totalBytesNeeded = 0;
209-
std::vector<int64_t> payloadByteSizes;
207+
// Record the cacheline aligned start offset for each payload's recv buffer.
208+
// 1. We assume the base workspace ptr of each rank is aligned (checked in this OP)
209+
// 2. offsets[PAYLOAD_DATA_OFFSET_INDEX] is aligned (ensured in calculateOffsets)
210+
// 3. We align the currentOffset during update.
211+
// In this way, it is guaranteed that the recv buffer is (over-)aligned, sufficient for 128bit vectorized ld/st.
212+
210213
std::vector<int> payloadElementSizes;
211214
std::vector<int> payloadElementsPerToken;
215+
std::vector<size_t> payloadRecvBufferOffsets;
216+
217+
// Start offset for the first payload
218+
size_t currentOffset = static_cast<size_t>(offsets[PAYLOAD_DATA_OFFSET_INDEX]);
212219
for (auto const& payload : inputPayloads)
213220
{
214221
CHECK_CONTIGUOUS(payload);
215222
CHECK_TH_CUDA(payload);
216223
TORCH_CHECK(payload.dim() == 2, "payload must be a 2D tensor");
217224
TORCH_CHECK(
218225
payload.size(0) == localNumTokens, "payload must have the same first dimension as tokenSelectedExperts");
226+
// Unlike recv buffer for payloads, payload itself is not allocated by us and we cannot control its alignment.
227+
// We only make sure the payload start offset is 16-byte aligned, while the actual vectorized ld/st width is
228+
// dynamically determined based on bytes per token of this payload.
229+
TORCH_CHECK(reinterpret_cast<uintptr_t>(payload.data_ptr()) % 16 == 0, "payload must be 16-byte aligned");
219230

220231
int elementsPerToken = static_cast<int>(payload.size(1));
221232
int elementSize = static_cast<int>(payload.dtype().itemsize());
222233
// Each payload buffer stores data from ALL ranks
223234
int64_t bytesPerPayload = epSize * runtimeMaxTokensPerRank * elementsPerToken * elementSize;
224235

225-
payloadByteSizes.push_back(bytesPerPayload);
226236
payloadElementSizes.push_back(elementSize);
227237
payloadElementsPerToken.push_back(elementsPerToken);
228-
totalBytesNeeded += bytesPerPayload;
238+
239+
payloadRecvBufferOffsets.push_back(currentOffset);
240+
241+
// Update offset and align to cacheline boundary for the next payload recv buffer.
242+
currentOffset += bytesPerPayload;
243+
currentOffset = alignOffset(currentOffset, CACHELINE_ALIGNMENT);
229244
}
230245

231246
CHECK_TH_CUDA(workspace);
@@ -236,16 +251,18 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
236251

237252
// Validate workspace size - must include space for auxiliary data + payloads
238253
int64_t sizePerRank = workspace.size(1);
239-
int64_t requiredSize = offsets[PAYLOAD_DATA_OFFSET_INDEX] + totalBytesNeeded;
254+
int64_t requiredSize = static_cast<int64_t>(currentOffset);
240255
TORCH_CHECK(sizePerRank >= requiredSize,
241256
"Workspace size per rank insufficient for dispatch. "
242257
"Need at least ",
243-
requiredSize, " bytes (", offsets[PAYLOAD_DATA_OFFSET_INDEX], " for auxiliary data + ", totalBytesNeeded,
244-
" for payloads), but got ", sizePerRank);
258+
requiredSize, " bytes (", offsets[PAYLOAD_DATA_OFFSET_INDEX], " for auxiliary data + payloads), but got ",
259+
sizePerRank);
245260

246261
// Get base workspace pointer
247262
uint8_t* workspacePtr = workspace.data_ptr<uint8_t>();
248263
uint8_t* rankWorkSpacePtr = workspacePtr + epRank * workspace.stride(0);
264+
TORCH_CHECK(reinterpret_cast<uintptr_t>(rankWorkSpacePtr) % CACHELINE_ALIGNMENT == 0,
265+
"rankWorkSpacePtr must be %d-byte aligned", CACHELINE_ALIGNMENT);
249266

250267
// Setup payload descriptors for source data
251268
int num_payloads = static_cast<int>(inputPayloads.size());
@@ -288,13 +305,10 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
288305
params.completion_flags[target_rank]
289306
= reinterpret_cast<uint32_t*>(targetWorkSpacePtr + offsets[DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX]);
290307

291-
size_t offset = static_cast<size_t>(offsets[PAYLOAD_DATA_OFFSET_INDEX]);
292308
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
293309
{
294-
// Store pointer for current payload
295-
params.recv_buffers[target_rank][payload_idx] = targetWorkSpacePtr + offset;
296-
// Update offset for next payload
297-
offset += payloadByteSizes[payload_idx];
310+
// Store pointer for current payload using pre-calculated aligned offset
311+
params.recv_buffers[target_rank][payload_idx] = targetWorkSpacePtr + payloadRecvBufferOffsets[payload_idx];
298312
}
299313
}
300314

@@ -310,22 +324,17 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
310324

311325
// Create tensor views for the current rank's receive buffers only
312326
std::vector<torch::Tensor> recvTensors;
313-
size_t offset = static_cast<size_t>(offsets[PAYLOAD_DATA_OFFSET_INDEX]);
314327
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
315328
{
316329
auto const& payload = inputPayloads[payload_idx];
317-
// Create tensor view for this payload
318-
auto recvTensor = torch::from_blob(rankWorkSpacePtr + offset,
330+
// Create tensor view for this payload using pre-calculated aligned offset
331+
auto recvTensor = torch::from_blob(rankWorkSpacePtr + payloadRecvBufferOffsets[payload_idx],
319332
{epSize, runtimeMaxTokensPerRank, payloadElementsPerToken[payload_idx]}, payload.options());
320333
recvTensors.push_back(recvTensor);
321-
322-
// Update offset for next payload
323-
offset += payloadByteSizes[payload_idx];
324334
}
325335

326336
// Compute aligned offset after dispatch payloads for combine payload region
327-
constexpr size_t CACHELINE_ALIGNMENT = 128;
328-
int64_t combinePayloadOffset = static_cast<int64_t>(alignOffset(static_cast<size_t>(offset), CACHELINE_ALIGNMENT));
337+
int64_t combinePayloadOffset = static_cast<int64_t>(alignOffset(currentOffset, CACHELINE_ALIGNMENT));
329338

330339
return std::make_tuple(std::move(recvTensors), combinePayloadOffset);
331340
}
@@ -356,6 +365,9 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke
356365
TORCH_CHECK(payload.size(0) == epSize, "payload first dimension must equal epSize");
357366
TORCH_CHECK(
358367
payload.size(1) == runtimeMaxTokensPerRank, "payload second dimension must equal runtimeMaxTokensPerRank");
368+
// We only make sure the payload start offset is 16-byte aligned, while the actual vectorized ld/st width is
369+
// dynamically determined based on bytes per token of this payload.
370+
TORCH_CHECK(reinterpret_cast<uintptr_t>(payload.data_ptr()) % 16 == 0, "payload must be 16-byte aligned");
359371
int64_t elementsPerToken = payload.size(2);
360372
TORCH_CHECK(elementsPerToken > 0, "elementsPerToken must be positive");
361373
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
@@ -411,6 +423,7 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke
411423
" for payload), but got ", sizePerRank);
412424

413425
// Create output tensor (local on current rank), no need for initialization
426+
// Typically, newly allocated GPU torch tensors are at least 16-byte aligned.
414427
torch::Tensor output = torch::empty({localNumTokens, elementsPerToken}, payload.options());
415428

416429
// Setup combine parameters

tensorrt_llm/_torch/distributed/moe_alltoall.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,25 +54,31 @@ def calculate_required_workspace_size(
5454
dtype: torch.dtype,
5555
extra_payload_bytes_per_token: int = 0) -> int:
5656
element_size = dtype.itemsize
57+
5758
# Auxiliary data size
58-
aux_size = MoeAlltoAll.get_aux_data_size(ep_size, max_num_tokens)
59+
workspace_size = MoeAlltoAll.get_aux_data_size(ep_size, max_num_tokens)
5960

6061
# Dispatch needs workspace for [ep_size, max_tokens] tokens,
61-
# but due to the variety of quantization recipes, we cannot know the exact size,
62-
# so we conservatively estimate assuming no quantization.
63-
payload_size_dispatch = ep_size * max_num_tokens * (
64-
hidden_size * element_size # (Unquantized) token hidden states
65-
+ top_k * 4 # token_selected_experts
66-
+ top_k * 4 # token_final_scales
67-
+ extra_payload_bytes_per_token # extra payload bytes per token
68-
)
62+
# but due to the variety of quantization recipes, we cannot know the exact size, so we conservatively estimate assuming no quantization.
63+
# Meanwhile, we consider the alignment requirement as in moeA2ADispatchOp and moeA2ACombineOp.
64+
# (Unquantized) token hidden states
65+
workspace_size += ep_size * max_num_tokens * hidden_size * element_size
66+
workspace_size = pad_up(workspace_size, 128)
67+
# token_selected_experts
68+
workspace_size += ep_size * max_num_tokens * top_k * 4
69+
workspace_size = pad_up(workspace_size, 128)
70+
# token_final_scales
71+
workspace_size += ep_size * max_num_tokens * top_k * 4
72+
workspace_size = pad_up(workspace_size, 128)
73+
# extra payload bytes per token
74+
workspace_size += ep_size * max_num_tokens * extra_payload_bytes_per_token
75+
workspace_size = pad_up(workspace_size, 128)
6976

7077
# Required workspace for combine [ep_size, max_tokens] tokens
71-
payload_size_combine = ep_size * max_num_tokens * hidden_size * element_size
78+
workspace_size += ep_size * max_num_tokens * hidden_size * element_size
79+
workspace_size = pad_up(workspace_size, 128)
7280

73-
# Pad to 128 bytes to ensure alignment. This matches the implementation of C++ torch OP code.
74-
return pad_up(aux_size, 128) + pad_up(
75-
payload_size_dispatch, 128) + pad_up(payload_size_combine, 128)
81+
return workspace_size
7682

7783
@classmethod
7884
def _init_constants(cls):

tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -81,32 +81,31 @@ def calculate_required_workspace_size(
8181
extra_payload_bytes_per_token: int = 0,
8282
) -> int:
8383
element_size = dtype.itemsize
84+
8485
# Auxiliary data size
85-
aux_size = NVLinkOneSided.get_aux_data_size(ep_size, max_num_tokens)
86+
workspace_size = NVLinkOneSided.get_aux_data_size(ep_size, max_num_tokens)
8687

8788
# Dispatch needs workspace for [ep_size, max_tokens] tokens,
88-
# but due to the variety of quantization recipes, we cannot know the exact size,
89-
# so we conservatively estimate assuming no quantization.
90-
payload_size_dispatch = (
91-
ep_size
92-
* max_num_tokens
93-
* (
94-
hidden_size * element_size # (Unquantized) token hidden states
95-
+ top_k * 4 # token_selected_experts
96-
+ top_k * 4 # token_final_scales
97-
+ extra_payload_bytes_per_token # extra payload bytes per token
98-
)
99-
)
89+
# but due to the variety of quantization recipes, we cannot know the exact size, so we conservatively estimate assuming no quantization.
90+
# Meanwhile, we consider the alignment requirement as in moeA2ADispatchOp and moeA2ACombineOp.
91+
# (Unquantized) token hidden states
92+
workspace_size += ep_size * max_num_tokens * hidden_size * element_size
93+
workspace_size = pad_up(workspace_size, 128)
94+
# token_selected_experts
95+
workspace_size += ep_size * max_num_tokens * top_k * 4
96+
workspace_size = pad_up(workspace_size, 128)
97+
# token_final_scales
98+
workspace_size += ep_size * max_num_tokens * top_k * 4
99+
workspace_size = pad_up(workspace_size, 128)
100+
# extra payload bytes per token
101+
workspace_size += ep_size * max_num_tokens * extra_payload_bytes_per_token
102+
workspace_size = pad_up(workspace_size, 128)
100103

101104
# Required workspace for combine [ep_size, max_tokens] tokens
102-
payload_size_combine = ep_size * max_num_tokens * hidden_size * element_size
105+
workspace_size += ep_size * max_num_tokens * hidden_size * element_size
106+
workspace_size = pad_up(workspace_size, 128)
103107

104-
# Pad to 128 bytes to ensure alignment. This matches the implementation of C++ torch OP code.
105-
return (
106-
pad_up(aux_size, 128)
107-
+ pad_up(payload_size_dispatch, 128)
108-
+ pad_up(payload_size_combine, 128)
109-
)
108+
return workspace_size
110109

111110
@classmethod
112111
def _init_constants(cls):

0 commit comments

Comments
 (0)