Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/curly-berries-jog.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@codama/visitors': minor
---

Use `NodePaths` in `fillDefaultPdaSeedValuesVisitor`
4 changes: 2 additions & 2 deletions packages/visitors/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ codama.update(deduplicateIdenticalDefinedTypesVisitor());
### `fillDefaultPdaSeedValuesVisitor`
This visitor fills any missing `PdaSeedValueNodes` from `PdaValueNodes` using the provided `InstructionNode` such that:
This visitor fills any missing `PdaSeedValueNodes` from `PdaValueNodes` using the provided `NodePath<InstructionNode>` such that:
- If a `VariablePdaSeedNode` is of type `PublicKeyTypeNode` and the name of the seed matches the name of an account in the `InstructionNode`, then a new `PdaSeedValueNode` will be added with the matching account.
- Otherwise, if a `VariablePdaSeedNode` is of any other type and the name of the seed matches the name of an argument in the `InstructionNode`, then a new `PdaSeedValueNode` will be added with the matching argument.
Expand All @@ -107,7 +107,7 @@ It also requires a [`LinkableDictionary`](../visitors-core/README.md#linkable-di
Note that this visitor is mainly used for internal purposes.
```ts
codama.update(fillDefaultPdaSeedValuesVisitor(instructionNode, linkables, strictMode));
codama.update(fillDefaultPdaSeedValuesVisitor(instructionPath, linkables, strictMode));
```
### `flattenInstructionDataArgumentsVisitor`
Expand Down
16 changes: 12 additions & 4 deletions packages/visitors/src/fillDefaultPdaSeedValuesVisitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@ import {
pdaSeedValueNode,
pdaValueNode,
} from '@codama/nodes';
import { extendVisitor, identityVisitor, LinkableDictionary, NodeStack, pipe, Visitor } from '@codama/visitors-core';
import {
extendVisitor,
getLastNodeFromPath,
identityVisitor,
LinkableDictionary,
NodePath,
pipe,
Visitor,
} from '@codama/visitors-core';

/**
* Fills in default values for variable PDA seeds that are not explicitly provided.
Expand All @@ -29,19 +37,19 @@ import { extendVisitor, identityVisitor, LinkableDictionary, NodeStack, pipe, Vi
* pdaSeedValueNodes contains invalid seeds or if there aren't enough variable seeds.
*/
export function fillDefaultPdaSeedValuesVisitor(
instruction: InstructionNode,
stack: NodeStack,
instructionPath: NodePath<InstructionNode>,
linkables: LinkableDictionary,
strictMode: boolean = false,
) {
const instruction = getLastNodeFromPath(instructionPath);
return pipe(identityVisitor(INSTRUCTION_INPUT_VALUE_NODES), v =>
extendVisitor(v, {
visitPdaValue(node, { next }) {
const visitedNode = next(node);
assertIsNode(visitedNode, 'pdaValueNode');
const foundPda = isNode(visitedNode.pda, 'pdaNode')
? visitedNode.pda
: linkables.get([...stack.getPath(), visitedNode.pda]);
: linkables.get([...instructionPath, visitedNode.pda]);
if (!foundPda) return visitedNode;
const seeds = addDefaultSeedValuesFromPdaWhenMissing(instruction, foundPda, visitedNode.seeds);
if (strictMode && !allSeedsAreValid(instruction, foundPda, seeds)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ export function setInstructionAccountDefaultValuesVisitor(rules: InstructionAcco
v =>
extendVisitor(v, {
visitInstruction(node) {
const instructionPath = stack.getPath('instructionNode');
const instructionAccounts = node.accounts.map((account): InstructionAccountNode => {
const rule = matchRule(node, account);
if (!rule) return account;
Expand All @@ -180,7 +181,7 @@ export function setInstructionAccountDefaultValuesVisitor(rules: InstructionAcco
...account,
defaultValue: visit(
rule.defaultValue,
fillDefaultPdaSeedValuesVisitor(node, stack, linkables, true),
fillDefaultPdaSeedValuesVisitor(instructionPath, linkables, true),
),
};
} catch (error) {
Expand Down
9 changes: 5 additions & 4 deletions packages/visitors/src/updateInstructionsVisitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
BottomUpNodeTransformerWithSelector,
bottomUpTransformerVisitor,
LinkableDictionary,
NodePath,
NodeStack,
pipe,
recordLinkablesOnFirstVisitVisitor,
Expand Down Expand Up @@ -72,10 +73,11 @@ export function updateInstructionsVisitor(map: Record<string, InstructionUpdates
return null;
}

const instructionPath = stack.getPath('instructionNode');
const { accounts: accountUpdates, arguments: argumentUpdates, ...metadataUpdates } = updates;
const { newArguments, newExtraArguments } = handleInstructionArguments(node, argumentUpdates ?? {});
const newAccounts = node.accounts.map(account =>
handleInstructionAccount(node, stack, account, accountUpdates ?? {}, linkables),
handleInstructionAccount(instructionPath, account, accountUpdates ?? {}, linkables),
);
return instructionNode({
...node,
Expand All @@ -96,8 +98,7 @@ export function updateInstructionsVisitor(map: Record<string, InstructionUpdates
}

function handleInstructionAccount(
instruction: InstructionNode,
stack: NodeStack,
instructionPath: NodePath<InstructionNode>,
account: InstructionAccountNode,
accountUpdates: InstructionAccountUpdates,
linkables: LinkableDictionary,
Expand All @@ -115,7 +116,7 @@ function handleInstructionAccount(

return instructionAccountNode({
...acountWithoutDefault,
defaultValue: visit(defaultValue, fillDefaultPdaSeedValuesVisitor(instruction, stack, linkables)),
defaultValue: visit(defaultValue, fillDefaultPdaSeedValuesVisitor(instructionPath, linkables)),
});
}

Expand Down
11 changes: 4 additions & 7 deletions packages/visitors/test/fillDefaultPdaSeedValuesVisitor.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {
publicKeyTypeNode,
variablePdaSeedNode,
} from '@codama/nodes';
import { LinkableDictionary, NodeStack, visit } from '@codama/visitors-core';
import { LinkableDictionary, visit } from '@codama/visitors-core';
import { expect, test } from 'vitest';

import { fillDefaultPdaSeedValuesVisitor } from '../src';
Expand Down Expand Up @@ -56,10 +56,9 @@ test('it fills missing pda seed values with default values', () => {
arguments: [instructionArgumentNode({ name: 'seed2', type: numberTypeNode('u64') })],
name: 'myInstruction',
});
const instructionStack = new NodeStack([program, instruction]);

// When we fill the PDA seeds with default values.
const result = visit(node, fillDefaultPdaSeedValuesVisitor(instruction, instructionStack, linkables));
const result = visit(node, fillDefaultPdaSeedValuesVisitor([program, instruction], linkables));

// Then we expect the following pdaValueNode to be returned.
expect(result).toEqual(
Expand Down Expand Up @@ -111,10 +110,9 @@ test('it fills nested pda value nodes', () => {
arguments: [instructionArgumentNode({ name: 'seed2', type: numberTypeNode('u64') })],
name: 'myInstruction',
});
const instructionStack = new NodeStack([program, instruction]);

// When we fill the PDA seeds with default values.
const result = visit(node, fillDefaultPdaSeedValuesVisitor(instruction, instructionStack, linkables));
const result = visit(node, fillDefaultPdaSeedValuesVisitor([program, instruction], linkables));

// Then we expect the following conditionalValueNode to be returned.
expect(result).toEqual(
Expand Down Expand Up @@ -159,10 +157,9 @@ test('it ignores default seeds missing from the instruction', () => {
arguments: [instructionArgumentNode({ name: 'seed2', type: numberTypeNode('u64') })],
name: 'myInstruction',
});
const instructionStack = new NodeStack([program, instruction]);

// When we fill the PDA seeds with default values.
const result = visit(node, fillDefaultPdaSeedValuesVisitor(instruction, instructionStack, linkables));
const result = visit(node, fillDefaultPdaSeedValuesVisitor([program, instruction], linkables));

// Then we expect the following pdaValueNode to be returned.
expect(result).toEqual(
Expand Down