Skip to content

Commit 8a2e1e0

Browse files
Added implementation and test cases for MIRMetadata
1 parent 08c3ae8 commit 8a2e1e0

File tree

3 files changed

+148
-1
lines changed

3 files changed

+148
-1
lines changed

include/graphit/midend/mir.h

+29
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <iostream>
1212
#include <unordered_set>
1313
#include <graphit/midend/mir_visitor.h>
14+
#include <graphit/midend/mir_metadata.h>
1415
#include <graphit/midend/var.h>
1516
#include <assert.h>
1617
#include <graphit/midend/field_vector_property.h>
@@ -54,6 +55,8 @@ namespace graphit {
5455
return to<T>(cloneNode());
5556
}
5657

58+
// We use a single map to hold all metadata on the MIR Node
59+
std::unordered_map<std::string, std::shared_ptr<MIRMetadata>> metadata_map;
5760
protected:
5861
template<typename T = MIRNode>
5962
std::shared_ptr<T> self() {
@@ -68,6 +71,32 @@ namespace graphit {
6871
// as I slowly add in support for copy functionalities
6972
return nullptr;
7073
};
74+
public:
75+
// Functions to set and retrieve metadata of different types
76+
template<typename T>
77+
void setMetadata(std::string mdname, T val) {
78+
typename MIRMetadataImpl<T>::Ptr mdnode = std::make_shared<MIRMetadataImpl<T>>(val);
79+
metadata_map[mdname] = mdnode;
80+
}
81+
// This function is safe to be called even if the metadata with
82+
// the specified name doesn't exist
83+
template<typename T>
84+
bool hasMetadata(std::string mdname) {
85+
if (metadata_map.find(mdname) == metadata_map.end())
86+
return false;
87+
typename MIRMetadata::Ptr mdnode = metadata_map[mdname];
88+
if (!mdnode->isa<T>())
89+
return false;
90+
return true;
91+
}
92+
// This function should be called only after confirming that the
93+
// metadata with the given name exists
94+
template <typename T>
95+
T getMetadata(std::string mdname) {
96+
assert(hasMetadata<T>(mdname));
97+
typename MIRMetadata::Ptr mdnode = metadata_map[mdname];
98+
return mdnode->to<T>()->val;
99+
}
71100
};
72101

