From 1e753a8263ad109cf34eea74f5c53613a3eeae4d Mon Sep 17 00:00:00 2001 From: Loris Leiva Date: Sun, 3 Nov 2024 15:55:41 +0000 Subject: [PATCH] Allow passing NodeStacks to nested visitors --- .changeset/spicy-camels-tease.md | 6 + packages/library/test/index.test.ts | 2 +- .../src/getRenderMapVisitor.ts | 2 +- .../src/getTypeManifestVisitor.ts | 18 +-- .../renderers-js/src/getRenderMapVisitor.ts | 7 +- .../src/getTypeManifestVisitor.ts | 18 +-- .../renderers-rust/src/getRenderMapVisitor.ts | 7 +- .../src/getTypeManifestVisitor.ts | 2 +- .../src/bottomUpTransformerVisitor.ts | 6 +- .../visitors-core/src/deleteNodesVisitor.ts | 4 +- .../visitors-core/src/getByteSizeVisitor.ts | 22 ++-- packages/visitors-core/src/identityVisitor.ts | 122 +++++++++--------- packages/visitors-core/src/mergeVisitor.ts | 122 +++++++++--------- .../src/nonNullableIdentityVisitor.ts | 6 +- .../visitors-core/src/removeDocsVisitor.ts | 4 +- packages/visitors-core/src/staticVisitor.ts | 5 +- .../src/topDownTransformerVisitor.ts | 6 +- packages/visitors-core/src/voidVisitor.ts | 6 +- .../test/bottomUpTransformerVisitor.test.ts | 56 +++++++- .../test/deleteNodesVisitor.test.ts | 7 +- .../visitors-core/test/extendVisitor.test.ts | 2 +- .../test/getByteSizeVisitor.test.ts | 6 +- .../test/identityVisitor.test.ts | 2 +- .../visitors-core/test/mapVisitor.test.ts | 2 +- .../visitors-core/test/mergeVisitor.test.ts | 2 +- .../test/removeDocsVisitor.test.ts | 2 +- .../visitors-core/test/staticVisitor.test.ts | 2 +- .../test/topDownTransformerVisitor.test.ts | 50 ++++++- .../src/fillDefaultPdaSeedValuesVisitor.ts | 2 +- .../src/setFixedAccountSizesVisitor.ts | 4 +- ...tInstructionAccountDefaultValuesVisitor.ts | 2 +- ...ransformDefinedTypesIntoAccountsVisitor.ts | 2 +- 32 files changed, 303 insertions(+), 203 deletions(-) create mode 100644 .changeset/spicy-camels-tease.md diff --git a/.changeset/spicy-camels-tease.md b/.changeset/spicy-camels-tease.md new file mode 100644 index 000000000..01484179f --- /dev/null +++ b/.changeset/spicy-camels-tease.md @@ -0,0 +1,6 @@ +--- +'@codama/visitors-core': minor +'@codama/visitors': minor +--- + +Allow passing `NodeStacks` to nested visitors diff --git a/packages/library/test/index.test.ts b/packages/library/test/index.test.ts index ded3def98..891cb03f3 100644 --- a/packages/library/test/index.test.ts +++ b/packages/library/test/index.test.ts @@ -12,7 +12,7 @@ test('it exports visitors', () => { test('it accepts visitors', () => { const codama = createFromRoot(rootNode(programNode({ name: 'myProgram', publicKey: '1111' }))); - const visitor = voidVisitor(['rootNode']); + const visitor = voidVisitor({ keys: ['rootNode'] }); const result = codama.accept(visitor) satisfies void; expect(typeof result).toBe('undefined'); }); diff --git a/packages/renderers-js-umi/src/getRenderMapVisitor.ts b/packages/renderers-js-umi/src/getRenderMapVisitor.ts index 1d4bb4e96..77b6d3f50 100644 --- a/packages/renderers-js-umi/src/getRenderMapVisitor.ts +++ b/packages/renderers-js-umi/src/getRenderMapVisitor.ts @@ -100,7 +100,7 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor< }); const typeManifestVisitor = getTypeManifestVisitor(); const resolvedInstructionInputVisitor = getResolvedInstructionInputsVisitor(); - const byteSizeVisitor = getByteSizeVisitor(linkables, stack); + const byteSizeVisitor = getByteSizeVisitor(linkables, { stack }); function getInstructionAccountType(account: ResolvedInstructionAccount): string { if (account.isPda && account.isSigner === false) return 'Pda'; diff --git a/packages/renderers-js-umi/src/getTypeManifestVisitor.ts b/packages/renderers-js-umi/src/getTypeManifestVisitor.ts index 8e392eaa5..299647d00 100644 --- a/packages/renderers-js-umi/src/getTypeManifestVisitor.ts +++ b/packages/renderers-js-umi/src/getTypeManifestVisitor.ts @@ -85,14 +85,16 @@ export function getTypeManifestVisitor(input: { value: '', valueImports: new ImportMap(), }) as TypeManifest, - [ - ...REGISTERED_TYPE_NODE_KINDS, - ...REGISTERED_VALUE_NODE_KINDS, - 'definedTypeLinkNode', - 'definedTypeNode', - 'accountNode', - 'instructionNode', - ], + { + keys: [ + ...REGISTERED_TYPE_NODE_KINDS, + ...REGISTERED_VALUE_NODE_KINDS, + 'definedTypeLinkNode', + 'definedTypeNode', + 'accountNode', + 'instructionNode', + ], + }, ), v => extendVisitor(v, { diff --git a/packages/renderers-js/src/getRenderMapVisitor.ts b/packages/renderers-js/src/getRenderMapVisitor.ts index 14d9ed3e5..fd087b21c 100644 --- a/packages/renderers-js/src/getRenderMapVisitor.ts +++ b/packages/renderers-js/src/getRenderMapVisitor.ts @@ -136,10 +136,9 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}) { }; return pipe( - staticVisitor( - () => new RenderMap(), - ['rootNode', 'programNode', 'pdaNode', 'accountNode', 'definedTypeNode', 'instructionNode'], - ), + staticVisitor(() => new RenderMap(), { + keys: ['rootNode', 'programNode', 'pdaNode', 'accountNode', 'definedTypeNode', 'instructionNode'], + }), v => extendVisitor(v, { visitAccount(node) { diff --git a/packages/renderers-js/src/getTypeManifestVisitor.ts b/packages/renderers-js/src/getTypeManifestVisitor.ts index d6f7e72c4..e5ed56ab9 100644 --- a/packages/renderers-js/src/getTypeManifestVisitor.ts +++ b/packages/renderers-js/src/getTypeManifestVisitor.ts @@ -58,14 +58,16 @@ export function getTypeManifestVisitor(input: { strictType: fragment(''), value: fragment(''), }) as TypeManifest, - [ - ...REGISTERED_TYPE_NODE_KINDS, - ...REGISTERED_VALUE_NODE_KINDS, - 'definedTypeLinkNode', - 'definedTypeNode', - 'accountNode', - 'instructionNode', - ], + { + keys: [ + ...REGISTERED_TYPE_NODE_KINDS, + ...REGISTERED_VALUE_NODE_KINDS, + 'definedTypeLinkNode', + 'definedTypeNode', + 'accountNode', + 'instructionNode', + ], + }, ), visitor => extendVisitor(visitor, { diff --git a/packages/renderers-rust/src/getRenderMapVisitor.ts b/packages/renderers-rust/src/getRenderMapVisitor.ts index c1e34a0b9..63b0cd2b0 100644 --- a/packages/renderers-rust/src/getRenderMapVisitor.ts +++ b/packages/renderers-rust/src/getRenderMapVisitor.ts @@ -53,10 +53,9 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}) { const anchorTraits = options.anchorTraits ?? true; return pipe( - staticVisitor( - () => new RenderMap(), - ['rootNode', 'programNode', 'instructionNode', 'accountNode', 'definedTypeNode'], - ), + staticVisitor(() => new RenderMap(), { + keys: ['rootNode', 'programNode', 'instructionNode', 'accountNode', 'definedTypeNode'], + }), v => extendVisitor(v, { visitAccount(node) { diff --git a/packages/renderers-rust/src/getTypeManifestVisitor.ts b/packages/renderers-rust/src/getTypeManifestVisitor.ts index e157471d6..938df56a5 100644 --- a/packages/renderers-rust/src/getTypeManifestVisitor.ts +++ b/packages/renderers-rust/src/getTypeManifestVisitor.ts @@ -44,7 +44,7 @@ export function getTypeManifestVisitor(options: { ...mergeManifests(values), type: values.map(v => v.type).join('\n'), }), - [...REGISTERED_TYPE_NODE_KINDS, 'definedTypeLinkNode', 'definedTypeNode', 'accountNode'], + { keys: [...REGISTERED_TYPE_NODE_KINDS, 'definedTypeLinkNode', 'definedTypeNode', 'accountNode'] }, ), v => extendVisitor(v, { diff --git a/packages/visitors-core/src/bottomUpTransformerVisitor.ts b/packages/visitors-core/src/bottomUpTransformerVisitor.ts index 74f81773c..9a3aaf3fd 100644 --- a/packages/visitors-core/src/bottomUpTransformerVisitor.ts +++ b/packages/visitors-core/src/bottomUpTransformerVisitor.ts @@ -17,7 +17,7 @@ export type BottomUpNodeTransformerWithSelector = { export function bottomUpTransformerVisitor( transformers: (BottomUpNodeTransformer | BottomUpNodeTransformerWithSelector)[], - nodeKeys?: TNodeKind[], + options: { keys?: TNodeKind[]; stack?: NodeStack } = {}, ): Visitor { const transformerFunctions = transformers.map((transformer): BottomUpNodeTransformer => { if (typeof transformer === 'function') return transformer; @@ -27,9 +27,9 @@ export function bottomUpTransformerVisitor interceptVisitor(v, (node, next) => { return transformerFunctions.reduce( diff --git a/packages/visitors-core/src/deleteNodesVisitor.ts b/packages/visitors-core/src/deleteNodesVisitor.ts index 636d96f47..4472ee79b 100644 --- a/packages/visitors-core/src/deleteNodesVisitor.ts +++ b/packages/visitors-core/src/deleteNodesVisitor.ts @@ -5,7 +5,7 @@ import { TopDownNodeTransformerWithSelector, topDownTransformerVisitor } from '. export function deleteNodesVisitor( selectors: NodeSelector[], - nodeKeys?: TNodeKind[], + options?: Parameters>[1], ) { return topDownTransformerVisitor( selectors.map( @@ -14,6 +14,6 @@ export function deleteNodesVisitor( transform: () => null, }), ), - nodeKeys, + options, ); } diff --git a/packages/visitors-core/src/getByteSizeVisitor.ts b/packages/visitors-core/src/getByteSizeVisitor.ts index 743da04ca..d95963824 100644 --- a/packages/visitors-core/src/getByteSizeVisitor.ts +++ b/packages/visitors-core/src/getByteSizeVisitor.ts @@ -19,8 +19,10 @@ export type ByteSizeVisitorKeys = export function getByteSizeVisitor( linkables: LinkableDictionary, - stack: NodeStack, + options: { stack?: NodeStack } = {}, ): Visitor { + const stack = options.stack ?? new NodeStack(); + const visitedDefinedTypes = new Map(); const definedTypeStack: string[] = []; @@ -30,14 +32,16 @@ export function getByteSizeVisitor( const baseVisitor = mergeVisitor( () => null as number | null, (_, values) => sumSizes(values), - [ - ...REGISTERED_TYPE_NODE_KINDS, - 'definedTypeLinkNode', - 'definedTypeNode', - 'accountNode', - 'instructionNode', - 'instructionArgumentNode', - ], + { + keys: [ + ...REGISTERED_TYPE_NODE_KINDS, + 'definedTypeLinkNode', + 'definedTypeNode', + 'accountNode', + 'instructionNode', + 'instructionArgumentNode', + ], + }, ); return pipe( diff --git a/packages/visitors-core/src/identityVisitor.ts b/packages/visitors-core/src/identityVisitor.ts index b236e9216..9bd26f763 100644 --- a/packages/visitors-core/src/identityVisitor.ts +++ b/packages/visitors-core/src/identityVisitor.ts @@ -76,16 +76,16 @@ import { staticVisitor } from './staticVisitor'; import { visit as baseVisit, Visitor } from './visitor'; export function identityVisitor( - nodeKeys: TNodeKind[] = REGISTERED_NODE_KINDS as TNodeKind[], + options: { keys?: TNodeKind[] } = {}, ): Visitor { - const castedNodeKeys: NodeKind[] = nodeKeys; - const visitor = staticVisitor(node => Object.freeze({ ...node }), castedNodeKeys) as Visitor; + const keys: NodeKind[] = options.keys ?? (REGISTERED_NODE_KINDS as TNodeKind[]); + const visitor = staticVisitor(node => Object.freeze({ ...node }), { keys }) as Visitor; const visit = (v: Visitor) => (node: Node): Node | null => - castedNodeKeys.includes(node.kind) ? baseVisit(node, v) : Object.freeze({ ...node }); + keys.includes(node.kind) ? baseVisit(node, v) : Object.freeze({ ...node }); - if (castedNodeKeys.includes('rootNode')) { + if (keys.includes('rootNode')) { visitor.visitRoot = function visitRoot(node) { const program = visit(this)(node.program); if (program === null) return null; @@ -97,7 +97,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('programNode')) { + if (keys.includes('programNode')) { visitor.visitProgram = function visitProgram(node) { return programNode({ ...node, @@ -114,7 +114,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('pdaNode')) { + if (keys.includes('pdaNode')) { visitor.visitPda = function visitPda(node) { return pdaNode({ ...node, @@ -123,7 +123,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('accountNode')) { + if (keys.includes('accountNode')) { visitor.visitAccount = function visitAccount(node) { const data = visit(this)(node.data); if (data === null) return null; @@ -134,7 +134,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('instructionNode')) { + if (keys.includes('instructionNode')) { visitor.visitInstruction = function visitInstruction(node) { return instructionNode({ ...node, @@ -169,7 +169,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('instructionAccountNode')) { + if (keys.includes('instructionAccountNode')) { visitor.visitInstructionAccount = function visitInstructionAccount(node) { const defaultValue = node.defaultValue ? (visit(this)(node.defaultValue) ?? undefined) : undefined; if (defaultValue) assertIsNode(defaultValue, INSTRUCTION_INPUT_VALUE_NODES); @@ -177,7 +177,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('instructionArgumentNode')) { + if (keys.includes('instructionArgumentNode')) { visitor.visitInstructionArgument = function visitInstructionArgument(node) { const type = visit(this)(node.type); if (type === null) return null; @@ -188,7 +188,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('instructionRemainingAccountsNode')) { + if (keys.includes('instructionRemainingAccountsNode')) { visitor.visitInstructionRemainingAccounts = function visitInstructionRemainingAccounts(node) { const value = visit(this)(node.value); if (value === null) return null; @@ -197,7 +197,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('instructionByteDeltaNode')) { + if (keys.includes('instructionByteDeltaNode')) { visitor.visitInstructionByteDelta = function visitInstructionByteDelta(node) { const value = visit(this)(node.value); if (value === null) return null; @@ -206,7 +206,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('definedTypeNode')) { + if (keys.includes('definedTypeNode')) { visitor.visitDefinedType = function visitDefinedType(node) { const type = visit(this)(node.type); if (type === null) return null; @@ -215,7 +215,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('arrayTypeNode')) { + if (keys.includes('arrayTypeNode')) { visitor.visitArrayType = function visitArrayType(node) { const size = visit(this)(node.count); if (size === null) return null; @@ -227,7 +227,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('enumTypeNode')) { + if (keys.includes('enumTypeNode')) { visitor.visitEnumType = function visitEnumType(node) { return enumTypeNode( node.variants.map(visit(this)).filter(removeNullAndAssertIsNodeFilter(ENUM_VARIANT_TYPE_NODES)), @@ -236,7 +236,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('enumStructVariantTypeNode')) { + if (keys.includes('enumStructVariantTypeNode')) { visitor.visitEnumStructVariantType = function visitEnumStructVariantType(node) { const newStruct = visit(this)(node.struct); if (!newStruct) { @@ -250,7 +250,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('enumTupleVariantTypeNode')) { + if (keys.includes('enumTupleVariantTypeNode')) { visitor.visitEnumTupleVariantType = function visitEnumTupleVariantType(node) { const newTuple = visit(this)(node.tuple); if (!newTuple) { @@ -264,7 +264,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('mapTypeNode')) { + if (keys.includes('mapTypeNode')) { visitor.visitMapType = function visitMapType(node) { const size = visit(this)(node.count); if (size === null) return null; @@ -279,7 +279,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('optionTypeNode')) { + if (keys.includes('optionTypeNode')) { visitor.visitOptionType = function visitOptionType(node) { const prefix = visit(this)(node.prefix); if (prefix === null) return null; @@ -291,7 +291,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('zeroableOptionTypeNode')) { + if (keys.includes('zeroableOptionTypeNode')) { visitor.visitZeroableOptionType = function visitZeroableOptionType(node) { const item = visit(this)(node.item); if (item === null) return null; @@ -302,7 +302,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('remainderOptionTypeNode')) { + if (keys.includes('remainderOptionTypeNode')) { visitor.visitRemainderOptionType = function visitRemainderOptionType(node) { const item = visit(this)(node.item); if (item === null) return null; @@ -311,7 +311,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('booleanTypeNode')) { + if (keys.includes('booleanTypeNode')) { visitor.visitBooleanType = function visitBooleanType(node) { const size = visit(this)(node.size); if (size === null) return null; @@ -320,7 +320,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('setTypeNode')) { + if (keys.includes('setTypeNode')) { visitor.visitSetType = function visitSetType(node) { const size = visit(this)(node.count); if (size === null) return null; @@ -332,14 +332,14 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('structTypeNode')) { + if (keys.includes('structTypeNode')) { visitor.visitStructType = function visitStructType(node) { const fields = node.fields.map(visit(this)).filter(removeNullAndAssertIsNodeFilter('structFieldTypeNode')); return structTypeNode(fields); }; } - if (castedNodeKeys.includes('structFieldTypeNode')) { + if (keys.includes('structFieldTypeNode')) { visitor.visitStructFieldType = function visitStructFieldType(node) { const type = visit(this)(node.type); if (type === null) return null; @@ -350,14 +350,14 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('tupleTypeNode')) { + if (keys.includes('tupleTypeNode')) { visitor.visitTupleType = function visitTupleType(node) { const items = node.items.map(visit(this)).filter(removeNullAndAssertIsNodeFilter(TYPE_NODES)); return tupleTypeNode(items); }; } - if (castedNodeKeys.includes('amountTypeNode')) { + if (keys.includes('amountTypeNode')) { visitor.visitAmountType = function visitAmountType(node) { const number = visit(this)(node.number); if (number === null) return null; @@ -366,7 +366,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('dateTimeTypeNode')) { + if (keys.includes('dateTimeTypeNode')) { visitor.visitDateTimeType = function visitDateTimeType(node) { const number = visit(this)(node.number); if (number === null) return null; @@ -375,7 +375,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('solAmountTypeNode')) { + if (keys.includes('solAmountTypeNode')) { visitor.visitSolAmountType = function visitSolAmountType(node) { const number = visit(this)(node.number); if (number === null) return null; @@ -384,7 +384,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('prefixedCountNode')) { + if (keys.includes('prefixedCountNode')) { visitor.visitPrefixedCount = function visitPrefixedCount(node) { const prefix = visit(this)(node.prefix); if (prefix === null) return null; @@ -393,13 +393,13 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('arrayValueNode')) { + if (keys.includes('arrayValueNode')) { visitor.visitArrayValue = function visitArrayValue(node) { return arrayValueNode(node.items.map(visit(this)).filter(removeNullAndAssertIsNodeFilter(VALUE_NODES))); }; } - if (castedNodeKeys.includes('constantValueNode')) { + if (keys.includes('constantValueNode')) { visitor.visitConstantValue = function visitConstantValue(node) { const type = visit(this)(node.type); if (type === null) return null; @@ -411,7 +411,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('enumValueNode')) { + if (keys.includes('enumValueNode')) { visitor.visitEnumValue = function visitEnumValue(node) { const enumLink = visit(this)(node.enum); if (enumLink === null) return null; @@ -422,7 +422,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('mapValueNode')) { + if (keys.includes('mapValueNode')) { visitor.visitMapValue = function visitMapValue(node) { return mapValueNode( node.entries.map(visit(this)).filter(removeNullAndAssertIsNodeFilter('mapEntryValueNode')), @@ -430,7 +430,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('mapEntryValueNode')) { + if (keys.includes('mapEntryValueNode')) { visitor.visitMapEntryValue = function visitMapEntryValue(node) { const key = visit(this)(node.key); if (key === null) return null; @@ -442,13 +442,13 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('setValueNode')) { + if (keys.includes('setValueNode')) { visitor.visitSetValue = function visitSetValue(node) { return setValueNode(node.items.map(visit(this)).filter(removeNullAndAssertIsNodeFilter(VALUE_NODES))); }; } - if (castedNodeKeys.includes('someValueNode')) { + if (keys.includes('someValueNode')) { visitor.visitSomeValue = function visitSomeValue(node) { const value = visit(this)(node.value); if (value === null) return null; @@ -457,7 +457,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('structValueNode')) { + if (keys.includes('structValueNode')) { visitor.visitStructValue = function visitStructValue(node) { return structValueNode( node.fields.map(visit(this)).filter(removeNullAndAssertIsNodeFilter('structFieldValueNode')), @@ -465,7 +465,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('structFieldValueNode')) { + if (keys.includes('structFieldValueNode')) { visitor.visitStructFieldValue = function visitStructFieldValue(node) { const value = visit(this)(node.value); if (value === null) return null; @@ -474,13 +474,13 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('tupleValueNode')) { + if (keys.includes('tupleValueNode')) { visitor.visitTupleValue = function visitTupleValue(node) { return tupleValueNode(node.items.map(visit(this)).filter(removeNullAndAssertIsNodeFilter(VALUE_NODES))); }; } - if (castedNodeKeys.includes('constantPdaSeedNode')) { + if (keys.includes('constantPdaSeedNode')) { visitor.visitConstantPdaSeed = function visitConstantPdaSeed(node) { const type = visit(this)(node.type); if (type === null) return null; @@ -492,7 +492,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('variablePdaSeedNode')) { + if (keys.includes('variablePdaSeedNode')) { visitor.visitVariablePdaSeed = function visitVariablePdaSeed(node) { const type = visit(this)(node.type); if (type === null) return null; @@ -501,7 +501,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('resolverValueNode')) { + if (keys.includes('resolverValueNode')) { visitor.visitResolverValue = function visitResolverValue(node) { const dependsOn = (node.dependsOn ?? []) .map(visit(this)) @@ -513,7 +513,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('conditionalValueNode')) { + if (keys.includes('conditionalValueNode')) { visitor.visitConditionalValue = function visitConditionalValue(node) { const condition = visit(this)(node.condition); if (condition === null) return null; @@ -529,7 +529,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('pdaValueNode')) { + if (keys.includes('pdaValueNode')) { visitor.visitPdaValue = function visitPdaValue(node) { const pda = visit(this)(node.pda); if (pda === null) return null; @@ -539,7 +539,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('pdaSeedValueNode')) { + if (keys.includes('pdaSeedValueNode')) { visitor.visitPdaSeedValue = function visitPdaSeedValue(node) { const value = visit(this)(node.value); if (value === null) return null; @@ -548,7 +548,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('fixedSizeTypeNode')) { + if (keys.includes('fixedSizeTypeNode')) { visitor.visitFixedSizeType = function visitFixedSizeType(node) { const type = visit(this)(node.type); if (type === null) return null; @@ -557,7 +557,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('sizePrefixTypeNode')) { + if (keys.includes('sizePrefixTypeNode')) { visitor.visitSizePrefixType = function visitSizePrefixType(node) { const prefix = visit(this)(node.prefix); if (prefix === null) return null; @@ -569,7 +569,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('preOffsetTypeNode')) { + if (keys.includes('preOffsetTypeNode')) { visitor.visitPreOffsetType = function visitPreOffsetType(node) { const type = visit(this)(node.type); if (type === null) return null; @@ -578,7 +578,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('postOffsetTypeNode')) { + if (keys.includes('postOffsetTypeNode')) { visitor.visitPostOffsetType = function visitPostOffsetType(node) { const type = visit(this)(node.type); if (type === null) return null; @@ -587,7 +587,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('sentinelTypeNode')) { + if (keys.includes('sentinelTypeNode')) { visitor.visitSentinelType = function visitSentinelType(node) { const sentinel = visit(this)(node.sentinel); if (sentinel === null) return null; @@ -599,7 +599,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('hiddenPrefixTypeNode')) { + if (keys.includes('hiddenPrefixTypeNode')) { visitor.visitHiddenPrefixType = function visitHiddenPrefixType(node) { const type = visit(this)(node.type); if (type === null) return null; @@ -610,7 +610,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('hiddenSuffixTypeNode')) { + if (keys.includes('hiddenSuffixTypeNode')) { visitor.visitHiddenSuffixType = function visitHiddenSuffixType(node) { const type = visit(this)(node.type); if (type === null) return null; @@ -621,7 +621,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('constantDiscriminatorNode')) { + if (keys.includes('constantDiscriminatorNode')) { visitor.visitConstantDiscriminator = function visitConstantDiscriminator(node) { const constant = visit(this)(node.constant); if (constant === null) return null; @@ -630,7 +630,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('accountLinkNode')) { + if (keys.includes('accountLinkNode')) { visitor.visitAccountLink = function visitAccountLink(node) { const program = node.program ? (visit(this)(node.program) ?? undefined) : undefined; if (program) assertIsNode(program, 'programLinkNode'); @@ -638,7 +638,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('definedTypeLinkNode')) { + if (keys.includes('definedTypeLinkNode')) { visitor.visitDefinedTypeLink = function visitDefinedTypeLink(node) { const program = node.program ? (visit(this)(node.program) ?? undefined) : undefined; if (program) assertIsNode(program, 'programLinkNode'); @@ -646,7 +646,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('instructionLinkNode')) { + if (keys.includes('instructionLinkNode')) { visitor.visitInstructionLink = function visitInstructionLink(node) { const program = node.program ? (visit(this)(node.program) ?? undefined) : undefined; if (program) assertIsNode(program, 'programLinkNode'); @@ -654,7 +654,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('instructionAccountLinkNode')) { + if (keys.includes('instructionAccountLinkNode')) { visitor.visitInstructionAccountLink = function visitInstructionAccountLink(node) { const instruction = node.instruction ? (visit(this)(node.instruction) ?? undefined) : undefined; if (instruction) assertIsNode(instruction, 'instructionLinkNode'); @@ -662,7 +662,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('instructionArgumentLinkNode')) { + if (keys.includes('instructionArgumentLinkNode')) { visitor.visitInstructionArgumentLink = function visitInstructionArgumentLink(node) { const instruction = node.instruction ? (visit(this)(node.instruction) ?? undefined) : undefined; if (instruction) assertIsNode(instruction, 'instructionLinkNode'); @@ -670,7 +670,7 @@ export function identityVisitor( }; } - if (castedNodeKeys.includes('pdaLinkNode')) { + if (keys.includes('pdaLinkNode')) { visitor.visitPdaLink = function visitPdaLink(node) { const program = node.program ? (visit(this)(node.program) ?? undefined) : undefined; if (program) assertIsNode(program, 'programLinkNode'); diff --git a/packages/visitors-core/src/mergeVisitor.ts b/packages/visitors-core/src/mergeVisitor.ts index b88aae925..d4a7888bb 100644 --- a/packages/visitors-core/src/mergeVisitor.ts +++ b/packages/visitors-core/src/mergeVisitor.ts @@ -6,22 +6,22 @@ import { visit as baseVisit, Visitor } from './visitor'; export function mergeVisitor( leafValue: (node: Node) => TReturn, merge: (node: Node, values: TReturn[]) => TReturn, - nodeKeys: TNodeKind[] = REGISTERED_NODE_KINDS as TNodeKind[], + options: { keys?: TNodeKind[] } = {}, ): Visitor { - const castedNodeKeys: NodeKind[] = nodeKeys; - const visitor = staticVisitor(leafValue, castedNodeKeys) as Visitor; + const keys: NodeKind[] = options.keys ?? (REGISTERED_NODE_KINDS as NodeKind[]); + const visitor = staticVisitor(leafValue, { keys }) as Visitor; const visit = (v: Visitor) => (node: Node): TReturn[] => - castedNodeKeys.includes(node.kind) ? [baseVisit(node, v)] : []; + keys.includes(node.kind) ? [baseVisit(node, v)] : []; - if (castedNodeKeys.includes('rootNode')) { + if (keys.includes('rootNode')) { visitor.visitRoot = function visitRoot(node) { return merge(node, getAllPrograms(node).flatMap(visit(this))); }; } - if (castedNodeKeys.includes('programNode')) { + if (keys.includes('programNode')) { visitor.visitProgram = function visitProgram(node) { return merge(node, [ ...node.pdas.flatMap(visit(this)), @@ -33,13 +33,13 @@ export function mergeVisitor( }; } - if (castedNodeKeys.includes('pdaNode')) { + if (keys.includes('pdaNode')) { visitor.visitPda = function visitPda(node) { return merge(node, node.seeds.flatMap(visit(this))); }; } - if (castedNodeKeys.includes('accountNode')) { + if (keys.includes('accountNode')) { visitor.visitAccount = function visitAccount(node) { return merge(node, [ ...visit(this)(node.data), @@ -49,7 +49,7 @@ export function mergeVisitor( }; } - if (castedNodeKeys.includes('instructionNode')) { + if (keys.includes('instructionNode')) { visitor.visitInstruction = function visitInstruction(node) { return merge(node, [ ...node.accounts.flatMap(visit(this)), @@ -63,13 +63,13 @@ export function mergeVisitor( }; } - if (castedNodeKeys.includes('instructionAccountNode')) { + if (keys.includes('instructionAccountNode')) { visitor.visitInstructionAccount = function visitInstructionAccount(node) { return merge(node, [...(node.defaultValue ? visit(this)(node.defaultValue) : [])]); }; } - if (castedNodeKeys.includes('instructionArgumentNode')) { + if (keys.includes('instructionArgumentNode')) { visitor.visitInstructionArgument = function visitInstructionArgument(node) { return merge(node, [ ...visit(this)(node.type), @@ -78,91 +78,91 @@ export function mergeVisitor( }; } - if (castedNodeKeys.includes('instructionRemainingAccountsNode')) { + if (keys.includes('instructionRemainingAccountsNode')) { visitor.visitInstructionRemainingAccounts = function visitInstructionRemainingAccounts(node) { return merge(node, visit(this)(node.value)); }; } - if (castedNodeKeys.includes('instructionByteDeltaNode')) { + if (keys.includes('instructionByteDeltaNode')) { visitor.visitInstructionByteDelta = function visitInstructionByteDelta(node) { return merge(node, visit(this)(node.value)); }; } - if (castedNodeKeys.includes('definedTypeNode')) { + if (keys.includes('definedTypeNode')) { visitor.visitDefinedType = function visitDefinedType(node) { return merge(node, visit(this)(node.type)); }; } - if (castedNodeKeys.includes('arrayTypeNode')) { + if (keys.includes('arrayTypeNode')) { visitor.visitArrayType = function visitArrayType(node) { return merge(node, [...visit(this)(node.count), ...visit(this)(node.item)]); }; } - if (castedNodeKeys.includes('enumTypeNode')) { + if (keys.includes('enumTypeNode')) { visitor.visitEnumType = function visitEnumType(node) { return merge(node, [...visit(this)(node.size), ...node.variants.flatMap(visit(this))]); }; } - if (castedNodeKeys.includes('enumStructVariantTypeNode')) { + if (keys.includes('enumStructVariantTypeNode')) { visitor.visitEnumStructVariantType = function visitEnumStructVariantType(node) { return merge(node, visit(this)(node.struct)); }; } - if (castedNodeKeys.includes('enumTupleVariantTypeNode')) { + if (keys.includes('enumTupleVariantTypeNode')) { visitor.visitEnumTupleVariantType = function visitEnumTupleVariantType(node) { return merge(node, visit(this)(node.tuple)); }; } - if (castedNodeKeys.includes('mapTypeNode')) { + if (keys.includes('mapTypeNode')) { visitor.visitMapType = function visitMapType(node) { return merge(node, [...visit(this)(node.count), ...visit(this)(node.key), ...visit(this)(node.value)]); }; } - if (castedNodeKeys.includes('optionTypeNode')) { + if (keys.includes('optionTypeNode')) { visitor.visitOptionType = function visitOptionType(node) { return merge(node, [...visit(this)(node.prefix), ...visit(this)(node.item)]); }; } - if (castedNodeKeys.includes('zeroableOptionTypeNode')) { + if (keys.includes('zeroableOptionTypeNode')) { visitor.visitZeroableOptionType = function visitZeroableOptionType(node) { return merge(node, [...visit(this)(node.item), ...(node.zeroValue ? visit(this)(node.zeroValue) : [])]); }; } - if (castedNodeKeys.includes('remainderOptionTypeNode')) { + if (keys.includes('remainderOptionTypeNode')) { visitor.visitRemainderOptionType = function visitRemainderOptionType(node) { return merge(node, visit(this)(node.item)); }; } - if (castedNodeKeys.includes('booleanTypeNode')) { + if (keys.includes('booleanTypeNode')) { visitor.visitBooleanType = function visitBooleanType(node) { return merge(node, visit(this)(node.size)); }; } - if (castedNodeKeys.includes('setTypeNode')) { + if (keys.includes('setTypeNode')) { visitor.visitSetType = function visitSetType(node) { return merge(node, [...visit(this)(node.count), ...visit(this)(node.item)]); }; } - if (castedNodeKeys.includes('structTypeNode')) { + if (keys.includes('structTypeNode')) { visitor.visitStructType = function visitStructType(node) { return merge(node, node.fields.flatMap(visit(this))); }; } - if (castedNodeKeys.includes('structFieldTypeNode')) { + if (keys.includes('structFieldTypeNode')) { visitor.visitStructFieldType = function visitStructFieldType(node) { return merge(node, [ ...visit(this)(node.type), @@ -171,115 +171,115 @@ export function mergeVisitor( }; } - if (castedNodeKeys.includes('tupleTypeNode')) { + if (keys.includes('tupleTypeNode')) { visitor.visitTupleType = function visitTupleType(node) { return merge(node, node.items.flatMap(visit(this))); }; } - if (castedNodeKeys.includes('amountTypeNode')) { + if (keys.includes('amountTypeNode')) { visitor.visitAmountType = function visitAmountType(node) { return merge(node, visit(this)(node.number)); }; } - if (castedNodeKeys.includes('dateTimeTypeNode')) { + if (keys.includes('dateTimeTypeNode')) { visitor.visitDateTimeType = function visitDateTimeType(node) { return merge(node, visit(this)(node.number)); }; } - if (castedNodeKeys.includes('solAmountTypeNode')) { + if (keys.includes('solAmountTypeNode')) { visitor.visitSolAmountType = function visitSolAmountType(node) { return merge(node, visit(this)(node.number)); }; } - if (castedNodeKeys.includes('prefixedCountNode')) { + if (keys.includes('prefixedCountNode')) { visitor.visitPrefixedCount = function visitPrefixedCount(node) { return merge(node, visit(this)(node.prefix)); }; } - if (castedNodeKeys.includes('arrayValueNode')) { + if (keys.includes('arrayValueNode')) { visitor.visitArrayValue = function visitArrayValue(node) { return merge(node, node.items.flatMap(visit(this))); }; } - if (castedNodeKeys.includes('constantValueNode')) { + if (keys.includes('constantValueNode')) { visitor.visitConstantValue = function visitConstantValue(node) { return merge(node, [...visit(this)(node.type), ...visit(this)(node.value)]); }; } - if (castedNodeKeys.includes('enumValueNode')) { + if (keys.includes('enumValueNode')) { visitor.visitEnumValue = function visitEnumValue(node) { return merge(node, [...visit(this)(node.enum), ...(node.value ? visit(this)(node.value) : [])]); }; } - if (castedNodeKeys.includes('mapValueNode')) { + if (keys.includes('mapValueNode')) { visitor.visitMapValue = function visitMapValue(node) { return merge(node, node.entries.flatMap(visit(this))); }; } - if (castedNodeKeys.includes('mapEntryValueNode')) { + if (keys.includes('mapEntryValueNode')) { visitor.visitMapEntryValue = function visitMapEntryValue(node) { return merge(node, [...visit(this)(node.key), ...visit(this)(node.value)]); }; } - if (castedNodeKeys.includes('setValueNode')) { + if (keys.includes('setValueNode')) { visitor.visitSetValue = function visitSetValue(node) { return merge(node, node.items.flatMap(visit(this))); }; } - if (castedNodeKeys.includes('someValueNode')) { + if (keys.includes('someValueNode')) { visitor.visitSomeValue = function visitSomeValue(node) { return merge(node, visit(this)(node.value)); }; } - if (castedNodeKeys.includes('structValueNode')) { + if (keys.includes('structValueNode')) { visitor.visitStructValue = function visitStructValue(node) { return merge(node, node.fields.flatMap(visit(this))); }; } - if (castedNodeKeys.includes('structFieldValueNode')) { + if (keys.includes('structFieldValueNode')) { visitor.visitStructFieldValue = function visitStructFieldValue(node) { return merge(node, visit(this)(node.value)); }; } - if (castedNodeKeys.includes('tupleValueNode')) { + if (keys.includes('tupleValueNode')) { visitor.visitTupleValue = function visitTupleValue(node) { return merge(node, node.items.flatMap(visit(this))); }; } - if (castedNodeKeys.includes('constantPdaSeedNode')) { + if (keys.includes('constantPdaSeedNode')) { visitor.visitConstantPdaSeed = function visitConstantPdaSeed(node) { return merge(node, [...visit(this)(node.type), ...visit(this)(node.value)]); }; } - if (castedNodeKeys.includes('variablePdaSeedNode')) { + if (keys.includes('variablePdaSeedNode')) { visitor.visitVariablePdaSeed = function visitVariablePdaSeed(node) { return merge(node, visit(this)(node.type)); }; } - if (castedNodeKeys.includes('resolverValueNode')) { + if (keys.includes('resolverValueNode')) { visitor.visitResolverValue = function visitResolverValue(node) { return merge(node, (node.dependsOn ?? []).flatMap(visit(this))); }; } - if (castedNodeKeys.includes('conditionalValueNode')) { + if (keys.includes('conditionalValueNode')) { visitor.visitConditionalValue = function visitConditionalValue(node) { return merge(node, [ ...visit(this)(node.condition), @@ -290,97 +290,97 @@ export function mergeVisitor( }; } - if (castedNodeKeys.includes('pdaValueNode')) { + if (keys.includes('pdaValueNode')) { visitor.visitPdaValue = function visitPdaValue(node) { return merge(node, [...visit(this)(node.pda), ...node.seeds.flatMap(visit(this))]); }; } - if (castedNodeKeys.includes('pdaSeedValueNode')) { + if (keys.includes('pdaSeedValueNode')) { visitor.visitPdaSeedValue = function visitPdaSeedValue(node) { return merge(node, visit(this)(node.value)); }; } - if (castedNodeKeys.includes('fixedSizeTypeNode')) { + if (keys.includes('fixedSizeTypeNode')) { visitor.visitFixedSizeType = function visitFixedSizeType(node) { return merge(node, visit(this)(node.type)); }; } - if (castedNodeKeys.includes('sizePrefixTypeNode')) { + if (keys.includes('sizePrefixTypeNode')) { visitor.visitSizePrefixType = function visitSizePrefixType(node) { return merge(node, [...visit(this)(node.prefix), ...visit(this)(node.type)]); }; } - if (castedNodeKeys.includes('preOffsetTypeNode')) { + if (keys.includes('preOffsetTypeNode')) { visitor.visitPreOffsetType = function visitPreOffsetType(node) { return merge(node, visit(this)(node.type)); }; } - if (castedNodeKeys.includes('postOffsetTypeNode')) { + if (keys.includes('postOffsetTypeNode')) { visitor.visitPostOffsetType = function visitPostOffsetType(node) { return merge(node, visit(this)(node.type)); }; } - if (castedNodeKeys.includes('sentinelTypeNode')) { + if (keys.includes('sentinelTypeNode')) { visitor.visitSentinelType = function visitSentinelType(node) { return merge(node, [...visit(this)(node.sentinel), ...visit(this)(node.type)]); }; } - if (castedNodeKeys.includes('hiddenPrefixTypeNode')) { + if (keys.includes('hiddenPrefixTypeNode')) { visitor.visitHiddenPrefixType = function visitHiddenPrefixType(node) { return merge(node, [...node.prefix.flatMap(visit(this)), ...visit(this)(node.type)]); }; } - if (castedNodeKeys.includes('hiddenSuffixTypeNode')) { + if (keys.includes('hiddenSuffixTypeNode')) { visitor.visitHiddenSuffixType = function visitHiddenSuffixType(node) { return merge(node, [...visit(this)(node.type), ...node.suffix.flatMap(visit(this))]); }; } - if (castedNodeKeys.includes('constantDiscriminatorNode')) { + if (keys.includes('constantDiscriminatorNode')) { visitor.visitConstantDiscriminator = function visitConstantDiscriminator(node) { return merge(node, visit(this)(node.constant)); }; } - if (castedNodeKeys.includes('accountLinkNode')) { + if (keys.includes('accountLinkNode')) { visitor.visitAccountLink = function visitAccountLink(node) { return merge(node, node.program ? visit(this)(node.program) : []); }; } - if (castedNodeKeys.includes('definedTypeLinkNode')) { + if (keys.includes('definedTypeLinkNode')) { visitor.visitDefinedTypeLink = function visitDefinedTypeLink(node) { return merge(node, node.program ? visit(this)(node.program) : []); }; } - if (castedNodeKeys.includes('instructionLinkNode')) { + if (keys.includes('instructionLinkNode')) { visitor.visitInstructionLink = function visitInstructionLink(node) { return merge(node, node.program ? visit(this)(node.program) : []); }; } - if (castedNodeKeys.includes('instructionAccountLinkNode')) { + if (keys.includes('instructionAccountLinkNode')) { visitor.visitInstructionAccountLink = function visitInstructionAccountLink(node) { return merge(node, node.instruction ? visit(this)(node.instruction) : []); }; } - if (castedNodeKeys.includes('instructionArgumentLinkNode')) { + if (keys.includes('instructionArgumentLinkNode')) { visitor.visitInstructionArgumentLink = function visitInstructionArgumentLink(node) { return merge(node, node.instruction ? visit(this)(node.instruction) : []); }; } - if (castedNodeKeys.includes('pdaLinkNode')) { + if (keys.includes('pdaLinkNode')) { visitor.visitPdaLink = function visitPdaLink(node) { return merge(node, node.program ? visit(this)(node.program) : []); }; diff --git a/packages/visitors-core/src/nonNullableIdentityVisitor.ts b/packages/visitors-core/src/nonNullableIdentityVisitor.ts index 6e473c9b9..e771a6dd7 100644 --- a/packages/visitors-core/src/nonNullableIdentityVisitor.ts +++ b/packages/visitors-core/src/nonNullableIdentityVisitor.ts @@ -1,10 +1,10 @@ -import { Node, NodeKind, REGISTERED_NODE_KINDS } from '@codama/nodes'; +import { Node, NodeKind } from '@codama/nodes'; import { identityVisitor } from './identityVisitor'; import { Visitor } from './visitor'; export function nonNullableIdentityVisitor( - nodeKeys: TNodeKind[] = REGISTERED_NODE_KINDS as TNodeKind[], + options: { keys?: TNodeKind[] } = {}, ): Visitor { - return identityVisitor(nodeKeys) as Visitor; + return identityVisitor(options) as Visitor; } diff --git a/packages/visitors-core/src/removeDocsVisitor.ts b/packages/visitors-core/src/removeDocsVisitor.ts index 525abc0b9..e1ad81eec 100644 --- a/packages/visitors-core/src/removeDocsVisitor.ts +++ b/packages/visitors-core/src/removeDocsVisitor.ts @@ -3,8 +3,8 @@ import { NodeKind } from '@codama/nodes'; import { interceptVisitor } from './interceptVisitor'; import { nonNullableIdentityVisitor } from './nonNullableIdentityVisitor'; -export function removeDocsVisitor(nodeKeys?: TNodeKind[]) { - return interceptVisitor(nonNullableIdentityVisitor(nodeKeys), (node, next) => { +export function removeDocsVisitor(options: { keys?: TNodeKind[] } = {}) { + return interceptVisitor(nonNullableIdentityVisitor(options), (node, next) => { if ('docs' in node) { return next({ ...node, docs: [] }); } diff --git a/packages/visitors-core/src/staticVisitor.ts b/packages/visitors-core/src/staticVisitor.ts index c9e03c710..0f9d8f9d4 100644 --- a/packages/visitors-core/src/staticVisitor.ts +++ b/packages/visitors-core/src/staticVisitor.ts @@ -4,10 +4,11 @@ import { getVisitFunctionName, Visitor } from './visitor'; export function staticVisitor( fn: (node: Node) => TReturn, - nodeKeys: TNodeKind[] = REGISTERED_NODE_KINDS as TNodeKind[], + options: { keys?: TNodeKind[] } = {}, ): Visitor { + const keys = options.keys ?? (REGISTERED_NODE_KINDS as TNodeKind[]); const visitor = {} as Visitor; - nodeKeys.forEach(key => { + keys.forEach(key => { visitor[getVisitFunctionName(key)] = fn.bind(visitor); }); return visitor; diff --git a/packages/visitors-core/src/topDownTransformerVisitor.ts b/packages/visitors-core/src/topDownTransformerVisitor.ts index ee5ecf20a..af718653e 100644 --- a/packages/visitors-core/src/topDownTransformerVisitor.ts +++ b/packages/visitors-core/src/topDownTransformerVisitor.ts @@ -17,7 +17,7 @@ export type TopDownNodeTransformerWithSelector = { export function topDownTransformerVisitor( transformers: (TopDownNodeTransformer | TopDownNodeTransformerWithSelector)[], - nodeKeys?: TNodeKind[], + options: { keys?: TNodeKind[]; stack?: NodeStack } = {}, ): Visitor { const transformerFunctions = transformers.map((transformer): TopDownNodeTransformer => { if (typeof transformer === 'function') return transformer; @@ -27,9 +27,9 @@ export function topDownTransformerVisitor : node; }); - const stack = new NodeStack(); + const stack = options.stack ?? new NodeStack(); return pipe( - identityVisitor(nodeKeys), + identityVisitor(options), v => interceptVisitor(v, (node, next) => { const appliedNode = transformerFunctions.reduce( diff --git a/packages/visitors-core/src/voidVisitor.ts b/packages/visitors-core/src/voidVisitor.ts index 96f705579..2ca193e15 100644 --- a/packages/visitors-core/src/voidVisitor.ts +++ b/packages/visitors-core/src/voidVisitor.ts @@ -3,10 +3,12 @@ import type { NodeKind } from '@codama/nodes'; import { mergeVisitor } from './mergeVisitor'; import { Visitor } from './visitor'; -export function voidVisitor(nodeKeys?: TNodeKind[]): Visitor { +export function voidVisitor( + options: { keys?: TNodeKind[] } = {}, +): Visitor { return mergeVisitor( () => undefined, () => undefined, - nodeKeys, + options, ); } diff --git a/packages/visitors-core/test/bottomUpTransformerVisitor.test.ts b/packages/visitors-core/test/bottomUpTransformerVisitor.test.ts index dcf3feeef..f2a44cce1 100644 --- a/packages/visitors-core/test/bottomUpTransformerVisitor.test.ts +++ b/packages/visitors-core/test/bottomUpTransformerVisitor.test.ts @@ -1,7 +1,22 @@ -import { isNode, numberTypeNode, publicKeyTypeNode, stringTypeNode, tupleTypeNode, TYPE_NODES } from '@codama/nodes'; +import { + definedTypeNode, + isNode, + numberTypeNode, + programNode, + publicKeyTypeNode, + stringTypeNode, + tupleTypeNode, + TYPE_NODES, +} from '@codama/nodes'; import { expect, test } from 'vitest'; -import { bottomUpTransformerVisitor, visit } from '../src'; +import { + BottomUpNodeTransformerWithSelector, + bottomUpTransformerVisitor, + findProgramNodeFromPath, + NodeStack, + visit, +} from '../src'; test('it can transform nodes into other nodes', () => { // Given the following tree. @@ -48,10 +63,9 @@ test('it can create partial transformer visitors', () => { // And a transformer visitor that wraps every node into another tuple node // but that does not transform public key nodes. - const visitor = bottomUpTransformerVisitor( - [node => (isNode(node, TYPE_NODES) ? tupleTypeNode([node]) : node)], - ['tupleTypeNode', 'numberTypeNode'], - ); + const visitor = bottomUpTransformerVisitor([node => (isNode(node, TYPE_NODES) ? tupleTypeNode([node]) : node)], { + keys: ['tupleTypeNode', 'numberTypeNode'], + }); // When we visit the tree using that visitor. const result = visit(node, visitor); @@ -107,3 +121,33 @@ test('it can transform nodes using multiple node selectors', () => { tupleTypeNode([numberTypeNode('u32'), tupleTypeNode([stringTypeNode('utf8'), publicKeyTypeNode()])]), ); }); + +test('it can start from an existing stack', () => { + // Given the following tuple node inside a program node. + const tuple = tupleTypeNode([numberTypeNode('u32'), publicKeyTypeNode()]); + const program = programNode({ + definedTypes: [definedTypeNode({ name: 'myTuple', type: tuple })], + name: 'myProgram', + publicKey: '1111', + }); + + // And a transformer that removes all number nodes + // from programs whose public key is '1111'. + const transformer: BottomUpNodeTransformerWithSelector = { + select: ['[numberTypeNode]', path => findProgramNodeFromPath(path)?.publicKey === '1111'], + transform: () => null, + }; + + // When we visit the tuple with an existing stack that contains the program node. + const stack = new NodeStack([program, program.definedTypes[0]]); + const resultWithStack = visit(tuple, bottomUpTransformerVisitor([transformer], { stack })); + + // Then we expect the number node to have been removed. + expect(resultWithStack).toStrictEqual(tupleTypeNode([publicKeyTypeNode()])); + + // But when we visit the tuple without the stack. + const resultWithoutStack = visit(tuple, bottomUpTransformerVisitor([transformer])); + + // Then we expect the number node to have been kept. + expect(resultWithoutStack).toStrictEqual(tuple); +}); diff --git a/packages/visitors-core/test/deleteNodesVisitor.test.ts b/packages/visitors-core/test/deleteNodesVisitor.test.ts index 297b9e62c..5eff4c445 100644 --- a/packages/visitors-core/test/deleteNodesVisitor.test.ts +++ b/packages/visitors-core/test/deleteNodesVisitor.test.ts @@ -23,10 +23,9 @@ test('it can create partial visitors', () => { // And a visitor that deletes all number nodes and public key nodes // but does not support public key nodes. - const visitor = deleteNodesVisitor( - ['[numberTypeNode]', '[publicKeyTypeNode]'], - ['tupleTypeNode', 'numberTypeNode'], - ); + const visitor = deleteNodesVisitor(['[numberTypeNode]', '[publicKeyTypeNode]'], { + keys: ['tupleTypeNode', 'numberTypeNode'], + }); // When we visit the tree using that visitor. const result = visit(node, visitor); diff --git a/packages/visitors-core/test/extendVisitor.test.ts b/packages/visitors-core/test/extendVisitor.test.ts index ca9c75f24..0169bcb70 100644 --- a/packages/visitors-core/test/extendVisitor.test.ts +++ b/packages/visitors-core/test/extendVisitor.test.ts @@ -50,7 +50,7 @@ test('it can visit itself using the exposed self argument', () => { test('it cannot extends nodes that are not supported by the base visitor', () => { // Given a base visitor that only supports tuple nodes. - const baseVisitor = voidVisitor(['tupleTypeNode']); + const baseVisitor = voidVisitor({ keys: ['tupleTypeNode'] }); // Then we expect an error when we try to extend other nodes for that visitor. expect(() => diff --git a/packages/visitors-core/test/getByteSizeVisitor.test.ts b/packages/visitors-core/test/getByteSizeVisitor.test.ts index b58abe503..2293fd014 100644 --- a/packages/visitors-core/test/getByteSizeVisitor.test.ts +++ b/packages/visitors-core/test/getByteSizeVisitor.test.ts @@ -23,9 +23,7 @@ import { expect, test } from 'vitest'; import { getByteSizeVisitor, getRecordLinkablesVisitor, LinkableDictionary, NodeStack, visit, Visitor } from '../src'; const expectSize = (node: Node, expectedSize: number | null) => { - expect(visit(node, getByteSizeVisitor(new LinkableDictionary(), new NodeStack()) as Visitor)).toBe( - expectedSize, - ); + expect(visit(node, getByteSizeVisitor(new LinkableDictionary()) as Visitor)).toBe(expectedSize); }; test.each([ @@ -138,7 +136,7 @@ test('it follows linked nodes using the correct paths', () => { visit(root, getRecordLinkablesVisitor(linkables)); // When we visit the first defined type. - const visitor = getByteSizeVisitor(linkables, new NodeStack([root, programA])); + const visitor = getByteSizeVisitor(linkables, { stack: new NodeStack([root, programA]) }); const result = visit(programA.definedTypes[0], visitor); // Then we expect the final linkable to be resolved. diff --git a/packages/visitors-core/test/identityVisitor.test.ts b/packages/visitors-core/test/identityVisitor.test.ts index 34e8a0867..5d45a8e36 100644 --- a/packages/visitors-core/test/identityVisitor.test.ts +++ b/packages/visitors-core/test/identityVisitor.test.ts @@ -42,7 +42,7 @@ test('it can create partial visitors', () => { // And an identity visitor that only supports 2 of these nodes // whilst using an interceptor to record the events that happened. const events: string[] = []; - const visitor = interceptVisitor(identityVisitor(['tupleTypeNode', 'numberTypeNode']), (node, next) => { + const visitor = interceptVisitor(identityVisitor({ keys: ['tupleTypeNode', 'numberTypeNode'] }), (node, next) => { events.push(`visiting:${node.kind}`); return next(node); }); diff --git a/packages/visitors-core/test/mapVisitor.test.ts b/packages/visitors-core/test/mapVisitor.test.ts index aef5c4144..707ef6d6c 100644 --- a/packages/visitors-core/test/mapVisitor.test.ts +++ b/packages/visitors-core/test/mapVisitor.test.ts @@ -27,7 +27,7 @@ test('it creates partial visitors from partial visitors', () => { const node = tupleTypeNode([numberTypeNode('u32'), publicKeyTypeNode()]); // And partial static visitor A that supports only 2 of these nodes. - const visitorA = staticVisitor(node => node.kind, ['tupleTypeNode', 'numberTypeNode']); + const visitorA = staticVisitor(node => node.kind, { keys: ['tupleTypeNode', 'numberTypeNode'] }); // And a mapped visitor B that returns the number of characters returned by visitor A. const visitorB = mapVisitor(visitorA, value => value.length); diff --git a/packages/visitors-core/test/mergeVisitor.test.ts b/packages/visitors-core/test/mergeVisitor.test.ts index e792cf7e4..fbdb55c50 100644 --- a/packages/visitors-core/test/mergeVisitor.test.ts +++ b/packages/visitors-core/test/mergeVisitor.test.ts @@ -44,7 +44,7 @@ test('it can create partial visitors', () => { const visitor = mergeVisitor( node => node.kind as string, (node, values) => `${node.kind}(${values.join(',')})`, - ['tupleTypeNode', 'numberTypeNode'], + { keys: ['tupleTypeNode', 'numberTypeNode'] }, ); // When we visit the tree using that visitor. diff --git a/packages/visitors-core/test/removeDocsVisitor.test.ts b/packages/visitors-core/test/removeDocsVisitor.test.ts index f1963f6f1..a0fa943c2 100644 --- a/packages/visitors-core/test/removeDocsVisitor.test.ts +++ b/packages/visitors-core/test/removeDocsVisitor.test.ts @@ -76,7 +76,7 @@ test('it can create partial visitors', () => { ]); // And a remove docs visitor that only supports struct type nodes. - const visitor = removeDocsVisitor(['structTypeNode']); + const visitor = removeDocsVisitor({ keys: ['structTypeNode'] }); // When we use it on our struct node. const result = visit(node, visitor); diff --git a/packages/visitors-core/test/staticVisitor.test.ts b/packages/visitors-core/test/staticVisitor.test.ts index 0b01cd0f3..b9b958595 100644 --- a/packages/visitors-core/test/staticVisitor.test.ts +++ b/packages/visitors-core/test/staticVisitor.test.ts @@ -21,7 +21,7 @@ test('it can create partial visitor', () => { const node = tupleTypeNode([numberTypeNode('u32'), publicKeyTypeNode()]); // And a static visitor that supports only 2 of these nodes. - const visitor = staticVisitor(node => node.kind, ['tupleTypeNode', 'numberTypeNode']); + const visitor = staticVisitor(node => node.kind, { keys: ['tupleTypeNode', 'numberTypeNode'] }); // Then we expect the following results when visiting supported nodes. expect(visit(node, visitor)).toBe('tupleTypeNode'); diff --git a/packages/visitors-core/test/topDownTransformerVisitor.test.ts b/packages/visitors-core/test/topDownTransformerVisitor.test.ts index 0a03f5139..0a18f534f 100644 --- a/packages/visitors-core/test/topDownTransformerVisitor.test.ts +++ b/packages/visitors-core/test/topDownTransformerVisitor.test.ts @@ -1,7 +1,21 @@ -import { assertIsNode, isNode, numberTypeNode, publicKeyTypeNode, tupleTypeNode } from '@codama/nodes'; +import { + assertIsNode, + definedTypeNode, + isNode, + numberTypeNode, + programNode, + publicKeyTypeNode, + tupleTypeNode, +} from '@codama/nodes'; import { expect, test } from 'vitest'; -import { topDownTransformerVisitor, visit } from '../src'; +import { + findProgramNodeFromPath, + NodeStack, + TopDownNodeTransformerWithSelector, + topDownTransformerVisitor, + visit, +} from '../src'; test('it can transform nodes to the same kind of node', () => { // Given the following tree. @@ -57,7 +71,7 @@ test('it can create partial transformer visitors', () => { }, }, ], - ['tupleTypeNode'], + { keys: ['tupleTypeNode'] }, ); // When we visit the tree using that visitor. @@ -115,3 +129,33 @@ test('it can transform nodes using multiple node selectors', () => { tupleTypeNode([numberTypeNode('u32'), tupleTypeNode([numberTypeNode('u64'), publicKeyTypeNode()])]), ); }); + +test('it can start from an existing stack', () => { + // Given the following tuple node inside a program node. + const tuple = tupleTypeNode([numberTypeNode('u32'), publicKeyTypeNode()]); + const program = programNode({ + definedTypes: [definedTypeNode({ name: 'myTuple', type: tuple })], + name: 'myProgram', + publicKey: '1111', + }); + + // And a transformer that removes all number nodes + // from programs whose public key is '1111'. + const transformer: TopDownNodeTransformerWithSelector = { + select: ['[numberTypeNode]', path => findProgramNodeFromPath(path)?.publicKey === '1111'], + transform: () => null, + }; + + // When we visit the tuple with an existing stack that contains the program node. + const stack = new NodeStack([program, program.definedTypes[0]]); + const resultWithStack = visit(tuple, topDownTransformerVisitor([transformer], { stack })); + + // Then we expect the number node to have been removed. + expect(resultWithStack).toStrictEqual(tupleTypeNode([publicKeyTypeNode()])); + + // But when we visit the tuple without the stack. + const resultWithoutStack = visit(tuple, topDownTransformerVisitor([transformer])); + + // Then we expect the number node to have been kept. + expect(resultWithoutStack).toStrictEqual(tuple); +}); diff --git a/packages/visitors/src/fillDefaultPdaSeedValuesVisitor.ts b/packages/visitors/src/fillDefaultPdaSeedValuesVisitor.ts index 5cc85e3cd..10da2055d 100644 --- a/packages/visitors/src/fillDefaultPdaSeedValuesVisitor.ts +++ b/packages/visitors/src/fillDefaultPdaSeedValuesVisitor.ts @@ -42,7 +42,7 @@ export function fillDefaultPdaSeedValuesVisitor( strictMode: boolean = false, ) { const instruction = getLastNodeFromPath(instructionPath); - return pipe(identityVisitor(INSTRUCTION_INPUT_VALUE_NODES), v => + return pipe(identityVisitor({ keys: INSTRUCTION_INPUT_VALUE_NODES }), v => extendVisitor(v, { visitPdaValue(node, { next }) { const visitedNode = next(node); diff --git a/packages/visitors/src/setFixedAccountSizesVisitor.ts b/packages/visitors/src/setFixedAccountSizesVisitor.ts index 90db9e6f6..043919146 100644 --- a/packages/visitors/src/setFixedAccountSizesVisitor.ts +++ b/packages/visitors/src/setFixedAccountSizesVisitor.ts @@ -19,13 +19,13 @@ export function setFixedAccountSizesVisitor() { select: path => isNodePath(path, 'accountNode') && getLastNodeFromPath(path).size === undefined, transform: (node, stack) => { assertIsNode(node, 'accountNode'); - const size = visit(node.data, getByteSizeVisitor(linkables, stack)); + const size = visit(node.data, getByteSizeVisitor(linkables, { stack })); if (size === null) return node; return accountNode({ ...node, size }) as typeof node; }, }, ], - ['rootNode', 'programNode', 'accountNode'], + { keys: ['rootNode', 'programNode', 'accountNode'] }, ); return pipe(visitor, v => recordLinkablesOnFirstVisitVisitor(v, linkables)); diff --git a/packages/visitors/src/setInstructionAccountDefaultValuesVisitor.ts b/packages/visitors/src/setInstructionAccountDefaultValuesVisitor.ts index 1263d7dc3..c22632d2a 100644 --- a/packages/visitors/src/setInstructionAccountDefaultValuesVisitor.ts +++ b/packages/visitors/src/setInstructionAccountDefaultValuesVisitor.ts @@ -163,7 +163,7 @@ export function setInstructionAccountDefaultValuesVisitor(rules: InstructionAcco } return pipe( - nonNullableIdentityVisitor(['rootNode', 'programNode', 'instructionNode']), + nonNullableIdentityVisitor({ keys: ['rootNode', 'programNode', 'instructionNode'] }), v => extendVisitor(v, { visitInstruction(node) { diff --git a/packages/visitors/src/transformDefinedTypesIntoAccountsVisitor.ts b/packages/visitors/src/transformDefinedTypesIntoAccountsVisitor.ts index 8a10cae4a..14bc2a296 100644 --- a/packages/visitors/src/transformDefinedTypesIntoAccountsVisitor.ts +++ b/packages/visitors/src/transformDefinedTypesIntoAccountsVisitor.ts @@ -2,7 +2,7 @@ import { accountNode, assertIsNode, programNode } from '@codama/nodes'; import { extendVisitor, nonNullableIdentityVisitor, pipe } from '@codama/visitors-core'; export function transformDefinedTypesIntoAccountsVisitor(definedTypes: string[]) { - return pipe(nonNullableIdentityVisitor(['rootNode', 'programNode']), v => + return pipe(nonNullableIdentityVisitor({ keys: ['rootNode', 'programNode'] }), v => extendVisitor(v, { visitProgram(program) { const typesToExtract = program.definedTypes.filter(node => definedTypes.includes(node.name));