@@ -18,18 +18,17 @@ limitations under the License.
1818#include < cstdint>
1919#include < vector>
2020
21- #include " absl/status/status.h"
22- #include " absl/status/statusor.h"
23- #include " absl/strings/str_cat.h"
24- #include " absl/strings/string_view.h"
2521#include " llvm/ADT/STLExtras.h"
2622#include " llvm/ADT/TypeSwitch.h" // IWYU pragma: keep
2723#include " llvm/Support/Casting.h"
24+ #include " llvm/Support/FormatVariadic.h"
25+ #include " llvm/Support/LogicalResult.h"
2826#include " mlir/Conversion/LLVMCommon/MemRefBuilder.h"
2927#include " mlir/Dialect/Arith/IR/Arith.h"
3028#include " mlir/Dialect/Func/IR/FuncOps.h"
3129#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
3230#include " mlir/Dialect/MemRef/IR/MemRef.h"
31+ #include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
3332#include " mlir/Dialect/SCF/Utils/Utils.h"
3433#include " mlir/IR/Builders.h"
3534#include " mlir/IR/BuiltinAttributes.h"
@@ -44,6 +43,12 @@ limitations under the License.
4443#include " mlir/IR/Value.h"
4544#include " mlir/IR/ValueRange.h"
4645#include " mlir/Support/LLVM.h"
46+ #include " absl/algorithm/container.h"
47+ #include " absl/status/status.h"
48+ #include " absl/status/statusor.h"
49+ #include " absl/strings/str_cat.h"
50+ #include " absl/strings/string_view.h"
51+ #include " mlir/include/mlir/IR/Diagnostics.h"
4752#include " tsl/platform/statusor.h"
4853
4954// Generated definitions.
@@ -232,11 +237,89 @@ void DeclareRuntimeFunctions(mlir::OpBuilder& builder) {
232237 .setVisibility (mlir::func::FuncOp::Visibility::Private);
233238}
234239
240+ bool IsContiguous (mlir::MemRefType type) {
241+ return type.getLayout ().isIdentity () ||
242+ (type.hasStaticShape () && type.getNumElements () > 0 &&
243+ mlir::memref::isStaticShapeAndContiguousRowMajor (type));
244+ }
245+
246+ namespace {
247+ llvm::LogicalResult VerifyCommonLoadStoreOp (
248+ mlir::Location loc, mlir::MemRefType gmem_type, absl::string_view gmem_name,
249+ mlir::MemRefType smem_type, absl::string_view smem_name,
250+ mlir::ArrayRef<int64_t > slice_lengths, int num_indices) {
251+ auto error = [loc](auto ... params) {
252+ return emitError (loc, llvm::formatv (params...));
253+ };
254+
255+ if (!IsContiguous (smem_type)) {
256+ return error (" The `{0}` memref must be contiguous." , smem_name);
257+ }
258+ if (gmem_type.getElementType () != smem_type.getElementType ()) {
259+ return error (
260+ " The `source` and `destination` memrefs must have the same element "
261+ " type." );
262+ }
263+ if (absl::c_any_of (slice_lengths, [](int64_t s) { return s < -1 ; })) {
264+ return error (
265+ " The `slice_lengths` attribute must not contain values less than -1." );
266+ }
267+ if (gmem_type.getRank () !=
268+ smem_type.getRank () + absl::c_count (slice_lengths, -1 )) {
269+ return error (
270+ " The rank of the `{0}` must be equal to the rank of the "
271+ " `{1}` plus the number of collapsed dimensions as indicated "
272+ " by -1 values in `slice_lengths`." ,
273+ gmem_name, smem_name);
274+ }
275+ if (num_indices != gmem_type.getRank ()) {
276+ return error (" The size of `indices` must be equal to the rank of `{0}`." ,
277+ gmem_name);
278+ }
279+ if (slice_lengths.size () != gmem_type.getRank ()) {
280+ return error (
281+ " The size of `slice_lengths` must be equal to the rank of `{0}`." ,
282+ gmem_name);
283+ }
284+ return llvm::success ();
285+ }
286+ } // namespace
287+
288+ llvm::LogicalResult AsyncLoadOp::verify () {
289+ auto r = VerifyCommonLoadStoreOp (getLoc (), getSource ().getType (), " source" ,
290+ getDestination ().getType (), " destination" ,
291+ getSliceLengths (), getIndices ().size ());
292+ if (failed (r)) {
293+ return r;
294+ }
295+
296+ for (int i = 0 ; i < getCollective ().size (); ++i) {
297+ for (int k = i + 1 ; k < getCollective ().size (); ++k)
298+ if (getCollective ()[i] == getCollective ()[k]) {
299+ return emitError (
300+ " The `collective` attribute must not contain duplicate "
301+ " dimensions." );
302+ }
303+ }
304+
305+ return llvm::success ();
306+ }
307+
308+ llvm::LogicalResult AsyncStoreOp::verify () {
309+ return VerifyCommonLoadStoreOp (getLoc (), getDestination ().getType (),
310+ " destination" , getSource ().getType (), " source" ,
311+ getSliceLengths (), getIndices ().size ());
312+ }
313+
235314void MosaicGPUDialect::initialize () {
236315 addTypes<
237316#define GET_TYPEDEF_LIST
238317#include " jaxlib/mosaic/dialect/gpu/mosaic_gpu_types.cc.inc"
239318 >();
319+ addAttributes<
320+ #define GET_ATTRDEF_LIST
321+ #include " jaxlib/mosaic/dialect/gpu/mosaic_gpu_attrdefs.cc.inc"
322+ >();
240323 addOperations<
241324#define GET_OP_LIST
242325#include " jaxlib/mosaic/dialect/gpu/mosaic_gpu_ops.cc.inc"
0 commit comments