@@ -2,9 +2,9 @@ import { CODAMA_ERROR__RENDERERS__UNSUPPORTED_NODE, CodamaError } from '@codama/
22import {
33 arrayTypeNode ,
44 CountNode ,
5+ definedTypeNode ,
56 fixedCountNode ,
67 isNode ,
7- isScalarEnum ,
88 NumberTypeNode ,
99 numberTypeNode ,
1010 pascalCase ,
@@ -13,6 +13,7 @@ import {
1313 remainderCountNode ,
1414 resolveNestedTypeNode ,
1515 snakeCase ,
16+ structTypeNode ,
1617} from '@codama/nodes' ;
1718import { extendVisitor , mergeVisitor , pipe , visit } from '@codama/visitors-core' ;
1819
@@ -140,41 +141,25 @@ export function getTypeManifestVisitor(options: {
140141 visitDefinedType ( definedType , { self } ) {
141142 parentName = pascalCase ( definedType . name ) ;
142143 const manifest = visit ( definedType . type , self ) ;
144+ const traits = getTraitsFromNode ( definedType ) ;
145+ manifest . imports . mergeWith ( traits . imports ) ;
143146 parentName = null ;
144- const traits = [ 'BorshSerialize' , 'BorshDeserialize' , 'Clone' , 'Debug' , 'Eq' , 'PartialEq' ] ;
145-
146- if ( isNode ( definedType . type , 'enumTypeNode' ) && isScalarEnum ( definedType . type ) ) {
147- traits . push ( 'Copy' , 'PartialOrd' , 'Hash' , 'FromPrimitive' ) ;
148- manifest . imports . add ( [ 'num_derive::FromPrimitive' ] ) ;
149- }
150-
151- const nestedStructs = manifest . nestedStructs . map (
152- struct =>
153- `#[derive(${ traits . join ( ', ' ) } )]\n` +
154- '#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]\n' +
155- `${ struct } ` ,
156- ) ;
157-
158- if ( ! isNode ( definedType . type , [ 'enumTypeNode' , 'structTypeNode' ] ) ) {
159- if ( nestedStructs . length > 0 ) {
160- manifest . imports . add ( [ 'borsh::BorshSerialize' , 'borsh::BorshDeserialize' ] ) ;
161- }
162- return {
163- ...manifest ,
164- nestedStructs,
165- type : `pub type ${ pascalCase ( definedType . name ) } = ${ manifest . type } ;` ,
166- } ;
167- }
168147
169- manifest . imports . add ( [ 'borsh::BorshSerialize' , 'borsh::BorshDeserialize' ] ) ;
170- return {
171- ...manifest ,
172- nestedStructs,
173- type :
174- `#[derive(${ traits . join ( ', ' ) } )]\n` +
175- '#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]\n' +
176- `${ manifest . type } ` ,
177- } ;
148+ const nestedStructs = manifest . nestedStructs . map ( struct => {
149+ const nestedTraits = getTraitsFromNode (
150+ definedTypeNode ( {
151+ name : struct . match ( / ^ p u b s t r u c t ( \w + ) / ) ?. [ 1 ] ?? '' ,
152+ type : structTypeNode ( [ ] ) ,
153+ } ) ,
154+ ) ;
155+ manifest . imports . mergeWith ( nestedTraits . imports ) ;
156+ return `${ nestedTraits . render } ${ struct } ` ;
157+ } ) ;
158+ const renderedType = isNode ( definedType . type , [ 'enumTypeNode' , 'structTypeNode' ] )
159+ ? manifest . type
160+ : `pub type ${ pascalCase ( definedType . name ) } = ${ manifest . type } ;` ;
161+
162+ return { ...manifest , nestedStructs, type : `${ traits . render } ${ renderedType } ` } ;
178163 } ,
179164
180165 visitDefinedTypeLink ( node ) {
0 commit comments