Skip to content

Commit 33a297d

Browse files
authored
fix: make serde field attrs dynamic and optional (#13)
1 parent ac05970 commit 33a297d

File tree

5 files changed

+341
-28
lines changed

5 files changed

+341
-28
lines changed

.changeset/shaky-dryers-refuse.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@codama/renderers-rust': patch
3+
---
4+
5+
fix: make serde field attributes dynamic and optional

src/getRenderMapVisitor.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}) {
5656
const dependencyMap = options.dependencyMap ?? {};
5757
const getImportFrom = getImportFromFactory(options.linkOverrides ?? {});
5858
const getTraitsFromNode = getTraitsFromNodeFactory(options.traitOptions);
59-
const typeManifestVisitor = getTypeManifestVisitor({ getImportFrom, getTraitsFromNode });
59+
const typeManifestVisitor = getTypeManifestVisitor({
60+
getImportFrom,
61+
getTraitsFromNode,
62+
traitOptions: options.traitOptions,
63+
});
6064
const anchorTraits = options.anchorTraits ?? true;
6165

6266
return pipe(

src/getTypeManifestVisitor.ts

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import { CODAMA_ERROR__RENDERERS__UNSUPPORTED_NODE, CodamaError } from '@codama/errors';
22
import {
3+
AccountNode,
34
arrayTypeNode,
45
CountNode,
6+
DefinedTypeNode,
57
definedTypeNode,
68
fixedCountNode,
9+
InstructionNode,
710
isNode,
811
NumberTypeNode,
912
numberTypeNode,
@@ -18,7 +21,13 @@ import {
1821
import { extendVisitor, mergeVisitor, pipe, visit } from '@codama/visitors-core';
1922

2023
import { ImportMap } from './ImportMap';
21-
import { GetImportFromFunction, GetTraitsFromNodeFunction, rustDocblock } from './utils';
24+
import {
25+
GetImportFromFunction,
26+
getSerdeFieldAttribute,
27+
GetTraitsFromNodeFunction,
28+
rustDocblock,
29+
TraitOptions,
30+
} from './utils';
2231

2332
export type TypeManifest = {
2433
imports: ImportMap;
@@ -31,12 +40,14 @@ export function getTypeManifestVisitor(options: {
3140
getTraitsFromNode: GetTraitsFromNodeFunction;
3241
nestedStruct?: boolean;
3342
parentName?: string | null;
43+
traitOptions?: TraitOptions;
3444
}) {
35-
const { getImportFrom, getTraitsFromNode } = options;
45+
const { getImportFrom, getTraitsFromNode, traitOptions } = options;
3646
let parentName: string | null = options.parentName ?? null;
3747
let nestedStruct: boolean = options.nestedStruct ?? false;
3848
let inlineStruct: boolean = false;
3949
let parentSize: NumberTypeNode | number | null = null;
50+
let parentNode: AccountNode | DefinedTypeNode | InstructionNode | null = null;
4051

4152
return pipe(
4253
mergeVisitor(
@@ -51,10 +62,12 @@ export function getTypeManifestVisitor(options: {
5162
extendVisitor(v, {
5263
visitAccount(account, { self }) {
5364
parentName = pascalCase(account.name);
65+
parentNode = account;
5466
const manifest = visit(account.data, self);
5567
const traits = getTraitsFromNode(account);
5668
manifest.imports.mergeWith(traits.imports);
5769
parentName = null;
70+
parentNode = null;
5871
return {
5972
...manifest,
6073
type: traits.render + manifest.type,
@@ -140,10 +153,12 @@ export function getTypeManifestVisitor(options: {
140153

141154
visitDefinedType(definedType, { self }) {
142155
parentName = pascalCase(definedType.name);
156+
parentNode = definedType;
143157
const manifest = visit(definedType.type, self);
144158
const traits = getTraitsFromNode(definedType);
145159
manifest.imports.mergeWith(traits.imports);
146160
parentName = null;
161+
parentNode = null;
147162

148163
const renderedType = isNode(definedType.type, ['enumTypeNode', 'structTypeNode'])
149164
? manifest.type
@@ -204,12 +219,18 @@ export function getTypeManifestVisitor(options: {
204219
parentName = originalParentName;
205220

206221
let derive = '';
207-
if (childManifest.type === '(Pubkey)') {
208-
derive =
209-
'#[cfg_attr(feature = "serde", serde(with = "serde_with::As::<serde_with::DisplayFromStr>"))]\n';
210-
} else if (childManifest.type === '(Vec<Pubkey>)') {
211-
derive =
212-
'#[cfg_attr(feature = "serde", serde(with = "serde_with::As::<Vec<serde_with::DisplayFromStr>>"))]\n';
222+
if (parentNode && childManifest.type === '(Pubkey)') {
223+
derive = getSerdeFieldAttribute(
224+
'serde_with::As::<serde_with::DisplayFromStr>',
225+
parentNode,
226+
traitOptions,
227+
);
228+
} else if (parentNode && childManifest.type === '(Vec<Pubkey>)') {
229+
derive = getSerdeFieldAttribute(
230+
'serde_with::As::<Vec<serde_with::DisplayFromStr>>',
231+
parentNode,
232+
traitOptions,
233+
);
213234
}
214235

215236
return {
@@ -385,25 +406,36 @@ export function getTypeManifestVisitor(options: {
385406
const resolvedNestedType = resolveNestedTypeNode(structFieldType.type);
386407

387408
let derive = '';
388-
if (fieldManifest.type === 'Pubkey') {
389-
derive =
390-
'#[cfg_attr(feature = "serde", serde(with = "serde_with::As::<serde_with::DisplayFromStr>"))]\n';
391-
} else if (fieldManifest.type === 'Vec<Pubkey>') {
392-
derive =
393-
'#[cfg_attr(feature = "serde", serde(with = "serde_with::As::<Vec<serde_with::DisplayFromStr>>"))]\n';
394-
} else if (
395-
isNode(resolvedNestedType, 'arrayTypeNode') &&
396-
isNode(resolvedNestedType.count, 'fixedCountNode') &&
397-
resolvedNestedType.count.value > 32
398-
) {
399-
derive = '#[cfg_attr(feature = "serde", serde(with = "serde_big_array::BigArray"))]\n';
400-
} else if (
401-
isNode(resolvedNestedType, ['bytesTypeNode', 'stringTypeNode']) &&
402-
isNode(structFieldType.type, 'fixedSizeTypeNode') &&
403-
structFieldType.type.size > 32
404-
) {
405-
derive =
406-
'#[cfg_attr(feature = "serde", serde(with = "serde_with::As::<serde_with::Bytes>"))]\n';
409+
if (parentNode) {
410+
if (fieldManifest.type === 'Pubkey') {
411+
derive = getSerdeFieldAttribute(
412+
'serde_with::As::<serde_with::DisplayFromStr>',
413+
parentNode,
414+
traitOptions,
415+
);
416+
} else if (fieldManifest.type === 'Vec<Pubkey>') {
417+
derive = getSerdeFieldAttribute(
418+
'serde_with::As::<Vec<serde_with::DisplayFromStr>>',
419+
parentNode,
420+
traitOptions,
421+
);
422+
} else if (
423+
isNode(resolvedNestedType, 'arrayTypeNode') &&
424+
isNode(resolvedNestedType.count, 'fixedCountNode') &&
425+
resolvedNestedType.count.value > 32
426+
) {
427+
derive = getSerdeFieldAttribute('serde_big_array::BigArray', parentNode, traitOptions);
428+
} else if (
429+
isNode(resolvedNestedType, ['bytesTypeNode', 'stringTypeNode']) &&
430+
isNode(structFieldType.type, 'fixedSizeTypeNode') &&
431+
structFieldType.type.size > 32
432+
) {
433+
derive = getSerdeFieldAttribute(
434+
'serde_with::As::<serde_with::Bytes>',
435+
parentNode,
436+
traitOptions,
437+
);
438+
}
407439
}
408440

409441
return {

src/utils/traitOptions.ts

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,60 @@ function extractFullyQualifiedNames(traits: string[], imports: ImportMap): strin
174174
return trait.slice(index + 2);
175175
});
176176
}
177+
178+
/**
179+
* Helper function to get the serde field attribute format based on trait configuration.
180+
* Returns the appropriate attribute string for serde field customization, or empty string if no serde traits.
181+
*/
182+
export function getSerdeFieldAttribute(
183+
serdeWith: string,
184+
node: AccountNode | DefinedTypeNode | InstructionNode,
185+
userOptions: TraitOptions = {},
186+
): string {
187+
assertIsNode(node, ['accountNode', 'definedTypeNode', 'instructionNode']);
188+
const options: Required<TraitOptions> = { ...DEFAULT_TRAIT_OPTIONS, ...userOptions };
189+
190+
// Get the node type and return early if it's a type alias.
191+
const nodeType = getNodeType(node);
192+
if (nodeType === 'alias') {
193+
return '';
194+
}
195+
196+
// Find all the traits for the node.
197+
const sanitizedOverrides = Object.fromEntries(
198+
Object.entries(options.overrides).map(([key, value]) => [camelCase(key), value]),
199+
);
200+
const nodeOverrides: string[] | undefined = sanitizedOverrides[node.name];
201+
const allTraits = nodeOverrides === undefined ? getDefaultTraits(nodeType, options) : nodeOverrides;
202+
203+
// Check if serde traits are present.
204+
const hasSerdeSerialize = allTraits.some(t => t === 'serde::Serialize' || t === 'Serialize');
205+
const hasSerdeDeserialize = allTraits.some(t => t === 'serde::Deserialize' || t === 'Deserialize');
206+
207+
if (!hasSerdeSerialize && !hasSerdeDeserialize) {
208+
return '';
209+
}
210+
211+
// Check if serde is feature-flagged.
212+
const partitionedTraits = partitionTraitsInFeatures(allTraits, options.featureFlags);
213+
const featuredTraits = partitionedTraits[1];
214+
215+
// Find which feature flag contains serde traits.
216+
let serdeFeatureName: string | undefined;
217+
for (const [feature, traits] of Object.entries(featuredTraits)) {
218+
if (
219+
traits.some(
220+
t => t === 'serde::Serialize' || t === 'serde::Deserialize' || t === 'Serialize' || t === 'Deserialize',
221+
)
222+
) {
223+
serdeFeatureName = feature;
224+
break;
225+
}
226+
}
227+
228+
if (serdeFeatureName) {
229+
return `#[cfg_attr(feature = "${serdeFeatureName}", serde(with = "${serdeWith}"))]\n`;
230+
} else {
231+
return `#[serde(with = "${serdeWith}")]\n`;
232+
}
233+
}

0 commit comments

Comments
 (0)