73102
struct Expr : public MIRNode {

include/graphit/midend/mir_metadata.h

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#ifndef MIR_METADATA_H
2+
#define MIR_METADATA_H
3+
4+
#include <memory>
5+
#include <cassert>
6+
namespace graphit {
7+
namespace mir {
8+
9+
template<typename T>
10+
class MIRMetadataImpl;
11+
12+
// The abstract class for the mir metadata
13+
// Different templated metadata types inherit from this type
14+
class MIRMetadata: public std::enable_shared_from_this<MIRMetadata> {
15+
public:
16+
typedef std::shared_ptr<MIRMetadata> Ptr;
17+
virtual ~MIRMetadata() = default;
18+
19+
20+
template <typename T>
21+
bool isa (void) {
22+
if(std::dynamic_pointer_cast<MIRMetadataImpl<T>>(shared_from_this()))
23+
return true;
24+
return false;
25+
}
26+
template <typename T>
27+
std::shared_ptr<MIRMetadataImpl<T>> to(void) {
28+
std::shared_ptr<MIRMetadataImpl<T>> ret = std::dynamic_pointer_cast<MIRMetadataImpl<T>>(shared_from_this());
29+
assert(ret != nullptr);
30+
return ret;
31+
}
32+
};
33+
34+
// Templated metadata class for each type
35+
template<typename T>
36+
class MIRMetadataImpl: public MIRMetadata {
37+
public:
38+
typedef std::shared_ptr<MIRMetadataImpl<T>> Ptr;
39+
T val;
40+
MIRMetadataImpl(T _val): val(_val) {
41+
}
42+
};
43+
44+
}
45+
}
46+
#endif

test/c++/midend_test.cpp

+73-1
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,76 @@ TEST_F(MidendTest, SimpleVertexSetDeclAllocWithMain) {
110110
"const vertices : vertexset{Vertex} = new vertexset{Vertex}(5);\n"
111111
"func main() print 4; end");
112112
EXPECT_EQ (0, basicTest(is));
113-
}
113+
}
114+
115+
// Test cases for the MIRMetadata API
116+
TEST_F(MidendTest, SimpleMetadataTest) {
117+
istringstream is("func main() print 4; end");
118+
EXPECT_EQ(0, basicTest(is));
119+
EXPECT_EQ(true, mir_context_->isFunction("main"));
120+
121+
mir::FuncDecl::Ptr main_func = mir_context_->getFunction("main");
122+
123+
main_func->setMetadata<bool>("basic_boolean_md", true);
124+
main_func->setMetadata<int>("basic_int_md", 42);
125+
EXPECT_EQ(true, main_func->hasMetadata<bool>("basic_boolean_md"));
126+
EXPECT_EQ(true, main_func->getMetadata<bool>("basic_boolean_md"));
127+
128+
EXPECT_EQ(true, main_func->hasMetadata<int>("basic_int_md"));
129+
EXPECT_EQ(42, main_func->getMetadata<int>("basic_int_md"));
130+
131+
}
132+
TEST_F(MidendTest, SimpleMetadataTestNoExist) {
133+
istringstream is("func main() print 4; end");
134+
EXPECT_EQ(0, basicTest(is));
135+
EXPECT_EQ(true, mir_context_->isFunction("main"));
136+
137+
mir::FuncDecl::Ptr main_func = mir_context_->getFunction("main");
138+
139+
main_func->setMetadata<int>("basic_int_md", 42);
140+
EXPECT_EQ(false, main_func->hasMetadata<int>("other_int_md"));
141+
EXPECT_EQ(false, main_func->hasMetadata<bool>("basic_int_md"));
142+
}
143+
144+
TEST_F(MidendTest, SimpleMetadataTestString) {
145+
istringstream is("func main() print 4; end");
146+
EXPECT_EQ(0, basicTest(is));
147+
EXPECT_EQ(true, mir_context_->isFunction("main"));
148+
149+
mir::FuncDecl::Ptr main_func = mir_context_->getFunction("main");
150+
151+
main_func->setMetadata<std::string>("basic_str_md", "md value");
152+
EXPECT_EQ(true, main_func->hasMetadata<std::string>("basic_str_md"));
153+
EXPECT_EQ("md value", main_func->getMetadata<std::string>("basic_str_md"));
154+
}
155+
156+
TEST_F(MidendTest, SimpleMetadataTestMIRNodeAsMD) {
157+
istringstream is("const val:int = 42;\nfunc main() print val; end");
158+
EXPECT_EQ(0, basicTest(is));
159+
EXPECT_EQ(true, mir_context_->isFunction("main"));
160+
EXPECT_EQ(1, mir_context_->getConstants().size());
161+
162+
mir::FuncDecl::Ptr main_func = mir_context_->getFunction("main");
163+
mir::VarDecl::Ptr decl = mir_context_->getConstants()[0];
164+
165+
main_func->setMetadata<mir::MIRNode::Ptr>("used_var_md", decl);
166+
167+
EXPECT_EQ(true, main_func->hasMetadata<mir::MIRNode::Ptr>("used_var_md"));
168+
mir::MIRNode::Ptr mdnode = main_func->getMetadata<mir::MIRNode::Ptr>("used_var_md");
169+
EXPECT_EQ(true, mir::isa<mir::VarDecl>(mdnode));
170+
}
171+
172+
TEST_F(MidendTest, SimpleMetadataTestMIRNodeVectorAsMD) {
173+
istringstream is("const val:int = 42;\nconst val2: int = 55;\nfunc main() print val + val2; end");
174+
EXPECT_EQ(0, basicTest(is));
175+
EXPECT_EQ(true, mir_context_->isFunction("main"));
176+
EXPECT_EQ(2, mir_context_->getConstants().size());
177+
178+
mir::FuncDecl::Ptr main_func = mir_context_->getFunction("main");
179+
std::vector<mir::VarDecl::Ptr> decls = mir_context_->getConstants();
180+
181+
main_func->setMetadata<std::vector<mir::VarDecl::Ptr>>("used_vars_md", decls);
182+
183+
EXPECT_EQ(true, main_func->hasMetadata<std::vector<mir::VarDecl::Ptr>>("used_vars_md"));
184+
EXPECT_EQ(2, main_func->getMetadata<std::vector<mir::VarDecl::Ptr>>("used_vars_md").size());
185+
}

0 commit comments

Comments
 (0)