Skip to content

Commit be8c245

Browse files
authored
[RTG] Add map type and attribute (#9706)
1 parent c625f62 commit be8c245

File tree

10 files changed

+396
-3
lines changed

10 files changed

+396
-3
lines changed

include/circt-c/Dialect/RTG.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,18 @@ MLIR_CAPI_EXPORTED bool rtgTypeIsAArray(MlirType type);
9999
/// Returns the element type of the RTG array.
100100
MLIR_CAPI_EXPORTED MlirType rtgArrayTypeGetElementType(MlirType type);
101101

102+
/// If the type is an RTG map.
103+
MLIR_CAPI_EXPORTED bool rtgTypeIsAMap(MlirType type);
104+
105+
/// Creates an RTG map type in the context.
106+
MLIR_CAPI_EXPORTED MlirType rtgMapTypeGet(MlirType keyType, MlirType valueType);
107+
108+
/// Return the key type of the RTG map.
109+
MLIR_CAPI_EXPORTED MlirType rtgMapTypeGetKeyType(MlirType type);
110+
111+
/// Return the value type of the RTG map.
112+
MLIR_CAPI_EXPORTED MlirType rtgMapTypeGetValueType(MlirType type);
113+
102114
/// Creates an RTG tuple type in the context.
103115
MLIR_CAPI_EXPORTED MlirType rtgTupleTypeGet(MlirContext ctxt,
104116
intptr_t numFields,
@@ -218,6 +230,21 @@ MLIR_CAPI_EXPORTED MlirAttribute rtgLabelAttrGet(MlirContext ctx,
218230
/// Returns the name of the RTG label attribute.
219231
MLIR_CAPI_EXPORTED MlirStringRef rtgLabelAttrGetName(MlirAttribute attr);
220232

233+
/// Checks if the attribute is an RTG map attribute.
234+
MLIR_CAPI_EXPORTED bool rtgAttrIsAMap(MlirAttribute attr);
235+
236+
/// Creates an RTG map attribute in the context with the given entries.
237+
MLIR_CAPI_EXPORTED MlirAttribute rtgMapAttrGet(MlirContext ctx,
238+
MlirType mapType,
239+
intptr_t numEntries,
240+
MlirAttribute const *keys,
241+
MlirAttribute const *values);
242+
243+
/// Looks up the value associated with the given key in the RTG map attribute.
244+
/// Returns a null attribute if the key is not found.
245+
MLIR_CAPI_EXPORTED MlirAttribute rtgMapAttrLookup(MlirAttribute attr,
246+
MlirAttribute key);
247+
221248
#ifdef __cplusplus
222249
}
223250
#endif

include/circt/Dialect/RTG/IR/RTGAttributes.td

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,47 @@ def SetAttr : RTGAttrDef<"Set", [
8585
let genVerifyDecl = true;
8686
}
8787

88+
// The accessor type is a pointer instead of a reference because references
89+
// cannot be reassigned (in the attribute 'construct' function).
90+
class DenseMapParameter<string keyOf, string valueOf, string desc = "">
91+
: AttrOrTypeParameter<
92+
"const ::llvm::DenseMap<" # keyOf # ", " # valueOf # "> *", desc> {
93+
let allocator =
94+
"$_dst = new ($_allocator.allocate<DenseMap<" # keyOf # ", " # valueOf #
95+
">>()) DenseMap<" # keyOf # ", " # valueOf #
96+
">($_self->begin(), $_self->end());";
97+
let cppStorageType = "::llvm::DenseMap<" # keyOf # ", " # valueOf # ">";
98+
}
99+
100+
def MapAttr : RTGAttrDef<"Map", [
101+
DeclareAttrInterfaceMethods<TypedAttrInterface>,
102+
]> {
103+
let summary = "a map from keys to values";
104+
105+
let parameters = (ins
106+
AttributeSelfTypeParameter<"", "rtg::MapType">:$type,
107+
DenseMapParameter<"TypedAttr", "TypedAttr", "map entries">:$entries);
108+
109+
let builders = [
110+
AttrBuilderWithInferredContext<
111+
(ins "rtg::MapType":$type,
112+
"const ::llvm::DenseMap<::mlir::TypedAttr, " #
113+
"::mlir::TypedAttr> *":$entries), [{
114+
return $_get(type.getContext(), type, entries);
115+
}]>,
116+
AttrBuilderWithInferredContext<
117+
(ins "rtg::MapType":$type), [{
118+
::llvm::DenseMap<::mlir::TypedAttr, ::mlir::TypedAttr> entries;
119+
return $_get(type.getContext(), type, &entries);
120+
}]>
121+
];
122+
123+
let mnemonic = "map";
124+
let hasCustomAssemblyFormat = true;
125+
126+
let genVerifyDecl = true;
127+
}
128+
88129
def TupleAttr : RTGAttrDef<"Tuple", [
89130
DeclareAttrInterfaceMethods<TypedAttrInterface>,
90131
]> {

include/circt/Dialect/RTG/IR/RTGTypes.td

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,38 @@ def ArrayType : RTGTypeDef<"Array"> {
148148
];
149149
}
150150

151+
def MapType : RTGTypeDef<"Map"> {
152+
let summary = "a map from keys to values";
153+
let description = [{
154+
This type represents a standard map/dictionary datastructure. It does not
155+
make any assumptions about the underlying implementation. Thus a hash map,
156+
tree map, etc. can be used in a backend. It does not guarentee deterministic
157+
iteration.
158+
}];
159+
160+
let parameters = (ins "::mlir::Type":$keyType, "::mlir::Type":$valueType);
161+
162+
let mnemonic = "map";
163+
let assemblyFormat = "`<` $keyType `->` $valueType `>`";
164+
165+
let builders = [
166+
TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType,
167+
"::mlir::Type":$valueType),
168+
"return $_get(keyType.getContext(), keyType, valueType);">,
169+
];
170+
}
171+
172+
class MapTypeOf<Type keyType, Type valueType> : Type<
173+
And<[
174+
MapType.predicate,
175+
SubstLeaves<"$_self", "llvm::cast<rtg::MapType>($_self).getKeyType()",
176+
keyType.predicate>,
177+
SubstLeaves<"$_self", "llvm::cast<rtg::MapType>($_self).getValueType()",
178+
valueType.predicate>
179+
]>,
180+
"map of " # keyType.summary # " keys to " # valueType.summary # " values",
181+
"::circt::rtg::MapType">;
182+
151183
def TupleType : RTGTypeDef<"Tuple"> {
152184
let summary = "a tuple of zero or more fields";
153185
let description = [{

include/circt/Dialect/RTG/IR/RTGVisitors.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ class RTGTypeVisitor {
169169
auto *thisCast = static_cast<ConcreteType *>(this);
170170
return TypeSwitch<Type, ResultType>(type)
171171
.template Case<ImmediateType, SequenceType, SetType, BagType, DictType,
172-
LabelType, IndexType, IntegerType>(
172+
MapType, LabelType, IndexType, IntegerType>(
173173
[&](auto expr) -> ResultType {
174174
return thisCast->visitType(expr, args...);
175175
})

integration_test/Bindings/Python/dialects/rtg.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import circt
55

66
from circt.dialects import rtg, rtgtest
7-
from circt.ir import Context, Location, Module, InsertionPoint, Block, StringAttr, TypeAttr, IndexType
7+
from circt.ir import Context, Location, Module, InsertionPoint, Block, StringAttr, TypeAttr, IndexType, IntegerType, IntegerAttr
88
from circt.passmanager import PassManager
99

1010
with Context() as ctx, Location.unknown():
@@ -269,3 +269,44 @@
269269
stringTy = rtg.StringType.get()
270270
# CHECK: !rtg.string
271271
print(stringTy)
272+
273+
with Context() as ctx, Location.unknown():
274+
circt.register_dialects(ctx)
275+
276+
# Test MapType
277+
keyTy = IntegerType.get_signless(32)
278+
valueTy = IntegerType.get_signless(64)
279+
mapTy = rtg.MapType.get(keyTy, valueTy)
280+
281+
# CHECK: key_type=i32
282+
print(f"key_type={mapTy.key_type}")
283+
# CHECK: value_type=i64
284+
print(f"value_type={mapTy.value_type}")
285+
# CHECK: !rtg.map<i32 -> i64>
286+
print(mapTy)
287+
288+
# Test MapAttr
289+
key0 = IntegerAttr.get(keyTy, 10)
290+
key1 = IntegerAttr.get(keyTy, 20)
291+
value0 = IntegerAttr.get(valueTy, 100)
292+
value1 = IntegerAttr.get(valueTy, 200)
293+
294+
mapAttr = rtg.MapAttr.get(mapTy, [(key0, value0), (key1, value1)])
295+
# CHECK: #rtg.map<10 : i32 -> 100 : i64, 20 : i32 -> 200 : i64>
296+
print(mapAttr)
297+
298+
# Test key-based lookup
299+
lookedUpValue = mapAttr.lookup(key0)
300+
# CHECK: 100 : i64
301+
print(lookedUpValue)
302+
303+
# Test lookup with non-existent key
304+
key2 = IntegerAttr.get(keyTy, 30)
305+
notFound = mapAttr.lookup(key2)
306+
# CHECK: key_not_found=True
307+
print(f"key_not_found={notFound is None}")
308+
309+
# Test empty map
310+
emptyMapAttr = rtg.MapAttr.get(mapTy)
311+
# CHECK: #rtg.map<>
312+
print(emptyMapAttr)

lib/Bindings/Python/RTGModule.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,19 @@ void circt::python::populateDialectRTGSubmodule(nb::module_ &m) {
111111
return rtgArrayTypeGetElementType(self);
112112
});
113113

114+
mlir_type_subclass(m, "MapType", rtgTypeIsAMap)
115+
.def_classmethod(
116+
"get",
117+
[](nb::object cls, MlirType keyType, MlirType valueType) {
118+
return cls(rtgMapTypeGet(keyType, valueType));
119+
},
120+
nb::arg("self"), nb::arg("key_type"), nb::arg("value_type"))
121+
.def_property_readonly(
122+
"key_type", [](MlirType self) { return rtgMapTypeGetKeyType(self); })
123+
.def_property_readonly("value_type", [](MlirType self) {
124+
return rtgMapTypeGetValueType(self);
125+
});
126+
114127
mlir_type_subclass(m, "TupleType", rtgTypeIsATuple)
115128
.def_classmethod(
116129
"get",
@@ -200,6 +213,38 @@ void circt::python::populateDialectRTGSubmodule(nb::module_ &m) {
200213
},
201214
nb::arg("self"), nb::arg("type"), nb::arg("ctxt") = nullptr);
202215

216+
mlir_attribute_subclass(m, "MapAttr", rtgAttrIsAMap)
217+
.def_classmethod(
218+
"get",
219+
[](nb::object cls, MlirType mapType,
220+
const std::vector<std::pair<MlirAttribute, MlirAttribute>>
221+
&entries,
222+
MlirContext ctxt) {
223+
std::vector<MlirAttribute> keys;
224+
std::vector<MlirAttribute> values;
225+
for (auto entry : entries) {
226+
keys.push_back(entry.first);
227+
values.push_back(entry.second);
228+
}
229+
return cls(rtgMapAttrGet(ctxt, mapType, keys.size(), keys.data(),
230+
values.data()));
231+
},
232+
nb::arg("self"), nb::arg("map_type"),
233+
nb::arg("entries") =
234+
std::vector<std::pair<MlirAttribute, MlirAttribute>>(),
235+
nb::arg("ctxt") = nullptr)
236+
.def(
237+
"lookup",
238+
[](MlirAttribute self, MlirAttribute key) {
239+
auto val = rtgMapAttrLookup(self, key);
240+
if (mlirAttributeIsNull(val))
241+
return nb::none();
242+
return nb::cast(val);
243+
},
244+
nb::arg("key"),
245+
"Look up the value associated with the given key. Returns None if "
246+
"the key is not found.");
247+
203248
// Attributes for ISA targets
204249
//===--------------------------------------------------------------------===//
205250

lib/CAPI/Dialect/RTG.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,23 @@ MlirType rtgArrayTypeGetElementType(MlirType type) {
134134
return wrap(cast<ArrayType>(unwrap(type)).getElementType());
135135
}
136136

137+
// MapType
138+
//===----------------------------------------------------------------------===//
139+
140+
bool rtgTypeIsAMap(MlirType type) { return isa<MapType>(unwrap(type)); }
141+
142+
MlirType rtgMapTypeGet(MlirType keyType, MlirType valueType) {
143+
return wrap(MapType::get(unwrap(keyType), unwrap(valueType)));
144+
}
145+
146+
MlirType rtgMapTypeGetKeyType(MlirType type) {
147+
return wrap(cast<MapType>(unwrap(type)).getKeyType());
148+
}
149+
150+
MlirType rtgMapTypeGetValueType(MlirType type) {
151+
return wrap(cast<MapType>(unwrap(type)).getValueType());
152+
}
153+
137154
// TupleType
138155
//===----------------------------------------------------------------------===//
139156

@@ -333,6 +350,35 @@ MlirStringRef rtgLabelAttrGetName(MlirAttribute attr) {
333350
return wrap(cast<LabelAttr>(unwrap(attr)).getName());
334351
}
335352

353+
// MapAttr
354+
//===----------------------------------------------------------------------===//
355+
356+
bool rtgAttrIsAMap(MlirAttribute attr) {
357+
return isa<rtg::MapAttr>(unwrap(attr));
358+
}
359+
360+
MlirAttribute rtgMapAttrGet(MlirContext ctx, MlirType mapType,
361+
intptr_t numEntries, MlirAttribute const *keys,
362+
MlirAttribute const *values) {
363+
DenseMap<TypedAttr, TypedAttr> entries;
364+
for (unsigned i = 0; i < numEntries; ++i) {
365+
entries.insert(
366+
{cast<TypedAttr>(unwrap(keys[i])), cast<TypedAttr>(unwrap(values[i]))});
367+
}
368+
return wrap(rtg::MapAttr::get(cast<rtg::MapType>(unwrap(mapType)), &entries));
369+
}
370+
371+
MlirAttribute rtgMapAttrLookup(MlirAttribute attr, MlirAttribute key) {
372+
auto mapAttr = cast<rtg::MapAttr>(unwrap(attr));
373+
auto keyAttr = cast<TypedAttr>(unwrap(key));
374+
375+
auto it = mapAttr.getEntries()->find(keyAttr);
376+
if (it == mapAttr.getEntries()->end())
377+
return wrap(Attribute()); // Return null attribute if key not found
378+
379+
return wrap(it->second);
380+
}
381+
336382
//===----------------------------------------------------------------------===//
337383
// Passes
338384
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)