@@ -33,6 +33,8 @@ namespace torch_ext
3333namespace moe_comm
3434{
3535
36+ static constexpr size_t CACHELINE_ALIGNMENT = 128 ;
37+
3638// TODO: Is Alignment necessary?
3739// Helper function to align offset to specified byte boundary
3840inline size_t alignOffset (size_t offset, size_t alignment)
@@ -46,7 +48,6 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens)
4648 // TODO: Use lambdas to encapsulate offset and alignment for each entry, which is less error prone and easier to
4749 // read.
4850 constexpr size_t SIZEOF_INT32 = 4 ;
49- constexpr size_t CACHELINE_ALIGNMENT = 128 ;
5051
5152 MoeA2ADataOffsets offsets;
5253 size_t offset = 0 ;
@@ -203,29 +204,43 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
203204 TORCH_CHECK (payload.is_contiguous (), " All payloads must be contiguous" );
204205 }
205206
206- // Calculate buffer sizes for all payloads
207- // Each payload buffer needs space for data from ALL ranks: epSize * maxTokensPerRank * elementsPerToken
208- int64_t totalBytesNeeded = 0 ;
209- std::vector<int64_t > payloadByteSizes;
207+ // Record the cacheline aligned start offset for each payload's recv buffer.
208+ // 1. We assume the base workspace ptr of each rank is aligned (checked in this OP)
209+ // 2. offsets[PAYLOAD_DATA_OFFSET_INDEX] is aligned (ensured in calculateOffsets)
210+ // 3. We align the currentOffset during update.
211+ // In this way, it is guaranteed that the recv buffer is (over-)aligned, sufficient for 128bit vectorized ld/st.
212+
210213 std::vector<int > payloadElementSizes;
211214 std::vector<int > payloadElementsPerToken;
215+ std::vector<size_t > payloadRecvBufferOffsets;
216+
217+ // Start offset for the first payload
218+ size_t currentOffset = static_cast <size_t >(offsets[PAYLOAD_DATA_OFFSET_INDEX]);
212219 for (auto const & payload : inputPayloads)
213220 {
214221 CHECK_CONTIGUOUS (payload);
215222 CHECK_TH_CUDA (payload);
216223 TORCH_CHECK (payload.dim () == 2 , " payload must be a 2D tensor" );
217224 TORCH_CHECK (
218225 payload.size (0 ) == localNumTokens, " payload must have the same first dimension as tokenSelectedExperts" );
226+ // Unlike recv buffer for payloads, payload itself is not allocated by us and we cannot control its alignment.
227+ // We only make sure the payload start offset is 16-byte aligned, while the actual vectorized ld/st width is
228+ // dynamically determined based on bytes per token of this payload.
229+ TORCH_CHECK (reinterpret_cast <uintptr_t >(payload.data_ptr ()) % 16 == 0 , " payload must be 16-byte aligned" );
219230
220231 int elementsPerToken = static_cast <int >(payload.size (1 ));
221232 int elementSize = static_cast <int >(payload.dtype ().itemsize ());
222233 // Each payload buffer stores data from ALL ranks
223234 int64_t bytesPerPayload = epSize * runtimeMaxTokensPerRank * elementsPerToken * elementSize;
224235
225- payloadByteSizes.push_back (bytesPerPayload);
226236 payloadElementSizes.push_back (elementSize);
227237 payloadElementsPerToken.push_back (elementsPerToken);
228- totalBytesNeeded += bytesPerPayload;
238+
239+ payloadRecvBufferOffsets.push_back (currentOffset);
240+
241+ // Update offset and align to cacheline boundary for the next payload recv buffer.
242+ currentOffset += bytesPerPayload;
243+ currentOffset = alignOffset (currentOffset, CACHELINE_ALIGNMENT);
229244 }
230245
231246 CHECK_TH_CUDA (workspace);
@@ -236,16 +251,18 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
236251
237252 // Validate workspace size - must include space for auxiliary data + payloads
238253 int64_t sizePerRank = workspace.size (1 );
239- int64_t requiredSize = offsets[PAYLOAD_DATA_OFFSET_INDEX] + totalBytesNeeded ;
254+ int64_t requiredSize = static_cast < int64_t >(currentOffset) ;
240255 TORCH_CHECK (sizePerRank >= requiredSize,
241256 " Workspace size per rank insufficient for dispatch. "
242257 " Need at least " ,
243- requiredSize, " bytes (" , offsets[PAYLOAD_DATA_OFFSET_INDEX], " for auxiliary data + " , totalBytesNeeded ,
244- " for payloads), but got " , sizePerRank);
258+ requiredSize, " bytes (" , offsets[PAYLOAD_DATA_OFFSET_INDEX], " for auxiliary data + payloads), but got " ,
259+ sizePerRank);
245260
246261 // Get base workspace pointer
247262 uint8_t * workspacePtr = workspace.data_ptr <uint8_t >();
248263 uint8_t * rankWorkSpacePtr = workspacePtr + epRank * workspace.stride (0 );
264+ TORCH_CHECK (reinterpret_cast <uintptr_t >(rankWorkSpacePtr) % CACHELINE_ALIGNMENT == 0 ,
265+ " rankWorkSpacePtr must be %d-byte aligned" , CACHELINE_ALIGNMENT);
249266
250267 // Setup payload descriptors for source data
251268 int num_payloads = static_cast <int >(inputPayloads.size ());
@@ -288,13 +305,10 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
288305 params.completion_flags [target_rank]
289306 = reinterpret_cast <uint32_t *>(targetWorkSpacePtr + offsets[DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX]);
290307
291- size_t offset = static_cast <size_t >(offsets[PAYLOAD_DATA_OFFSET_INDEX]);
292308 for (int payload_idx = 0 ; payload_idx < num_payloads; payload_idx++)
293309 {
294- // Store pointer for current payload
295- params.recv_buffers [target_rank][payload_idx] = targetWorkSpacePtr + offset;
296- // Update offset for next payload
297- offset += payloadByteSizes[payload_idx];
310+ // Store pointer for current payload using pre-calculated aligned offset
311+ params.recv_buffers [target_rank][payload_idx] = targetWorkSpacePtr + payloadRecvBufferOffsets[payload_idx];
298312 }
299313 }
300314
@@ -310,22 +324,17 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
310324
311325 // Create tensor views for the current rank's receive buffers only
312326 std::vector<torch::Tensor> recvTensors;
313- size_t offset = static_cast <size_t >(offsets[PAYLOAD_DATA_OFFSET_INDEX]);
314327 for (int payload_idx = 0 ; payload_idx < num_payloads; payload_idx++)
315328 {
316329 auto const & payload = inputPayloads[payload_idx];
317- // Create tensor view for this payload
318- auto recvTensor = torch::from_blob (rankWorkSpacePtr + offset ,
330+ // Create tensor view for this payload using pre-calculated aligned offset
331+ auto recvTensor = torch::from_blob (rankWorkSpacePtr + payloadRecvBufferOffsets[payload_idx] ,
319332 {epSize, runtimeMaxTokensPerRank, payloadElementsPerToken[payload_idx]}, payload.options ());
320333 recvTensors.push_back (recvTensor);
321-
322- // Update offset for next payload
323- offset += payloadByteSizes[payload_idx];
324334 }
325335
326336 // Compute aligned offset after dispatch payloads for combine payload region
327- constexpr size_t CACHELINE_ALIGNMENT = 128 ;
328- int64_t combinePayloadOffset = static_cast <int64_t >(alignOffset (static_cast <size_t >(offset), CACHELINE_ALIGNMENT));
337+ int64_t combinePayloadOffset = static_cast <int64_t >(alignOffset (currentOffset, CACHELINE_ALIGNMENT));
329338
330339 return std::make_tuple (std::move (recvTensors), combinePayloadOffset);
331340}
@@ -356,6 +365,9 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke
356365 TORCH_CHECK (payload.size (0 ) == epSize, " payload first dimension must equal epSize" );
357366 TORCH_CHECK (
358367 payload.size (1 ) == runtimeMaxTokensPerRank, " payload second dimension must equal runtimeMaxTokensPerRank" );
368+ // We only make sure the payload start offset is 16-byte aligned, while the actual vectorized ld/st width is
369+ // dynamically determined based on bytes per token of this payload.
370+ TORCH_CHECK (reinterpret_cast <uintptr_t >(payload.data_ptr ()) % 16 == 0 , " payload must be 16-byte aligned" );
359371 int64_t elementsPerToken = payload.size (2 );
360372 TORCH_CHECK (elementsPerToken > 0 , " elementsPerToken must be positive" );
361373 TORCH_CHECK (epRank >= 0 && epRank < epSize, " epRank must be in the range [0, epSize)" );
@@ -411,6 +423,7 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke
411423 " for payload), but got " , sizePerRank);
412424
413425 // Create output tensor (local on current rank), no need for initialization
426+ // Typically, newly allocated GPU torch tensors are at least 16-byte aligned.
414427 torch::Tensor output = torch::empty ({localNumTokens, elementsPerToken}, payload.options ());
415428
416429 // Setup combine parameters
0 commit comments