Skip to content

Commit ed43f34

Browse files
committed
Use traitOptions when visiting account nodes
1 parent 5ebb815 commit ed43f34

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

packages/renderers-rust/src/getRenderMapVisitor.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@ import {
2727
import { getTypeManifestVisitor } from './getTypeManifestVisitor';
2828
import { ImportMap } from './ImportMap';
2929
import { renderValueNode } from './renderValueNodeVisitor';
30-
import { getImportFromFactory, LinkOverrides, render, TraitOptions } from './utils';
30+
import { getImportFromFactory, getTraitsFromNodeFactory, LinkOverrides, render, TraitOptions } from './utils';
3131

3232
export type GetRenderMapOptions = {
3333
defaultTraitOverrides?: string[];
3434
dependencyMap?: Record<string, string>;
3535
linkOverrides?: LinkOverrides;
3636
renderParentInstructions?: boolean;
37-
traitOverrides?: TraitOptions;
37+
traitOptions?: TraitOptions;
3838
};
3939

4040
export function getRenderMapVisitor(options: GetRenderMapOptions = {}) {
@@ -44,7 +44,8 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}) {
4444
const renderParentInstructions = options.renderParentInstructions ?? false;
4545
const dependencyMap = options.dependencyMap ?? {};
4646
const getImportFrom = getImportFromFactory(options.linkOverrides ?? {});
47-
const typeManifestVisitor = getTypeManifestVisitor({ getImportFrom });
47+
const getTraitsFromNode = getTraitsFromNodeFactory(options.traitOptions);
48+
const typeManifestVisitor = getTypeManifestVisitor({ getImportFrom, getTraitsFromNode });
4849

4950
return pipe(
5051
staticVisitor(
@@ -149,6 +150,7 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}) {
149150
node.arguments.forEach(argument => {
150151
const argumentVisitor = getTypeManifestVisitor({
151152
getImportFrom,
153+
getTraitsFromNode,
152154
nestedStruct: true,
153155
parentName: `${pascalCase(node.name)}InstructionData`,
154156
});
@@ -189,6 +191,7 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}) {
189191
const struct = structTypeNodeFromInstructionArgumentNodes(node.arguments);
190192
const structVisitor = getTypeManifestVisitor({
191193
getImportFrom,
194+
getTraitsFromNode,
192195
parentName: `${pascalCase(node.name)}InstructionData`,
193196
});
194197
const typeManifest = visit(struct, structVisitor);

packages/renderers-rust/src/getTypeManifestVisitor.ts

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import {
1717
import { extendVisitor, mergeVisitor, pipe, visit } from '@codama/visitors-core';
1818

1919
import { ImportMap } from './ImportMap';
20-
import { GetImportFromFunction, rustDocblock } from './utils';
20+
import { GetImportFromFunction, GetTraitsFromNodeFunction, rustDocblock } from './utils';
2121

2222
export type TypeManifest = {
2323
imports: ImportMap;
@@ -27,10 +27,11 @@ export type TypeManifest = {
2727

2828
export function getTypeManifestVisitor(options: {
2929
getImportFrom: GetImportFromFunction;
30+
getTraitsFromNode: GetTraitsFromNodeFunction;
3031
nestedStruct?: boolean;
3132
parentName?: string | null;
3233
}) {
33-
const { getImportFrom } = options;
34+
const { getImportFrom, getTraitsFromNode } = options;
3435
let parentName: string | null = options.parentName ?? null;
3536
let nestedStruct: boolean = options.nestedStruct ?? false;
3637
let inlineStruct: boolean = false;
@@ -50,14 +51,12 @@ export function getTypeManifestVisitor(options: {
5051
visitAccount(account, { self }) {
5152
parentName = pascalCase(account.name);
5253
const manifest = visit(account.data, self);
53-
manifest.imports.add(['borsh::BorshSerialize', 'borsh::BorshDeserialize']);
54+
const traits = getTraitsFromNode(account);
55+
manifest.imports.mergeWith(traits.imports);
5456
parentName = null;
5557
return {
5658
...manifest,
57-
type:
58-
'#[derive(BorshSerialize, BorshDeserialize, Clone, Debug, Eq, PartialEq)]\n' +
59-
'#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]\n' +
60-
`${manifest.type}`,
59+
type: traits.render + manifest.type,
6160
};
6261
},
6362

0 commit comments

Comments
 (0)