@@ -43,9 +43,8 @@ using namespace mlir;
4343
4444nlohmann::json readPropertiesFromFile (const std::string &filename) {
4545 std::ifstream file (filename);
46- if (!file.is_open ()) {
47- throw std::runtime_error (" Failed to open file: " + filename);
48- }
46+ if (!file.is_open ())
47+ throw std::runtime_error (" failed to open file: '" + filename + " '" );
4948
5049 nlohmann::json properties;
5150 file >> properties;
@@ -60,6 +59,10 @@ struct PropertyEntry {
6059};
6160
6261namespace {
62+ /* *
63+ * @brief Inserts the true data properties of each matrix-typed intermediate result (with certain exceptions) recorded
64+ * in a previous run of DAPHNE into the IR by inserting `mlir::daphne::CastOp`s that add these data properties.
65+ */
6366struct InsertPropertiesPass : public impl ::InsertPropertiesPassBase<InsertPropertiesPass> {
6467 InsertPropertiesPass () = default ;
6568
@@ -89,7 +92,7 @@ void InsertPropertiesPass::runOnOperation() {
8992
9093 auto insertRecordedProperties = [&](Operation *op) {
9194 size_t numResults = op->getNumResults ();
92- for (unsigned i = 0 ; i < numResults && propertyIndex < properties.size (); ++i) {
95+ for (size_t i = 0 ; i < numResults && propertyIndex < properties.size (); ++i) {
9396 Value res = op->getResult (i);
9497 nlohmann::json &prop = properties[propertyIndex].properties ;
9598 auto it = prop.begin ();
@@ -98,49 +101,42 @@ void InsertPropertiesPass::runOnOperation() {
98101 const nlohmann::json &value = it.value ();
99102 if (key == " sparsity" ) {
100103 if (value.is_null ()) {
101- llvm::errs () << " Error : 'sparsity' is null for property index " << propertyIndex << " \n " ;
104+ llvm::errs () << " error : 'sparsity' is null for property index " << propertyIndex << " \n " ;
102105 ++it;
103106 continue ;
104107 } else if (!value.is_number ()) {
105- llvm::errs () << " Error : 'sparsity' is not a number for property index " << propertyIndex
108+ llvm::errs () << " error : 'sparsity' is not a number for property index " << propertyIndex
106109 << " \n " ;
107110 ++it;
108111 continue ;
109112 }
110113
111- if (res.getType ().isa <daphne::MatrixType>()) {
112- auto mt = res.getType ().dyn_cast <daphne::MatrixType>();
114+ if (auto mt = res.getType ().dyn_cast <daphne::MatrixType>()) {
113115 double sparsity = value.get <double >();
114- if (mt) {
115- if ((llvm::isa<scf::ForOp>(op) || llvm::isa<scf::WhileOp>(op) ||
116- llvm::isa<scf::IfOp>(op))) {
117- builder.setInsertionPointAfter (op);
118- builder.create <daphne::CastOp>(op->getLoc (), mt.withSparsity (sparsity), res);
119- }
120-
121- else {
122- for (auto &use : res.getUses ()) {
123- Operation *userOp = use.getOwner ();
124- if (isa<scf::ForOp>(userOp) || isa<scf::IfOp>(userOp) ||
125- isa<scf::WhileOp>(userOp)) {
126- auto key = std::make_pair (res, userOp);
127-
128- Value castOpValue;
129- if (castOpMap.count (key)) {
130- castOpValue = castOpMap[key];
131- } else {
132- builder.setInsertionPoint (userOp);
133- castOpValue = builder.create <daphne::CastOp>(op->getLoc (), mt, res);
134- castOpMap[key] = castOpValue;
135- }
136-
137- userOp->setOperand (use.getOperandNumber (), castOpValue);
116+ if ((llvm::isa<scf::ForOp>(op) || llvm::isa<scf::WhileOp>(op) || llvm::isa<scf::IfOp>(op))) {
117+ builder.setInsertionPointAfter (op);
118+ builder.create <daphne::CastOp>(op->getLoc (), mt.withSparsity (sparsity), res);
119+ } else {
120+ for (auto &use : res.getUses ()) {
121+ Operation *userOp = use.getOwner ();
122+ if (isa<scf::ForOp>(userOp) || isa<scf::IfOp>(userOp) || isa<scf::WhileOp>(userOp)) {
123+ auto key = std::make_pair (res, userOp);
124+
125+ Value castOpValue;
126+ if (castOpMap.count (key)) {
127+ castOpValue = castOpMap[key];
128+ } else {
129+ builder.setInsertionPoint (userOp);
130+ castOpValue = builder.create <daphne::CastOp>(op->getLoc (), mt, res);
131+ castOpMap[key] = castOpValue;
138132 }
133+
134+ userOp->setOperand (use.getOperandNumber (), castOpValue);
139135 }
140136 }
141-
142- ++propertyIndex;
143137 }
138+
139+ ++propertyIndex;
144140 }
145141 }
146142 ++it;
@@ -152,40 +148,33 @@ void InsertPropertiesPass::runOnOperation() {
152148 if (propertyIndex >= properties.size ())
153149 return WalkResult::advance ();
154150
155- // Skip specific ops that should not be processed
151+ // Skip specific ops that should not be processed.
156152 if (isa<daphne::RecordPropertiesOp>(op) || op->hasAttr (" daphne.value_ids" ))
157153 return WalkResult::advance ();
158-
159- if (auto castOp = dyn_cast<daphne::CastOp>(op)) {
160- if (castOp.isRemovePropertyCast ()) {
154+ if (auto castOp = dyn_cast<daphne::CastOp>(op))
155+ if (castOp.isRemovePropertyCast ())
161156 return WalkResult::advance ();
162- }
163- }
164157
165- // Handle loops (scf.for and scf.while) and If blocks as black boxes
158+ // Handle loops (scf.for and scf.while) and if- blocks as black boxes.
166159 if (isa<scf::ForOp>(op) || isa<scf::WhileOp>(op) || isa<scf::IfOp>(op)) {
167160 insertRecordedProperties (op);
168161 return WalkResult::skip ();
169162 }
170163
164+ // Check if this is the @main function or a UDF.
171165 if (auto funcOp = llvm::dyn_cast<func::FuncOp>(op)) {
172- // Check if this is the @main function or a UDF
173- if (funcOp.getName () == " main" ) {
166+ if (funcOp.getName () == " main" )
174167 return WalkResult::advance ();
175- } else {
176- return WalkResult::skip ();
177- }
168+ return WalkResult::skip ();
178169 }
179170
180- // Process all other operations that output matrix types
171+ // Process other operations with matrix-typed results.
181172 insertRecordedProperties (op);
182173 return WalkResult::advance ();
183174 });
184175
185- if (propertyIndex < properties.size ()) {
186- llvm::errs () << " Warning: Not all properties were applied."
187- << " \n " ;
188- }
176+ if (propertyIndex < properties.size ())
177+ llvm::errs () << " warning: not all properties were applied\n " ;
189178}
190179
191180std::unique_ptr<OperationPass<func::FuncOp>> mlir::daphne::createInsertPropertiesPass (std::string propertiesFilePath) {
0 commit comments