@@ -620,6 +620,95 @@ class CIRYieldOpLowering
620
620
}
621
621
};
622
622
623
+ class CIRGlobalOpLowering
624
+ : public mlir::OpConversionPattern<mlir::cir::GlobalOp> {
625
+ public:
626
+ using OpConversionPattern<mlir::cir::GlobalOp>::OpConversionPattern;
627
+ mlir::LogicalResult
628
+ matchAndRewrite (mlir::cir::GlobalOp op, OpAdaptor adaptor,
629
+ mlir::ConversionPatternRewriter &rewriter) const override {
630
+ auto moduleOp = op->getParentOfType <mlir::ModuleOp>();
631
+ if (!moduleOp)
632
+ return mlir::failure ();
633
+
634
+ mlir::OpBuilder b (moduleOp.getContext ());
635
+
636
+ const auto CIRSymType = op.getSymType ();
637
+ auto convertedType = getTypeConverter ()->convertType (CIRSymType);
638
+ if (!convertedType)
639
+ return mlir::failure ();
640
+ auto memrefType = dyn_cast<mlir::MemRefType>(convertedType);
641
+ if (!memrefType)
642
+ memrefType = mlir::MemRefType::get ({}, convertedType);
643
+ // Add an optional alignment to the global memref.
644
+ mlir::IntegerAttr memrefAlignment =
645
+ op.getAlignment ()
646
+ ? mlir::IntegerAttr::get (b.getI64Type (), op.getAlignment ().value ())
647
+ : mlir::IntegerAttr ();
648
+ // Add an optional initial value to the global memref.
649
+ mlir::Attribute initialValue = mlir::Attribute ();
650
+ std::optional<mlir::Attribute> init = op.getInitialValue ();
651
+ if (init.has_value ()) {
652
+ if (auto constArr = init.value ().dyn_cast <mlir::cir::ZeroAttr>()) {
653
+ if (memrefType.getShape ().size ()) {
654
+ auto rtt = mlir::RankedTensorType::get (memrefType.getShape (),
655
+ memrefType.getElementType ());
656
+ initialValue = mlir::DenseIntElementsAttr::get (rtt, 0 );
657
+ } else {
658
+ auto rtt = mlir::RankedTensorType::get ({}, convertedType);
659
+ initialValue = mlir::DenseIntElementsAttr::get (rtt, 0 );
660
+ }
661
+ } else if (auto intAttr = init.value ().dyn_cast <mlir::cir::IntAttr>()) {
662
+ auto rtt = mlir::RankedTensorType::get ({}, convertedType);
663
+ initialValue = mlir::DenseIntElementsAttr::get (rtt, intAttr.getValue ());
664
+ } else if (auto fltAttr = init.value ().dyn_cast <mlir::cir::FPAttr>()) {
665
+ auto rtt = mlir::RankedTensorType::get ({}, convertedType);
666
+ initialValue = mlir::DenseFPElementsAttr::get (rtt, fltAttr.getValue ());
667
+ } else if (auto boolAttr = init.value ().dyn_cast <mlir::cir::BoolAttr>()) {
668
+ auto rtt = mlir::RankedTensorType::get ({}, convertedType);
669
+ initialValue =
670
+ mlir::DenseIntElementsAttr::get (rtt, (char )boolAttr.getValue ());
671
+ } else
672
+ llvm_unreachable (
673
+ " GlobalOp lowering with initial value is not fully supported yet" );
674
+ }
675
+
676
+ // Add symbol visibility
677
+ std::string sym_visibility = op.isPrivate () ? " private" : " public" ;
678
+
679
+ rewriter.replaceOpWithNewOp <mlir::memref::GlobalOp>(
680
+ op, b.getStringAttr (op.getSymName ()),
681
+ /* sym_visibility=*/ b.getStringAttr (sym_visibility),
682
+ /* type=*/ memrefType, initialValue,
683
+ /* constant=*/ op.getConstant (),
684
+ /* alignment=*/ memrefAlignment);
685
+
686
+ return mlir::success ();
687
+ }
688
+ };
689
+
690
+ class CIRGetGlobalOpLowering
691
+ : public mlir::OpConversionPattern<mlir::cir::GetGlobalOp> {
692
+ public:
693
+ using OpConversionPattern<mlir::cir::GetGlobalOp>::OpConversionPattern;
694
+
695
+ mlir::LogicalResult
696
+ matchAndRewrite (mlir::cir::GetGlobalOp op, OpAdaptor adaptor,
697
+ mlir::ConversionPatternRewriter &rewriter) const override {
698
+ // FIXME(cir): Premature DCE to avoid lowering stuff we're not using.
699
+ // CIRGen should mitigate this and not emit the get_global.
700
+ if (op->getUses ().empty ()) {
701
+ rewriter.eraseOp (op);
702
+ return mlir::success ();
703
+ }
704
+
705
+ auto type = getTypeConverter ()->convertType (op.getType ());
706
+ auto symbol = op.getName ();
707
+ rewriter.replaceOpWithNewOp <mlir::memref::GetGlobalOp>(op, type, symbol);
708
+ return mlir::success ();
709
+ }
710
+ };
711
+
623
712
void populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
624
713
mlir::TypeConverter &converter) {
625
714
patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
@@ -628,8 +717,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
628
717
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
629
718
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
630
719
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
631
- CIRYieldOpLowering, CIRCosOpLowering>(converter ,
632
- patterns.getContext ());
720
+ CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering ,
721
+ CIRGetGlobalOpLowering>(converter, patterns.getContext ());
633
722
}
634
723
635
724
static mlir::TypeConverter prepareTypeConverter () {
@@ -639,6 +728,8 @@ static mlir::TypeConverter prepareTypeConverter() {
639
728
// FIXME: The pointee type might not be converted (e.g. struct)
640
729
if (!ty)
641
730
return nullptr ;
731
+ if (isa<mlir::cir::ArrayType>(type.getPointee ()))
732
+ return ty;
642
733
return mlir::MemRefType::get ({}, ty);
643
734
});
644
735
converter.addConversion (
@@ -669,8 +760,17 @@ static mlir::TypeConverter prepareTypeConverter() {
669
760
return converter.convertType (type.getUnderlying ());
670
761
});
671
762
converter.addConversion ([&](mlir::cir::ArrayType type) -> mlir::Type {
672
- auto elementType = converter.convertType (type.getEltType ());
673
- return mlir::MemRefType::get (type.getSize (), elementType);
763
+ SmallVector<int64_t > shape;
764
+ mlir::Type curType = type;
765
+ while (auto arrayType = dyn_cast<mlir::cir::ArrayType>(curType)) {
766
+ shape.push_back (arrayType.getSize ());
767
+ curType = arrayType.getEltType ();
768
+ }
769
+ auto elementType = converter.convertType (curType);
770
+ // FIXME: The element type might not be converted (e.g. struct)
771
+ if (!elementType)
772
+ return nullptr ;
773
+ return mlir::MemRefType::get (shape, elementType);
674
774
});
675
775
676
776
return converter;
0 commit comments