@@ -42,13 +42,14 @@ class Typename:
4242 namespaces_name_rule = delimitedList (IDENT , "::" )
4343 rule = (
4444 namespaces_name_rule ("namespaces_and_name" ) #
45- ).setParseAction (lambda t : Typename (t ))
45+ ).setParseAction (lambda t : Typename . from_parse_result (t ))
4646
4747 def __init__ (self ,
48- t : ParseResults ,
48+ name : str ,
49+ namespaces : list [str ],
4950 instantiations : Sequence [ParseResults ] = ()):
50- self .name = t [ - 1 ] # the name is the last element in this list
51- self .namespaces = t [: - 1 ]
51+ self .name = name
52+ self .namespaces = namespaces
5253
5354 # If the first namespace is empty string, just get rid of it.
5455 if self .namespaces and self .namespaces [0 ] == '' :
@@ -63,12 +64,38 @@ def __init__(self,
6364 self .instantiations = []
6465
6566 @staticmethod
66- def from_parse_result (parse_result : Union [ str , list ] ):
67+ def from_parse_result (parse_result : list ):
6768 """Unpack the parsed result to get the Typename instance."""
68- return parse_result [0 ]
69+ name = parse_result [- 1 ] # the name is the last element in this list
70+ namespaces = parse_result [:- 1 ]
71+ return Typename (name , namespaces )
6972
7073 def __repr__ (self ) -> str :
71- return self .to_cpp ()
74+ if self .get_template_args ():
75+ templates = f"<{ self .get_template_args ()} >"
76+ else :
77+ templates = ""
78+
79+ if len (self .namespaces ) > 0 :
80+ namespaces = "::" .join (self .namespaces ) + "::"
81+ else :
82+ namespaces = ""
83+
84+ return f"{ namespaces } { self .name } { templates } "
85+
86+ def get_template_args (self ) -> str :
87+ """Return the template args as a string, e.g. <double, gtsam::Pose3>."""
88+ return ", " .join ([inst .to_cpp () for inst in self .instantiations ])
89+
90+ def templated_name (self ) -> str :
91+ """Return the name without namespace and with the template instantiations if any."""
92+ if self .instantiations :
93+ templates = self .get_template_args ()
94+ name = f"{ self .name } <{ templates } >"
95+ else :
96+ name = self .name
97+
98+ return name
7299
73100 def instantiated_name (self ) -> str :
74101 """Get the instantiated name of the type."""
@@ -84,8 +111,7 @@ def qualified_name(self):
84111 def to_cpp (self ) -> str :
85112 """Generate the C++ code for wrapping."""
86113 if self .instantiations :
87- cpp_name = self .name + "<{}>" .format (", " .join (
88- [inst .to_cpp () for inst in self .instantiations ]))
114+ cpp_name = self .name + f"<{ self .get_template_args ()} >"
89115 else :
90116 cpp_name = self .name
91117 return '{}{}{}' .format (
@@ -129,7 +155,7 @@ class BasicType:
129155 rule = (Or (BASIC_TYPES )("typename" )).setParseAction (lambda t : BasicType (t ))
130156
131157 def __init__ (self , t : ParseResults ):
132- self .typename = Typename (t )
158+ self .typename = Typename . from_parse_result (t )
133159
134160
135161class CustomType :
@@ -148,7 +174,7 @@ class CustomType:
148174 rule = (Typename .rule ("typename" )).setParseAction (lambda t : CustomType (t ))
149175
150176 def __init__ (self , t : ParseResults ):
151- self .typename = Typename (t )
177+ self .typename = Typename . from_parse_result (t )
152178
153179
154180class Type :
@@ -226,18 +252,16 @@ def to_cpp(self) -> str:
226252 """
227253
228254 if self .is_shared_ptr :
229- typename = "std::shared_ptr<{typename}>" .format (
230- typename = self .get_typename ())
255+ typename = f"std::shared_ptr<{ self .get_typename ()} >"
231256 elif self .is_ptr :
232- typename = "{typename}*" . format ( typename = self .typename .to_cpp ())
257+ typename = f" { self .typename .to_cpp ()} *"
233258 elif self .is_ref :
234- typename = typename = "{typename}&" .format (
235- typename = self .get_typename ())
259+ typename = f"{ self .get_typename ()} &"
236260 else :
237261 typename = self .get_typename ()
238262
239- return ( "{ const}{typename}" . format (
240- const = " const " if self . is_const else "" , typename = typename ))
263+ const = " const " if self . is_const else ""
264+ return f" { const } { typename } "
241265
242266
243267class TemplatedType :
@@ -265,7 +289,7 @@ def __init__(self, typename: Typename, template_params: List[Type],
265289 is_const : str , is_shared_ptr : str , is_ptr : str , is_ref : str ):
266290 instantiations = [param .typename for param in template_params ]
267291 # Recreate the typename but with the template params as instantiations.
268- self .typename = Typename (typename .namespaces + [ typename .name ] ,
292+ self .typename = Typename (typename .name , typename .namespaces ,
269293 instantiations )
270294
271295 self .template_params = template_params
@@ -278,22 +302,33 @@ def __init__(self, typename: Typename, template_params: List[Type],
278302 @staticmethod
279303 def from_parse_result (t : ParseResults ):
280304 """Get the TemplatedType from the parser results."""
281- return TemplatedType (t .typename , t .template_params , t . is_const ,
282- t .is_shared_ptr , t .is_ptr , t .is_ref )
305+ return TemplatedType (t .typename , t .template_params . as_list () ,
306+ t .is_const , t . is_shared_ptr , t .is_ptr , t .is_ref )
283307
284308 def __repr__ (self ):
285- return "TemplatedType({typename.namespaces}::{typename.name})" .format (
286- typename = self .typename )
309+ return "TemplatedType({typename.namespaces}::{typename.name}<{template_params}>)" .format (
310+ typename = self .typename , template_params = self .template_params )
311+
312+ def get_template_params (self ):
313+ """
314+ Get the template args for the type as a string.
315+ E.g. for
316+ ```
317+ template <T = {double}, U = {string}>
318+ class Random(){};
319+ ```
320+ it returns `<double, string>`.
321+
322+ """
323+ # Use Type.to_cpp to do the heavy lifting for the template parameters.
324+ return ", " .join ([t .to_cpp () for t in self .template_params ])
287325
288326 def get_typename (self ):
289327 """
290328 Get the typename of this type without any qualifiers.
291329 E.g. for `const std::vector<double>& indices` this will return `std::vector<double>`.
292330 """
293- # Use Type.to_cpp to do the heavy lifting for the template parameters.
294- template_args = ", " .join ([t .to_cpp () for t in self .template_params ])
295-
296- return f"{ self .typename .qualified_name ()} <{ template_args } >"
331+ return f"{ self .typename .qualified_name ()} <{ self .get_template_params ()} >"
297332
298333 def to_cpp (self ):
299334 """
0 commit comments