From 465bc04a93e6e1121e1ab0b6fd846a0062ccfb40 Mon Sep 17 00:00:00 2001 From: Loris Leiva Date: Thu, 31 Oct 2024 15:39:21 +0000 Subject: [PATCH] Use NodePath in LinkableDictionary --- .changeset/red-plants-collect.md | 5 + packages/errors/src/context.ts | 4 +- .../src/getRenderMapVisitor.ts | 2 +- .../src/getTypeManifestVisitor.ts | 2 +- .../src/fragments/accountPdaHelpers.ts | 2 +- .../fragments/instructionAccountTypeParam.ts | 3 +- .../src/getTypeManifestVisitor.ts | 2 +- .../renderers-rust/src/getRenderMapVisitor.ts | 2 +- packages/validators/README.md | 4 +- packages/validators/src/ValidationItem.ts | 4 +- .../src/getValidationItemsVisitor.ts | 2 +- packages/visitors-core/README.md | 7 +- .../visitors-core/src/LinkableDictionary.ts | 166 ++++++++++-------- packages/visitors-core/src/NodePath.ts | 57 +++++- packages/visitors-core/src/NodeStack.ts | 9 +- .../visitors-core/src/getByteSizeVisitor.ts | 3 +- .../src/recordLinkablesVisitor.ts | 2 +- .../test/recordLinkablesVisitor.test.ts | 83 ++++----- ...reateSubInstructionsFromEnumArgsVisitor.ts | 18 +- .../src/fillDefaultPdaSeedValuesVisitor.ts | 2 +- .../visitors/src/unwrapDefinedTypesVisitor.ts | 4 +- .../src/unwrapTypeDefinedLinksVisitor.ts | 16 +- .../fillDefaultPdaSeedValuesVisitor.test.ts | 6 +- 23 files changed, 233 insertions(+), 172 deletions(-) create mode 100644 .changeset/red-plants-collect.md diff --git a/.changeset/red-plants-collect.md b/.changeset/red-plants-collect.md new file mode 100644 index 000000000..bbae88d29 --- /dev/null +++ b/.changeset/red-plants-collect.md @@ -0,0 +1,5 @@ +--- +'@codama/visitors-core': minor +--- + +Record and resolve `NodePaths` instead of `Nodes` in `LinkableDictionary` diff --git a/packages/errors/src/context.ts b/packages/errors/src/context.ts index 2323d8f12..08dd224c9 100644 --- a/packages/errors/src/context.ts +++ b/packages/errors/src/context.ts @@ -75,7 +75,7 @@ export type CodamaErrorContext = DefaultUnspecifiedErrorContextToUndefined<{ kind: LinkNode['kind']; linkNode: LinkNode; name: CamelCaseString; - stack: Node[]; + path: readonly Node[]; }; [CODAMA_ERROR__NODE_FILESYSTEM_FUNCTION_UNAVAILABLE]: { fsFunction: string; @@ -169,7 +169,7 @@ type ValidationItem = { level: 'debug' | 'error' | 'info' | 'trace' | 'warn'; message: string; node: Node; - stack: Node[]; + path: Node[]; }; export function decodeEncodedContext(encodedContext: string): object { diff --git a/packages/renderers-js-umi/src/getRenderMapVisitor.ts b/packages/renderers-js-umi/src/getRenderMapVisitor.ts index 45d1af17c..ad5cddd59 100644 --- a/packages/renderers-js-umi/src/getRenderMapVisitor.ts +++ b/packages/renderers-js-umi/src/getRenderMapVisitor.ts @@ -204,7 +204,7 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor< } // Seeds. - const pda = node.pda ? linkables.get(node.pda, stack) : undefined; + const pda = node.pda ? linkables.get([...stack.getPath(), node.pda]) : undefined; const pdaSeeds = pda?.seeds ?? []; const seeds = pdaSeeds.map(seed => { if (isNode(seed, 'variablePdaSeedNode')) { diff --git a/packages/renderers-js-umi/src/getTypeManifestVisitor.ts b/packages/renderers-js-umi/src/getTypeManifestVisitor.ts index 1db555e41..885317fb7 100644 --- a/packages/renderers-js-umi/src/getTypeManifestVisitor.ts +++ b/packages/renderers-js-umi/src/getTypeManifestVisitor.ts @@ -429,7 +429,7 @@ export function getTypeManifestVisitor(input: { const importFrom = getImportFrom(node.enum); // FIXME(loris): No program node can ever be in this stack. - const enumNode = linkables.get(node.enum, stack)?.type; + const enumNode = linkables.get([...stack.getPath(), node.enum])?.type; const isScalar = enumNode && isNode(enumNode, 'enumTypeNode') ? isScalarEnum(enumNode) diff --git a/packages/renderers-js/src/fragments/accountPdaHelpers.ts b/packages/renderers-js/src/fragments/accountPdaHelpers.ts index 77fa9533c..67f17b057 100644 --- a/packages/renderers-js/src/fragments/accountPdaHelpers.ts +++ b/packages/renderers-js/src/fragments/accountPdaHelpers.ts @@ -13,7 +13,7 @@ export function getAccountPdaHelpersFragment( }, ): Fragment { const { accountNode, accountStack, nameApi, linkables, customAccountData, typeManifest } = scope; - const pdaNode = accountNode.pda ? linkables.get(accountNode.pda, accountStack) : undefined; + const pdaNode = accountNode.pda ? linkables.get([...accountStack.getPath(), accountNode.pda]) : undefined; if (!pdaNode) { return fragment(''); } diff --git a/packages/renderers-js/src/fragments/instructionAccountTypeParam.ts b/packages/renderers-js/src/fragments/instructionAccountTypeParam.ts index f1b248878..a9abea487 100644 --- a/packages/renderers-js/src/fragments/instructionAccountTypeParam.ts +++ b/packages/renderers-js/src/fragments/instructionAccountTypeParam.ts @@ -43,9 +43,8 @@ function getDefaultAddress( case 'publicKeyValueNode': return `"${defaultValue.publicKey}"`; case 'programLinkNode': - // FIXME(loris): No need for a stack here. // eslint-disable-next-line no-case-declarations - const programNode = linkables.get(defaultValue, new NodeStack()); + const programNode = linkables.get([defaultValue]); return programNode ? `"${programNode.publicKey}"` : 'string'; case 'programIdValueNode': return `"${programId}"`; diff --git a/packages/renderers-js/src/getTypeManifestVisitor.ts b/packages/renderers-js/src/getTypeManifestVisitor.ts index 2fdf09c66..2203c4319 100644 --- a/packages/renderers-js/src/getTypeManifestVisitor.ts +++ b/packages/renderers-js/src/getTypeManifestVisitor.ts @@ -355,7 +355,7 @@ export function getTypeManifestVisitor(input: { const importFrom = getImportFrom(node.enum); // FIXME(loris): No program node can ever be in this stack. - const enumNode = linkables.get(node.enum, stack)?.type; + const enumNode = linkables.get([...stack.getPath(), node.enum])?.type; const isScalar = enumNode && isNode(enumNode, 'enumTypeNode') ? isScalarEnum(enumNode) diff --git a/packages/renderers-rust/src/getRenderMapVisitor.ts b/packages/renderers-rust/src/getRenderMapVisitor.ts index 907552929..c1e34a0b9 100644 --- a/packages/renderers-rust/src/getRenderMapVisitor.ts +++ b/packages/renderers-rust/src/getRenderMapVisitor.ts @@ -64,7 +64,7 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}) { // Seeds. const seedsImports = new ImportMap(); - const pda = node.pda ? linkables.get(node.pda, stack) : undefined; + const pda = node.pda ? linkables.get([...stack.getPath(), node.pda]) : undefined; const pdaSeeds = pda?.seeds ?? []; const seeds = pdaSeeds.map(seed => { if (isNode(seed, 'variablePdaSeedNode')) { diff --git a/packages/validators/README.md b/packages/validators/README.md index 557d3b995..170dbd050 100644 --- a/packages/validators/README.md +++ b/packages/validators/README.md @@ -36,8 +36,8 @@ type ValidationItem = { message: string; // The node that the validation item is related to. node: Node; - // The stack of nodes that led to the node above. - stack: readonly Node[]; + // The path of nodes that led to the node above (including the node itself). + path: readonly Node[]; }; ``` diff --git a/packages/validators/src/ValidationItem.ts b/packages/validators/src/ValidationItem.ts index cec08c5d3..a1d5ab3fd 100644 --- a/packages/validators/src/ValidationItem.ts +++ b/packages/validators/src/ValidationItem.ts @@ -8,7 +8,7 @@ export type ValidationItem = { level: LogLevel; message: string; node: Node; - stack: readonly Node[]; + path: readonly Node[]; }; export function validationItem( @@ -21,7 +21,7 @@ export function validationItem( level, message, node, - stack: Array.isArray(stack) ? [...stack] : stack.all(), + path: Array.isArray(stack) ? [...stack] : stack.all(), }; } diff --git a/packages/validators/src/getValidationItemsVisitor.ts b/packages/validators/src/getValidationItemsVisitor.ts index f976ae77d..3841ccf45 100644 --- a/packages/validators/src/getValidationItemsVisitor.ts +++ b/packages/validators/src/getValidationItemsVisitor.ts @@ -47,7 +47,7 @@ export function getValidationItemsVisitor(): Visitor const items = [] as ValidationItem[]; if (!node.name) { items.push(validationItem('error', 'Pointing to a defined type with no name.', node, stack)); - } else if (!linkables.has(node, stack)) { + } else if (!linkables.has(stack.getPath())) { items.push( validationItem( 'error', diff --git a/packages/visitors-core/README.md b/packages/visitors-core/README.md index a7e14874d..262e4b624 100644 --- a/packages/visitors-core/README.md +++ b/packages/visitors-core/README.md @@ -653,11 +653,8 @@ It offers the following API: ```ts const linkables = new LinkableDictionary(); -// Record program nodes. -linkables.record(programNode, stack); - -// Record other linkable nodes with their associated program node. -linkables.record(accountNode, stack); +// Record linkable nodes via their full path. +linkables.recordPath([rootNode, programNode, accountNode]); // Get a linkable node using a link node, or throw an error if it is not found. const programNode = linkables.getOrThrow(programLinkNode, stack); diff --git a/packages/visitors-core/src/LinkableDictionary.ts b/packages/visitors-core/src/LinkableDictionary.ts index 7e812d2fa..132eb61c6 100644 --- a/packages/visitors-core/src/LinkableDictionary.ts +++ b/packages/visitors-core/src/LinkableDictionary.ts @@ -1,25 +1,25 @@ import { CODAMA_ERROR__LINKED_NODE_NOT_FOUND, CodamaError } from '@codama/errors'; import { - AccountLinkNode, AccountNode, CamelCaseString, - DefinedTypeLinkNode, DefinedTypeNode, - InstructionAccountLinkNode, InstructionAccountNode, - InstructionArgumentLinkNode, InstructionArgumentNode, - InstructionLinkNode, InstructionNode, isNode, LinkNode, - PdaLinkNode, PdaNode, - ProgramLinkNode, ProgramNode, } from '@codama/nodes'; -import { NodeStack } from './NodeStack'; +import { + findInstructionNodeFromPath, + findProgramNodeFromPath, + getLastNodeFromPath, + getNodePathUntilLastNode, + isNodePath, + NodePath, +} from './NodePath'; export type LinkableNode = | AccountNode @@ -40,100 +40,114 @@ export const LINKABLE_NODES: LinkableNode['kind'][] = [ 'programNode', ]; +export type GetLinkableFromLinkNode = { + accountLinkNode: AccountNode; + definedTypeLinkNode: DefinedTypeNode; + instructionAccountLinkNode: InstructionAccountNode; + instructionArgumentLinkNode: InstructionArgumentNode; + instructionLinkNode: InstructionNode; + pdaLinkNode: PdaNode; + programLinkNode: ProgramNode; +}[TLinkNode['kind']]; + type ProgramDictionary = { - accounts: Map; - definedTypes: Map; + accounts: Map>; + definedTypes: Map>; instructions: Map; - pdas: Map; - program: ProgramNode; + pdas: Map>; + program: NodePath; }; type InstructionDictionary = { - accounts: Map; - arguments: Map; - instruction: InstructionNode; + accounts: Map>; + arguments: Map>; + instruction: NodePath; }; export class LinkableDictionary { readonly programs: Map = new Map(); - record(node: LinkableNode, stack: NodeStack): this { - const programDictionary = this.getOrCreateProgramDictionary(node, stack); + recordPath(linkablePath: NodePath): this { + const linkableNode = getLastNodeFromPath(linkablePath); + const programDictionary = this.getOrCreateProgramDictionary(linkablePath); if (!programDictionary) return this; // Do not record nodes that are outside of a program. - const instructionDictionary = this.getOrCreateInstructionDictionary(programDictionary, node, stack); - - if (isNode(node, 'accountNode')) { - programDictionary.accounts.set(node.name, node); - } else if (isNode(node, 'definedTypeNode')) { - programDictionary.definedTypes.set(node.name, node); - } else if (isNode(node, 'pdaNode')) { - programDictionary.pdas.set(node.name, node); - } else if (instructionDictionary && isNode(node, 'instructionAccountNode')) { - instructionDictionary.accounts.set(node.name, node); - } else if (instructionDictionary && isNode(node, 'instructionArgumentNode')) { - instructionDictionary.arguments.set(node.name, node); + const instructionDictionary = this.getOrCreateInstructionDictionary(programDictionary, linkablePath); + + if (isNodePath(linkablePath, 'accountNode')) { + programDictionary.accounts.set(linkableNode.name, linkablePath); + } else if (isNodePath(linkablePath, 'definedTypeNode')) { + programDictionary.definedTypes.set(linkableNode.name, linkablePath); + } else if (isNodePath(linkablePath, 'pdaNode')) { + programDictionary.pdas.set(linkableNode.name, linkablePath); + } else if (instructionDictionary && isNodePath(linkablePath, 'instructionAccountNode')) { + instructionDictionary.accounts.set(linkableNode.name, linkablePath); + } else if (instructionDictionary && isNodePath(linkablePath, 'instructionArgumentNode')) { + instructionDictionary.arguments.set(linkableNode.name, linkablePath); } return this; } - getOrThrow(linkNode: AccountLinkNode, stack: NodeStack): AccountNode; - getOrThrow(linkNode: DefinedTypeLinkNode, stack: NodeStack): DefinedTypeNode; - getOrThrow(linkNode: InstructionAccountLinkNode, stack: NodeStack): InstructionAccountNode; - getOrThrow(linkNode: InstructionArgumentLinkNode, stack: NodeStack): InstructionArgumentNode; - getOrThrow(linkNode: InstructionLinkNode, stack: NodeStack): InstructionNode; - getOrThrow(linkNode: PdaLinkNode, stack: NodeStack): PdaNode; - getOrThrow(linkNode: ProgramLinkNode, stack: NodeStack): ProgramNode; - getOrThrow(linkNode: LinkNode, stack: NodeStack): LinkableNode { - const node = this.get(linkNode as ProgramLinkNode, stack) as LinkableNode | undefined; - - if (!node) { + getPathOrThrow( + linkPath: NodePath, + ): NodePath> { + const linkablePath = this.getPath(linkPath); + + if (!linkablePath) { + const linkNode = getLastNodeFromPath(linkPath); throw new CodamaError(CODAMA_ERROR__LINKED_NODE_NOT_FOUND, { kind: linkNode.kind, linkNode, name: linkNode.name, - stack: stack.all(), + path: linkablePath, }); } - return node; + return linkablePath; } - get(linkNode: AccountLinkNode, stack: NodeStack): AccountNode | undefined; - get(linkNode: DefinedTypeLinkNode, stack: NodeStack): DefinedTypeNode | undefined; - get(linkNode: InstructionAccountLinkNode, stack: NodeStack): InstructionAccountNode | undefined; - get(linkNode: InstructionArgumentLinkNode, stack: NodeStack): InstructionArgumentNode | undefined; - get(linkNode: InstructionLinkNode, stack: NodeStack): InstructionNode | undefined; - get(linkNode: PdaLinkNode, stack: NodeStack): PdaNode | undefined; - get(linkNode: ProgramLinkNode, stack: NodeStack): ProgramNode | undefined; - get(linkNode: LinkNode, stack: NodeStack): LinkableNode | undefined { - const programDictionary = this.getProgramDictionary(linkNode, stack); + getPath( + linkPath: NodePath, + ): NodePath> | undefined { + const linkNode = getLastNodeFromPath(linkPath); + const programDictionary = this.getProgramDictionary(linkPath); if (!programDictionary) return undefined; - const instructionDictionary = this.getInstructionDictionary(programDictionary, linkNode, stack); + const instructionDictionary = this.getInstructionDictionary(programDictionary, linkPath); + type LinkablePath = NodePath> | undefined; if (isNode(linkNode, 'accountLinkNode')) { - return programDictionary.accounts.get(linkNode.name); + return programDictionary.accounts.get(linkNode.name) as LinkablePath; } else if (isNode(linkNode, 'definedTypeLinkNode')) { - return programDictionary.definedTypes.get(linkNode.name); + return programDictionary.definedTypes.get(linkNode.name) as LinkablePath; } else if (isNode(linkNode, 'instructionAccountLinkNode')) { - return instructionDictionary?.accounts.get(linkNode.name); + return instructionDictionary?.accounts.get(linkNode.name) as LinkablePath; } else if (isNode(linkNode, 'instructionArgumentLinkNode')) { - return instructionDictionary?.arguments.get(linkNode.name); + return instructionDictionary?.arguments.get(linkNode.name) as LinkablePath; } else if (isNode(linkNode, 'instructionLinkNode')) { - return instructionDictionary?.instruction; + return instructionDictionary?.instruction as LinkablePath; } else if (isNode(linkNode, 'pdaLinkNode')) { - return programDictionary.pdas.get(linkNode.name); + return programDictionary.pdas.get(linkNode.name) as LinkablePath; } else if (isNode(linkNode, 'programLinkNode')) { - return programDictionary.program; + return programDictionary.program as LinkablePath; } return undefined; } - has(linkNode: LinkNode, stack: NodeStack): boolean { - const programDictionary = this.getProgramDictionary(linkNode, stack); + getOrThrow(linkPath: NodePath): GetLinkableFromLinkNode { + return getLastNodeFromPath(this.getPathOrThrow(linkPath)); + } + + get(linkPath: NodePath): GetLinkableFromLinkNode | undefined { + const path = this.getPath(linkPath); + return path ? getLastNodeFromPath(path) : undefined; + } + + has(linkPath: NodePath): boolean { + const linkNode = getLastNodeFromPath(linkPath); + const programDictionary = this.getProgramDictionary(linkPath); if (!programDictionary) return false; - const instructionDictionary = this.getInstructionDictionary(programDictionary, linkNode, stack); + const instructionDictionary = this.getInstructionDictionary(programDictionary, linkPath); if (isNode(linkNode, 'accountLinkNode')) { return programDictionary.accounts.has(linkNode.name); @@ -154,8 +168,9 @@ export class LinkableDictionary { return false; } - private getOrCreateProgramDictionary(node: LinkableNode, stack: NodeStack): ProgramDictionary | undefined { - const programNode = isNode(node, 'programNode') ? node : stack.getProgram(); + private getOrCreateProgramDictionary(linkablePath: NodePath): ProgramDictionary | undefined { + const linkableNode = getLastNodeFromPath(linkablePath); + const programNode = isNode(linkableNode, 'programNode') ? linkableNode : findProgramNodeFromPath(linkablePath); if (!programNode) return undefined; let programDictionary = this.programs.get(programNode.name); @@ -165,7 +180,7 @@ export class LinkableDictionary { definedTypes: new Map(), instructions: new Map(), pdas: new Map(), - program: programNode, + program: getNodePathUntilLastNode(linkablePath, 'programNode')!, }; this.programs.set(programNode.name, programDictionary); } @@ -175,10 +190,12 @@ export class LinkableDictionary { private getOrCreateInstructionDictionary( programDictionary: ProgramDictionary, - node: LinkableNode, - stack: NodeStack, + linkablePath: NodePath, ): InstructionDictionary | undefined { - const instructionNode = isNode(node, 'instructionNode') ? node : stack.getInstruction(); + const linkableNode = getLastNodeFromPath(linkablePath); + const instructionNode = isNode(linkableNode, 'instructionNode') + ? linkableNode + : findInstructionNodeFromPath(linkablePath); if (!instructionNode) return undefined; let instructionDictionary = programDictionary.instructions.get(instructionNode.name); @@ -186,7 +203,7 @@ export class LinkableDictionary { instructionDictionary = { accounts: new Map(), arguments: new Map(), - instruction: instructionNode, + instruction: getNodePathUntilLastNode(linkablePath, 'instructionNode')!, }; programDictionary.instructions.set(instructionNode.name, instructionDictionary); } @@ -194,7 +211,8 @@ export class LinkableDictionary { return instructionDictionary; } - private getProgramDictionary(linkNode: LinkNode, stack: NodeStack): ProgramDictionary | undefined { + private getProgramDictionary(linkPath: NodePath): ProgramDictionary | undefined { + const linkNode = getLastNodeFromPath(linkPath); let programName: CamelCaseString | undefined = undefined; if (isNode(linkNode, 'programLinkNode')) { programName = linkNode.name; @@ -203,23 +221,23 @@ export class LinkableDictionary { } else if ('instruction' in linkNode) { programName = linkNode.instruction?.program?.name; } - programName = programName ?? stack.getProgram()?.name; + programName = programName ?? findProgramNodeFromPath(linkPath)?.name; return programName ? this.programs.get(programName) : undefined; } private getInstructionDictionary( programDictionary: ProgramDictionary, - linkNode: LinkNode, - stack: NodeStack, + linkPath: NodePath, ): InstructionDictionary | undefined { + const linkNode = getLastNodeFromPath(linkPath); let instructionName: CamelCaseString | undefined = undefined; if (isNode(linkNode, 'instructionLinkNode')) { instructionName = linkNode.name; } else if ('instruction' in linkNode) { instructionName = linkNode.instruction?.name; } - instructionName = instructionName ?? stack.getInstruction()?.name; + instructionName = instructionName ?? findInstructionNodeFromPath(linkPath)?.name; return instructionName ? programDictionary.instructions.get(instructionName) : undefined; } diff --git a/packages/visitors-core/src/NodePath.ts b/packages/visitors-core/src/NodePath.ts index ed6d3b294..b1685c9e5 100644 --- a/packages/visitors-core/src/NodePath.ts +++ b/packages/visitors-core/src/NodePath.ts @@ -1,7 +1,62 @@ -import { Node } from '@codama/nodes'; +import { assertIsNode, GetNodeFromKind, InstructionNode, isNode, Node, NodeKind, ProgramNode } from '@codama/nodes'; export type NodePath = readonly [...Node[], TNode]; export function getLastNodeFromPath(path: NodePath): TNode { return path[path.length - 1] as TNode; } + +export function findFirstNodeFromPath( + path: NodePath, + kind: TKind | TKind[], +): GetNodeFromKind | undefined { + return path.find(node => isNode(node, kind)); +} + +export function findLastNodeFromPath( + path: NodePath, + kind: TKind | TKind[], +): GetNodeFromKind | undefined { + for (let index = path.length - 1; index >= 0; index--) { + const node = path[index]; + if (isNode(node, kind)) return node; + } + return undefined; +} + +export function findProgramNodeFromPath(path: NodePath): ProgramNode | undefined { + return findLastNodeFromPath(path, 'programNode'); +} + +export function findInstructionNodeFromPath(path: NodePath): InstructionNode | undefined { + return findLastNodeFromPath(path, 'instructionNode'); +} + +export function getNodePathUntilLastNode( + path: NodePath, + kind: TKind | TKind[], +): NodePath> | undefined { + const lastIndex = (() => { + for (let index = path.length - 1; index >= 0; index--) { + const node = path[index]; + if (isNode(node, kind)) return index; + } + return -1; + })(); + if (lastIndex === -1) return undefined; + return path.slice(0, lastIndex + 1) as unknown as NodePath>; +} + +export function isNodePath( + path: NodePath | null | undefined, + kind: TKind | TKind[], +): path is NodePath> { + return isNode(path ? getLastNodeFromPath(path) : null, kind); +} + +export function assertIsNodePath( + path: NodePath | null | undefined, + kind: TKind | TKind[], +): asserts path is NodePath> { + assertIsNode(path ? getLastNodeFromPath(path) : null, kind); +} diff --git a/packages/visitors-core/src/NodeStack.ts b/packages/visitors-core/src/NodeStack.ts index 973274e80..81206bd48 100644 --- a/packages/visitors-core/src/NodeStack.ts +++ b/packages/visitors-core/src/NodeStack.ts @@ -2,14 +2,13 @@ import { assertIsNode, GetNodeFromKind, InstructionNode, - isNode, Node, NodeKind, ProgramNode, REGISTERED_NODE_KINDS, } from '@codama/nodes'; -import { NodePath } from './NodePath'; +import { findLastNodeFromPath, NodePath } from './NodePath'; export class NodeStack { /** @@ -58,11 +57,7 @@ export class NodeStack { } public find(kind: TKind | TKind[]): GetNodeFromKind | undefined { - for (let index = this.stack.length - 1; index >= 0; index--) { - const node = this.stack[index]; - if (isNode(node, kind)) return node; - } - return undefined; + return findLastNodeFromPath([...this.stack] as unknown as NodePath>, kind); } public getProgram(): ProgramNode | undefined { diff --git a/packages/visitors-core/src/getByteSizeVisitor.ts b/packages/visitors-core/src/getByteSizeVisitor.ts index 12c1908a1..e0ca181c5 100644 --- a/packages/visitors-core/src/getByteSizeVisitor.ts +++ b/packages/visitors-core/src/getByteSizeVisitor.ts @@ -69,7 +69,8 @@ export function getByteSizeVisitor( visitDefinedTypeLink(node, { self }) { // Fetch the linked type and return null if not found. // The validator visitor will throw a proper error later on. - const linkedDefinedType = linkables.get(node, stack); + // FIXME: Keep track of our own internal stack within this visitor (starting from a provided NodePath). + const linkedDefinedType = linkables.get([...stack.getPath(), node]); if (!linkedDefinedType) { return null; } diff --git a/packages/visitors-core/src/recordLinkablesVisitor.ts b/packages/visitors-core/src/recordLinkablesVisitor.ts index cce165ab0..4702577fd 100644 --- a/packages/visitors-core/src/recordLinkablesVisitor.ts +++ b/packages/visitors-core/src/recordLinkablesVisitor.ts @@ -18,7 +18,7 @@ export function getRecordLinkablesVisitor( v => interceptVisitor(v, (node, next) => { if (isNode(node, LINKABLE_NODES)) { - linkables.record(node, stack); + linkables.recordPath(stack.getPath()); } return next(node); }), diff --git a/packages/visitors-core/test/recordLinkablesVisitor.test.ts b/packages/visitors-core/test/recordLinkablesVisitor.test.ts index 77d826fd5..657ffc4a4 100644 --- a/packages/visitors-core/test/recordLinkablesVisitor.test.ts +++ b/packages/visitors-core/test/recordLinkablesVisitor.test.ts @@ -45,10 +45,9 @@ test('it records program nodes', () => { // When we visit the tree. visit(node, visitor); - // Then we expect program nodes to be recorded and retrievable. - const emptyStack = new NodeStack(); - expect(linkables.get(programLinkNode('programA'), emptyStack)).toEqual(node.program); - expect(linkables.get(programLinkNode('programB'), emptyStack)).toEqual(node.additionalPrograms[0]); + // Then we expect program paths to be recorded and retrievable. + expect(linkables.getPath([programLinkNode('programA')])).toEqual([node, node.program]); + expect(linkables.getPath([programLinkNode('programB')])).toEqual([node, node.additionalPrograms[0]]); }); test('it records account nodes', () => { @@ -66,10 +65,9 @@ test('it records account nodes', () => { // When we visit the tree. visit(node, visitor); - // Then we expect account nodes to be recorded and retrievable. - const emptyStack = new NodeStack(); - expect(linkables.get(accountLinkNode('accountA', 'myProgram'), emptyStack)).toEqual(node.accounts[0]); - expect(linkables.get(accountLinkNode('accountB', 'myProgram'), emptyStack)).toEqual(node.accounts[1]); + // Then we expect account paths to be recorded and retrievable. + expect(linkables.getPath([accountLinkNode('accountA', 'myProgram')])).toEqual([node, node.accounts[0]]); + expect(linkables.getPath([accountLinkNode('accountB', 'myProgram')])).toEqual([node, node.accounts[1]]); }); test('it records defined type nodes', () => { @@ -90,10 +88,9 @@ test('it records defined type nodes', () => { // When we visit the tree. visit(node, visitor); - // Then we expect defined type nodes to be recorded and retrievable. - const emptyStack = new NodeStack(); - expect(linkables.get(definedTypeLinkNode('typeA', 'myProgram'), emptyStack)).toEqual(node.definedTypes[0]); - expect(linkables.get(definedTypeLinkNode('typeB', 'myProgram'), emptyStack)).toEqual(node.definedTypes[1]); + // Then we expect defined type paths to be recorded and retrievable. + expect(linkables.getPath([definedTypeLinkNode('typeA', 'myProgram')])).toEqual([node, node.definedTypes[0]]); + expect(linkables.getPath([definedTypeLinkNode('typeB', 'myProgram')])).toEqual([node, node.definedTypes[1]]); }); test('it records pda nodes', () => { @@ -111,10 +108,9 @@ test('it records pda nodes', () => { // When we visit the tree. visit(node, visitor); - // Then we expect pda nodes to be recorded and retrievable. - const emptyStack = new NodeStack(); - expect(linkables.get(pdaLinkNode('pdaA', 'myProgram'), emptyStack)).toEqual(node.pdas[0]); - expect(linkables.get(pdaLinkNode('pdaB', 'myProgram'), emptyStack)).toEqual(node.pdas[1]); + // Then we expect pda paths to be recorded and retrievable. + expect(linkables.getPath([pdaLinkNode('pdaA', 'myProgram')])).toEqual([node, node.pdas[0]]); + expect(linkables.getPath([pdaLinkNode('pdaB', 'myProgram')])).toEqual([node, node.pdas[1]]); }); test('it records instruction nodes', () => { @@ -132,10 +128,9 @@ test('it records instruction nodes', () => { // When we visit the tree. visit(node, visitor); - // Then we expect instruction nodes to be recorded and retrievable. - const emptyStack = new NodeStack(); - expect(linkables.get(instructionLinkNode('instructionA', 'myProgram'), emptyStack)).toEqual(node.instructions[0]); - expect(linkables.get(instructionLinkNode('instructionB', 'myProgram'), emptyStack)).toEqual(node.instructions[1]); + // Then we expect instruction paths to be recorded and retrievable. + expect(linkables.getPath([instructionLinkNode('instructionA', 'myProgram')])).toEqual([node, node.instructions[0]]); + expect(linkables.getPath([instructionLinkNode('instructionB', 'myProgram')])).toEqual([node, node.instructions[1]]); }); test('it records instruction account nodes', () => { @@ -157,15 +152,18 @@ test('it records instruction account nodes', () => { // When we visit the tree. visit(node, visitor); - // Then we expect instruction account nodes to be recorded and retrievable. - const emptyStack = new NodeStack(); + // Then we expect instruction account paths to be recorded and retrievable. const instruction = instructionLinkNode('myInstruction', 'myProgram'); - expect(linkables.get(instructionAccountLinkNode('accountA', instruction), emptyStack)).toEqual( + expect(linkables.getPath([instructionAccountLinkNode('accountA', instruction)])).toEqual([ + node, + node.instructions[0], instructionAccounts[0], - ); - expect(linkables.get(instructionAccountLinkNode('accountB', instruction), emptyStack)).toEqual( + ]); + expect(linkables.getPath([instructionAccountLinkNode('accountB', instruction)])).toEqual([ + node, + node.instructions[0], instructionAccounts[1], - ); + ]); }); test('it records instruction argument nodes', () => { @@ -187,15 +185,18 @@ test('it records instruction argument nodes', () => { // When we visit the tree. visit(node, visitor); - // Then we expect instruction argument nodes to be recorded and retrievable. - const emptyStack = new NodeStack(); + // Then we expect instruction argument paths to be recorded and retrievable. const instruction = instructionLinkNode('myInstruction', 'myProgram'); - expect(linkables.get(instructionArgumentLinkNode('argumentA', instruction), emptyStack)).toEqual( + expect(linkables.getPath([instructionArgumentLinkNode('argumentA', instruction)])).toEqual([ + node, + node.instructions[0], instructionArguments[0], - ); - expect(linkables.get(instructionArgumentLinkNode('argumentB', instruction), emptyStack)).toEqual( + ]); + expect(linkables.getPath([instructionArgumentLinkNode('argumentB', instruction)])).toEqual([ + node, + node.instructions[0], instructionArguments[1], - ); + ]); }); test('it records all linkable before the first visit of the base visitor', () => { @@ -207,11 +208,10 @@ test('it records all linkable before the first visit of the base visitor', () => // And a recordLinkablesOnFirstVisitVisitor extending a base visitor that // stores the linkable programs available at every visit. const linkables = new LinkableDictionary(); - const emptyStack = new NodeStack(); const events: string[] = []; const baseVisitor = interceptFirstVisitVisitor(voidVisitor(), (node, next) => { - events.push(`programA:${linkables.has(programLinkNode('programA'), emptyStack)}`); - events.push(`programB:${linkables.has(programLinkNode('programB'), emptyStack)}`); + events.push(`programA:${linkables.has([programLinkNode('programA')])}`); + events.push(`programB:${linkables.has([programLinkNode('programB')])}`); next(node); }); const visitor = recordLinkablesOnFirstVisitVisitor(baseVisitor, linkables); @@ -245,7 +245,7 @@ test('it keeps track of the current program when extending a visitor', () => { const baseVisitor = interceptVisitor(voidVisitor(), (node, next) => { stack.push(node); if (isNode(node, 'programNode')) { - dictionary[node.name] = linkables.getOrThrow(accountLinkNode('someAccount'), stack); + dictionary[node.name] = linkables.getOrThrow([...stack.getPath(), accountLinkNode('someAccount')]); } next(node); stack.pop(); @@ -285,7 +285,10 @@ test('it keeps track of the current instruction when extending a visitor', () => const baseVisitor = interceptVisitor(voidVisitor(), (node, next) => { stack.push(node); if (isNode(node, 'instructionNode')) { - dictionary[node.name] = linkables.getOrThrow(instructionAccountLinkNode('someAccount'), stack); + dictionary[node.name] = linkables.getOrThrow([ + ...stack.getPath(), + instructionAccountLinkNode('someAccount'), + ]); } next(node); stack.pop(); @@ -312,8 +315,7 @@ test('it does not record linkable types that are not under a program node', () = visit(node, visitor); // Then we expect the account node to not be recorded. - const emptyStack = new NodeStack(); - expect(linkables.has(accountLinkNode('someAccount'), emptyStack)).toBe(false); + expect(linkables.has([accountLinkNode('someAccount')])).toBe(false); }); test('it can throw an exception when trying to retrieve a missing linked node', () => { @@ -330,8 +332,7 @@ test('it can throw an exception when trying to retrieve a missing linked node', visit(node, visitor); // When we try to retrieve a missing account node. - const emptyStack = new NodeStack(); - const getMissingAccount = () => linkables.getOrThrow(accountLinkNode('missingAccount', 'myProgram'), emptyStack); + const getMissingAccount = () => linkables.getOrThrow([node, accountLinkNode('missingAccount', 'myProgram')]); // Then we expect an exception to be thrown. expect(getMissingAccount).toThrow( diff --git a/packages/visitors/src/createSubInstructionsFromEnumArgsVisitor.ts b/packages/visitors/src/createSubInstructionsFromEnumArgsVisitor.ts index aeeddff1e..d492db438 100644 --- a/packages/visitors/src/createSubInstructionsFromEnumArgsVisitor.ts +++ b/packages/visitors/src/createSubInstructionsFromEnumArgsVisitor.ts @@ -14,23 +14,20 @@ import { BottomUpNodeTransformerWithSelector, bottomUpTransformerVisitor, LinkableDictionary, - NodeStack, pipe, recordLinkablesOnFirstVisitVisitor, - recordNodeStackVisitor, } from '@codama/visitors-core'; import { flattenInstructionArguments } from './flattenInstructionDataArgumentsVisitor'; export function createSubInstructionsFromEnumArgsVisitor(map: Record) { const linkables = new LinkableDictionary(); - const stack = new NodeStack(); const visitor = bottomUpTransformerVisitor( Object.entries(map).map( ([selector, argNameInput]): BottomUpNodeTransformerWithSelector => ({ select: ['[instructionNode]', selector], - transform: node => { + transform: (node, stack) => { assertIsNode(node, 'instructionNode'); const argFields = node.arguments; @@ -48,8 +45,11 @@ export function createSubInstructionsFromEnumArgsVisitor(map: Record recordNodeStackVisitor(v, stack), - v => recordLinkablesOnFirstVisitVisitor(v, linkables), - ); + return pipe(visitor, v => recordLinkablesOnFirstVisitVisitor(v, linkables)); } diff --git a/packages/visitors/src/fillDefaultPdaSeedValuesVisitor.ts b/packages/visitors/src/fillDefaultPdaSeedValuesVisitor.ts index 4d879d755..5a622b26b 100644 --- a/packages/visitors/src/fillDefaultPdaSeedValuesVisitor.ts +++ b/packages/visitors/src/fillDefaultPdaSeedValuesVisitor.ts @@ -41,7 +41,7 @@ export function fillDefaultPdaSeedValuesVisitor( assertIsNode(visitedNode, 'pdaValueNode'); const foundPda = isNode(visitedNode.pda, 'pdaNode') ? visitedNode.pda - : linkables.get(visitedNode.pda, stack); + : linkables.get([...stack.getPath(), visitedNode.pda]); if (!foundPda) return visitedNode; const seeds = addDefaultSeedValuesFromPdaWhenMissing(instruction, foundPda, visitedNode.seeds); if (strictMode && !allSeedsAreValid(instruction, foundPda, seeds)) { diff --git a/packages/visitors/src/unwrapDefinedTypesVisitor.ts b/packages/visitors/src/unwrapDefinedTypesVisitor.ts index cc6092f2d..5e9d9f078 100644 --- a/packages/visitors/src/unwrapDefinedTypesVisitor.ts +++ b/packages/visitors/src/unwrapDefinedTypesVisitor.ts @@ -25,7 +25,9 @@ export function unwrapDefinedTypesVisitor(typesToInline: string[] | '*' = '*') { if (!shouldInline(linkType.name)) { return linkType; } - return visit(linkables.getOrThrow(linkType, stack).type, self); + const definedType = linkables.getOrThrow(stack.getPath('definedTypeLinkNode')); + // FIXME: Wrap in heap.pushStack() and heap.popStack(). + return visit(definedType.type, self); }, visitProgram(program, { self }) { diff --git a/packages/visitors/src/unwrapTypeDefinedLinksVisitor.ts b/packages/visitors/src/unwrapTypeDefinedLinksVisitor.ts index e2b15548c..61c354ae0 100644 --- a/packages/visitors/src/unwrapTypeDefinedLinksVisitor.ts +++ b/packages/visitors/src/unwrapTypeDefinedLinksVisitor.ts @@ -1,29 +1,21 @@ -import { assertIsNode } from '@codama/nodes'; import { BottomUpNodeTransformerWithSelector, bottomUpTransformerVisitor, LinkableDictionary, - NodeStack, pipe, recordLinkablesOnFirstVisitVisitor, - recordNodeStackVisitor, } from '@codama/visitors-core'; export function unwrapTypeDefinedLinksVisitor(definedLinksType: string[]) { const linkables = new LinkableDictionary(); - const stack = new NodeStack(); const transformers: BottomUpNodeTransformerWithSelector[] = definedLinksType.map(selector => ({ select: ['[definedTypeLinkNode]', selector], - transform: node => { - assertIsNode(node, 'definedTypeLinkNode'); - return linkables.getOrThrow(node, stack).type; + transform: (_, stack) => { + const definedType = linkables.getOrThrow(stack.getPath('definedTypeLinkNode')); + return definedType.type; }, })); - return pipe( - bottomUpTransformerVisitor(transformers), - v => recordNodeStackVisitor(v, stack), - v => recordLinkablesOnFirstVisitVisitor(v, linkables), - ); + return pipe(bottomUpTransformerVisitor(transformers), v => recordLinkablesOnFirstVisitVisitor(v, linkables)); } diff --git a/packages/visitors/test/fillDefaultPdaSeedValuesVisitor.test.ts b/packages/visitors/test/fillDefaultPdaSeedValuesVisitor.test.ts index 8029fac3f..980e953ac 100644 --- a/packages/visitors/test/fillDefaultPdaSeedValuesVisitor.test.ts +++ b/packages/visitors/test/fillDefaultPdaSeedValuesVisitor.test.ts @@ -39,7 +39,7 @@ test('it fills missing pda seed values with default values', () => { // And a linkable dictionary that recorded this PDA. const linkables = new LinkableDictionary(); - linkables.record(pda, new NodeStack([program, pda])); + linkables.recordPath([program, pda]); // And a pdaValueNode with a single seed filled. const node = pdaValueNode('myPda', [pdaSeedValueNode('seed1', numberValueNode(42))]); @@ -91,7 +91,7 @@ test('it fills nested pda value nodes', () => { // And a linkable dictionary that recorded this PDA. const linkables = new LinkableDictionary(); - linkables.record(pda, new NodeStack([program, pda])); + linkables.recordPath([program, pda]); // And a pdaValueNode nested inside a conditionalValueNode. const node = conditionalValueNode({ @@ -149,7 +149,7 @@ test('it ignores default seeds missing from the instruction', () => { // And a linkable dictionary that recorded this PDA. const linkables = new LinkableDictionary(); - linkables.record(pda, new NodeStack([program, pda])); + linkables.recordPath([program, pda]); // And a pdaValueNode with a single seed filled. const node = pdaValueNode('myPda', [pdaSeedValueNode('seed1', numberValueNode(42))]);