diff --git a/.changeset/spicy-spoons-worry.md b/.changeset/spicy-spoons-worry.md new file mode 100644 index 000000000..a62fae605 --- /dev/null +++ b/.changeset/spicy-spoons-worry.md @@ -0,0 +1,5 @@ +--- +'@codama/renderers-rust': patch +--- + +Add options to configure how traits are rendered in Rust diff --git a/packages/renderers-rust/README.md b/packages/renderers-rust/README.md index 7230c8708..af05de381 100644 --- a/packages/renderers-rust/README.md +++ b/packages/renderers-rust/README.md @@ -37,12 +37,114 @@ codama.accept(renderVisitor(pathToGeneratedFolder, options)); The `renderVisitor` accepts the following options. -| Name | Type | Default | Description | -| ----------------------------- | ----------------------------------------------------------------------------------------------------------------------- | ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| `deleteFolderBeforeRendering` | `boolean` | `true` | Whether the base directory should be cleaned before generating new files. | -| `formatCode` | `boolean` | `false` | Whether we should use `cargo fmt` to format the generated code. When set to `true`, the `crateFolder` option must be provided. | -| `toolchain` | `string` | `"+stable"` | The toolchain to use when formatting the generated code. | -| `crateFolder` | `string` | none | The path to the root folder of the Rust crate. This option is required when `formatCode` is set to `true`. | -| `linkOverrides` | `Record<'accounts' \| 'definedTypes' \| 'instructions' \| 'pdas' \| 'programs' \| 'resolvers', Record>` | `{}` | A object that overrides the import path of link nodes. For instance, `{ definedTypes: { counter: 'hooked' } }` uses the `hooked` folder to import any link node referring to the `counter` type. | -| `dependencyMap` | `Record` | `{}` | A mapping between import aliases and their actual crate name or path in Rust. | -| `renderParentInstructions` | `boolean` | `false` | When using nested instructions, whether the parent instructions should also be rendered. When set to `false` (default), only the instruction leaves are being rendered. | +| Name | Type | Default | Description | +| ----------------------------- | ----------------------------------------------------------------------------------------------------------------------- | ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `deleteFolderBeforeRendering` | `boolean` | `true` | Whether the base directory should be cleaned before generating new files. | +| `formatCode` | `boolean` | `false` | Whether we should use `cargo fmt` to format the generated code. When set to `true`, the `crateFolder` option must be provided. | +| `toolchain` | `string` | `"+stable"` | The toolchain to use when formatting the generated code. | +| `crateFolder` | `string` | none | The path to the root folder of the Rust crate. This option is required when `formatCode` is set to `true`. | +| `linkOverrides` | `Record<'accounts' \| 'definedTypes' \| 'instructions' \| 'pdas' \| 'programs' \| 'resolvers', Record>` | `{}` | A object that overrides the import path of link nodes. For instance, `{ definedTypes: { counter: 'hooked' } }` uses the `hooked` folder to import any link node referring to the `counter` type. | +| `dependencyMap` | `Record` | `{}` | A mapping between import aliases and their actual crate name or path in Rust. | +| `renderParentInstructions` | `boolean` | `false` | When using nested instructions, whether the parent instructions should also be rendered. When set to `false` (default), only the instruction leaves are being rendered. | +| `traitOptions` | [`TraitOptions`](#trait-options) | `DEFAULT_TRAIT_OPTIONS` | A set of options that can be used to configure how traits are rendered for every Rust types. See [documentation below](#trait-options) for more information. | + +## Trait Options + +The Rust renderer provides sensible default traits when generating the various Rust types you client will use. However, you may wish to configure these traits to better suit your needs. The `traitOptions` attribute is here to help you with that. Let's see the various settings it provides. + +### Default traits + +Using the `traitOptions` attribute, you may configure the default traits that will be applied to every Rust type. These default traits can be configured using 4 different attributes: + +- `baseDefaults`: The default traits to implement for all types. +- `dataEnumDefaults`: The default traits to implement for all data enum types, in addition to the `baseDefaults` traits. Data enums are enums with at least one non-unit variant — e.g. `pub enum Command { Write(String), Quit }`. +- `scalarEnumDefaults`: The default traits to implement for all scalar enum types, in addition to the `baseDefaults` traits. Scalar enums are enums with unit variants only — e.g. `pub enum Feedback { Good, Bad }`. +- `structDefaults`: The default traits to implement for all struct types, in addition to the `baseDefaults` traits. + +Note that you must provide the fully qualified name of the traits you provide (e.g. `serde::Serialize`). Here are the default values for these attributes: + +```ts +const traitOptions = { + baseDefaults: [ + 'borsh::BorshSerialize', + 'borsh::BorshDeserialize', + 'serde::Serialize', + 'serde::Deserialize', + 'Clone', + 'Debug', + 'Eq', + 'PartialEq', + ], + dataEnumDefaults: [], + scalarEnumDefaults: ['Copy', 'PartialOrd', 'Hash', 'num_derive::FromPrimitive'], + structDefaults: [], +}; +``` + +### Overridden traits + +In addition to configure the default traits, you may also override the traits for specific types. This will completely replace the default traits for the given type. To do so, you may use the `overrides` attribute of the `traitOptions` object. + +This attribute is a map where the keys are the names of the types you want to override, and the values are the traits you want to apply to these types. Here is an example: + +```ts +const traitOptions = { + overrides: { + myCustomType: ['Clone', 'my::custom::Trait', 'my::custom::OtherTrait'], + myTypeWithNoTraits: [], + }, +}; +``` + +### Feature Flags + +You may also configure which traits should be rendered under a feature flag by using the `featureFlags` attribute. This attribute is a map where the keys are feature flag names and the values are the traits that should be rendered under that feature flag. Here is an example: + +```ts +const traitOptions = { + featureFlags: { fruits: ['fruits::Apple', 'fruits::Banana'] }, +}; +``` + +Now, if at any point, we encounter a `fruits::Apple` or `fruits::Banana` trait to be rendered (either as default traits or as overridden traits), they will be rendered under the `fruits` feature flag. For instance: + +```rust +#[cfg_attr(feature = "fruits", derive(fruits::Apple, fruits::Banana))] +``` + +By default, the `featureFlags` option is set to the following: + +```ts +const traitOptions = { + featureFlags: { serde: ['serde::Serialize', 'serde::Deserialize'] }, +}; +``` + +Note that for feature flags to be effective, they must be added to the `Cargo.toml` file of the generated Rust client. + +### Using the Fully Qualified Name + +By default, all traits are imported using the provided Fully Qualified Name which means their short name will be used within the `derive` attributes. + +However, you may want to avoid importing these traits and use the Fully Qualified Name directly in the generated code. To do so, you may use the `useFullyQualifiedName` attribute of the `traitOptions` object by setting it to `true`: + +```ts +const traitOptions = { + useFullyQualifiedName: true, +}; +``` + +Here is an example of rendered traits with this option set to `true` and `false` (which is the default): + +```rust +// With `useFullyQualifiedName` set to `false` (default). +use serde::Serialize; +use serde::Deserialize; +// ... +#[derive(Serialize, Deserialize)] + +// With `useFullyQualifiedName` set to `true`. +#[derive(serde::Serialize, serde::Deserialize)] +``` + +Note that any trait rendered under a feature flag will always use the Fully Qualified Name in order to ensure we only reference the trait when the feature is enabled. diff --git a/packages/renderers-rust/src/getRenderMapVisitor.ts b/packages/renderers-rust/src/getRenderMapVisitor.ts index 05174b3f0..27835fde8 100644 --- a/packages/renderers-rust/src/getRenderMapVisitor.ts +++ b/packages/renderers-rust/src/getRenderMapVisitor.ts @@ -27,12 +27,14 @@ import { import { getTypeManifestVisitor } from './getTypeManifestVisitor'; import { ImportMap } from './ImportMap'; import { renderValueNode } from './renderValueNodeVisitor'; -import { getImportFromFactory, LinkOverrides, render } from './utils'; +import { getImportFromFactory, getTraitsFromNodeFactory, LinkOverrides, render, TraitOptions } from './utils'; export type GetRenderMapOptions = { + defaultTraitOverrides?: string[]; dependencyMap?: Record; linkOverrides?: LinkOverrides; renderParentInstructions?: boolean; + traitOptions?: TraitOptions; }; export function getRenderMapVisitor(options: GetRenderMapOptions = {}) { @@ -42,7 +44,8 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}) { const renderParentInstructions = options.renderParentInstructions ?? false; const dependencyMap = options.dependencyMap ?? {}; const getImportFrom = getImportFromFactory(options.linkOverrides ?? {}); - const typeManifestVisitor = getTypeManifestVisitor({ getImportFrom }); + const getTraitsFromNode = getTraitsFromNodeFactory(options.traitOptions); + const typeManifestVisitor = getTypeManifestVisitor({ getImportFrom, getTraitsFromNode }); return pipe( staticVisitor( @@ -147,6 +150,7 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}) { node.arguments.forEach(argument => { const argumentVisitor = getTypeManifestVisitor({ getImportFrom, + getTraitsFromNode, nestedStruct: true, parentName: `${pascalCase(node.name)}InstructionData`, }); @@ -187,6 +191,7 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}) { const struct = structTypeNodeFromInstructionArgumentNodes(node.arguments); const structVisitor = getTypeManifestVisitor({ getImportFrom, + getTraitsFromNode, parentName: `${pascalCase(node.name)}InstructionData`, }); const typeManifest = visit(struct, structVisitor); diff --git a/packages/renderers-rust/src/getTypeManifestVisitor.ts b/packages/renderers-rust/src/getTypeManifestVisitor.ts index 638b4ba67..e157471d6 100644 --- a/packages/renderers-rust/src/getTypeManifestVisitor.ts +++ b/packages/renderers-rust/src/getTypeManifestVisitor.ts @@ -2,9 +2,9 @@ import { CODAMA_ERROR__RENDERERS__UNSUPPORTED_NODE, CodamaError } from '@codama/ import { arrayTypeNode, CountNode, + definedTypeNode, fixedCountNode, isNode, - isScalarEnum, NumberTypeNode, numberTypeNode, pascalCase, @@ -17,7 +17,7 @@ import { import { extendVisitor, mergeVisitor, pipe, visit } from '@codama/visitors-core'; import { ImportMap } from './ImportMap'; -import { GetImportFromFunction, rustDocblock } from './utils'; +import { GetImportFromFunction, GetTraitsFromNodeFunction, rustDocblock } from './utils'; export type TypeManifest = { imports: ImportMap; @@ -27,10 +27,11 @@ export type TypeManifest = { export function getTypeManifestVisitor(options: { getImportFrom: GetImportFromFunction; + getTraitsFromNode: GetTraitsFromNodeFunction; nestedStruct?: boolean; parentName?: string | null; }) { - const { getImportFrom } = options; + const { getImportFrom, getTraitsFromNode } = options; let parentName: string | null = options.parentName ?? null; let nestedStruct: boolean = options.nestedStruct ?? false; let inlineStruct: boolean = false; @@ -50,14 +51,12 @@ export function getTypeManifestVisitor(options: { visitAccount(account, { self }) { parentName = pascalCase(account.name); const manifest = visit(account.data, self); - manifest.imports.add(['borsh::BorshSerialize', 'borsh::BorshDeserialize']); + const traits = getTraitsFromNode(account); + manifest.imports.mergeWith(traits.imports); parentName = null; return { ...manifest, - type: - '#[derive(BorshSerialize, BorshDeserialize, Clone, Debug, Eq, PartialEq)]\n' + - '#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]\n' + - `${manifest.type}`, + type: traits.render + manifest.type, }; }, @@ -141,41 +140,15 @@ export function getTypeManifestVisitor(options: { visitDefinedType(definedType, { self }) { parentName = pascalCase(definedType.name); const manifest = visit(definedType.type, self); + const traits = getTraitsFromNode(definedType); + manifest.imports.mergeWith(traits.imports); parentName = null; - const traits = ['BorshSerialize', 'BorshDeserialize', 'Clone', 'Debug', 'Eq', 'PartialEq']; - if (isNode(definedType.type, 'enumTypeNode') && isScalarEnum(definedType.type)) { - traits.push('Copy', 'PartialOrd', 'Hash', 'FromPrimitive'); - manifest.imports.add(['num_derive::FromPrimitive']); - } - - const nestedStructs = manifest.nestedStructs.map( - struct => - `#[derive(${traits.join(', ')})]\n` + - '#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]\n' + - `${struct}`, - ); - - if (!isNode(definedType.type, ['enumTypeNode', 'structTypeNode'])) { - if (nestedStructs.length > 0) { - manifest.imports.add(['borsh::BorshSerialize', 'borsh::BorshDeserialize']); - } - return { - ...manifest, - nestedStructs, - type: `pub type ${pascalCase(definedType.name)} = ${manifest.type};`, - }; - } + const renderedType = isNode(definedType.type, ['enumTypeNode', 'structTypeNode']) + ? manifest.type + : `pub type ${pascalCase(definedType.name)} = ${manifest.type};`; - manifest.imports.add(['borsh::BorshSerialize', 'borsh::BorshDeserialize']); - return { - ...manifest, - nestedStructs, - type: - `#[derive(${traits.join(', ')})]\n` + - '#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]\n' + - `${manifest.type}`, - }; + return { ...manifest, type: `${traits.render}${renderedType}` }; }, visitDefinedTypeLink(node) { @@ -450,11 +423,15 @@ export function getTypeManifestVisitor(options: { const mergedManifest = mergeManifests(fields); if (nestedStruct) { + const nestedTraits = getTraitsFromNode( + definedTypeNode({ name: originalParentName, type: structType }), + ); + mergedManifest.imports.mergeWith(nestedTraits.imports); return { ...mergedManifest, nestedStructs: [ ...mergedManifest.nestedStructs, - `pub struct ${pascalCase(originalParentName)} {\n${fieldTypes}\n}`, + `${nestedTraits.render}pub struct ${pascalCase(originalParentName)} {\n${fieldTypes}\n}`, ], type: pascalCase(originalParentName), }; diff --git a/packages/renderers-rust/src/utils/index.ts b/packages/renderers-rust/src/utils/index.ts index 258a76212..4de9615fb 100644 --- a/packages/renderers-rust/src/utils/index.ts +++ b/packages/renderers-rust/src/utils/index.ts @@ -1,3 +1,4 @@ export * from './codecs'; export * from './linkOverrides'; export * from './render'; +export * from './traitOptions'; diff --git a/packages/renderers-rust/src/utils/traitOptions.ts b/packages/renderers-rust/src/utils/traitOptions.ts new file mode 100644 index 000000000..04615a89d --- /dev/null +++ b/packages/renderers-rust/src/utils/traitOptions.ts @@ -0,0 +1,163 @@ +import { AccountNode, assertIsNode, camelCase, DefinedTypeNode, isNode, isScalarEnum } from '@codama/nodes'; + +import { ImportMap } from '../ImportMap'; + +export type TraitOptions = { + /** The default traits to implement for all types. */ + baseDefaults?: string[]; + /** + * The default traits to implement for data enums only — on top of the base defaults. + * Data enums are enums with at least one non-unit variant. + */ + dataEnumDefaults?: string[]; + /** + * The mapping of feature flags to traits. + * For each entry, the traits will be rendered within a + * `#[cfg_attr(feature = "feature_name", derive(Traits))]` attribute. + */ + featureFlags?: Record; + /** The complete trait overrides of specific types. */ + overrides?: Record; + /** + * The default traits to implement for scalar enums only — on top of the base defaults. + * Scalar enums are enums with no variants or only unit variants. + */ + scalarEnumDefaults?: string[]; + /** The default traits to implement for structs only — on top of the base defaults. */ + structDefaults?: string[]; + /** Whether or not to use the fully qualified name for traits, instead of importing them. */ + useFullyQualifiedName?: boolean; +}; + +export const DEFAULT_TRAIT_OPTIONS: Required = { + baseDefaults: [ + 'borsh::BorshSerialize', + 'borsh::BorshDeserialize', + 'serde::Serialize', + 'serde::Deserialize', + 'Clone', + 'Debug', + 'Eq', + 'PartialEq', + ], + dataEnumDefaults: [], + featureFlags: { serde: ['serde::Serialize', 'serde::Deserialize'] }, + overrides: {}, + scalarEnumDefaults: ['Copy', 'PartialOrd', 'Hash', 'num_derive::FromPrimitive'], + structDefaults: [], + useFullyQualifiedName: false, +}; + +export type GetTraitsFromNodeFunction = (node: AccountNode | DefinedTypeNode) => { imports: ImportMap; render: string }; + +export function getTraitsFromNodeFactory(options: TraitOptions = {}): GetTraitsFromNodeFunction { + return node => getTraitsFromNode(node, options); +} + +export function getTraitsFromNode( + node: AccountNode | DefinedTypeNode, + userOptions: TraitOptions = {}, +): { imports: ImportMap; render: string } { + assertIsNode(node, ['accountNode', 'definedTypeNode']); + const options: Required = { ...DEFAULT_TRAIT_OPTIONS, ...userOptions }; + + // Get the node type and return early if it's a type alias. + const nodeType = getNodeType(node); + if (nodeType === 'alias') { + return { imports: new ImportMap(), render: '' }; + } + + // Find all the FQN traits for the node. + const sanitizedOverrides = Object.fromEntries( + Object.entries(options.overrides).map(([key, value]) => [camelCase(key), value]), + ); + const nodeOverrides: string[] | undefined = sanitizedOverrides[node.name]; + const allTraits = nodeOverrides === undefined ? getDefaultTraits(nodeType, options) : nodeOverrides; + + // Wrap the traits in feature flags if necessary. + const partitionedTraits = partitionTraitsInFeatures(allTraits, options.featureFlags); + let unfeaturedTraits = partitionedTraits[0]; + const featuredTraits = partitionedTraits[1]; + + // Import the traits if necessary. + const imports = new ImportMap(); + if (!options.useFullyQualifiedName) { + unfeaturedTraits = extractFullyQualifiedNames(unfeaturedTraits, imports); + } + + // Render the trait lines. + const traitLines: string[] = [ + ...(unfeaturedTraits.length > 0 ? [`#[derive(${unfeaturedTraits.join(', ')})]\n`] : []), + ...Object.entries(featuredTraits).map(([feature, traits]) => { + return `#[cfg_attr(feature = "${feature}", derive(${traits.join(', ')}))]\n`; + }), + ]; + + return { imports, render: traitLines.join('') }; +} + +function getNodeType(node: AccountNode | DefinedTypeNode): 'alias' | 'dataEnum' | 'scalarEnum' | 'struct' { + if (isNode(node, 'accountNode')) return 'struct'; + if (isNode(node.type, 'structTypeNode')) return 'struct'; + if (isNode(node.type, 'enumTypeNode')) { + return isScalarEnum(node.type) ? 'scalarEnum' : 'dataEnum'; + } + return 'alias'; +} + +function getDefaultTraits( + nodeType: 'dataEnum' | 'scalarEnum' | 'struct', + options: Pick< + Required, + 'baseDefaults' | 'dataEnumDefaults' | 'scalarEnumDefaults' | 'structDefaults' + >, +): string[] { + switch (nodeType) { + case 'dataEnum': + return [...options.baseDefaults, ...options.dataEnumDefaults]; + case 'scalarEnum': + return [...options.baseDefaults, ...options.scalarEnumDefaults]; + case 'struct': + return [...options.baseDefaults, ...options.structDefaults]; + } +} + +function partitionTraitsInFeatures( + traits: string[], + featureFlags: Record, +): [string[], Record] { + // Reverse the feature flags option for quick lookup. + // If there are any duplicate traits, the first one encountered will be used. + const reverseFeatureFlags = Object.entries(featureFlags).reduce( + (acc, [feature, traits]) => { + for (const trait of traits) { + if (!acc[trait]) acc[trait] = feature; + } + return acc; + }, + {} as Record, + ); + + const unfeaturedTraits: string[] = []; + const featuredTraits: Record = {}; + for (const trait of traits) { + const feature: string | undefined = reverseFeatureFlags[trait]; + if (feature === undefined) { + unfeaturedTraits.push(trait); + } else { + if (!featuredTraits[feature]) featuredTraits[feature] = []; + featuredTraits[feature].push(trait); + } + } + + return [unfeaturedTraits, featuredTraits]; +} + +function extractFullyQualifiedNames(traits: string[], imports: ImportMap): string[] { + return traits.map(trait => { + const index = trait.lastIndexOf('::'); + if (index === -1) return trait; + imports.add(trait); + return trait.slice(index + 2); + }); +} diff --git a/packages/renderers-rust/test/utils/traitOptions.test.ts b/packages/renderers-rust/test/utils/traitOptions.test.ts new file mode 100644 index 000000000..4ef1b4a4e --- /dev/null +++ b/packages/renderers-rust/test/utils/traitOptions.test.ts @@ -0,0 +1,360 @@ +import { + accountNode, + definedTypeNode, + enumEmptyVariantTypeNode, + enumStructVariantTypeNode, + enumTypeNode, + numberTypeNode, + structFieldTypeNode, + structTypeNode, +} from '@codama/nodes'; +import { describe, expect, test } from 'vitest'; + +import { getTraitsFromNode, TraitOptions } from '../../src/utils'; + +describe('default values', () => { + test('it defaults to a set of traits for data enums', () => { + // Given a data enum defined type. + const node = definedTypeNode({ + name: 'Command', + type: enumTypeNode([ + enumStructVariantTypeNode( + 'Play', + structTypeNode([structFieldTypeNode({ name: 'guess', type: numberTypeNode('u16') })]), + ), + enumEmptyVariantTypeNode('Quit'), + ]), + }); + + // When we get the traits from the node using the default options. + const { render, imports } = getTraitsFromNode(node); + + // Then we expect the following traits to be rendered. + expect(render).toBe( + `#[derive(BorshSerialize, BorshDeserialize, Clone, Debug, Eq, PartialEq)]\n` + + `#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]\n`, + ); + + // And the following imports to be used. + expect([...imports.imports]).toStrictEqual(['borsh::BorshSerialize', 'borsh::BorshDeserialize']); + }); + + test('it defaults to a set of traits for scalar enums', () => { + // Given a scalar enum defined type. + const node = definedTypeNode({ + name: 'Feedback', + type: enumTypeNode([enumEmptyVariantTypeNode('Good'), enumEmptyVariantTypeNode('Bad')]), + }); + + // When we get the traits from the node using the default options. + const { render, imports } = getTraitsFromNode(node); + + // Then we expect the following traits to be rendered. + expect(render).toBe( + `#[derive(BorshSerialize, BorshDeserialize, Clone, Debug, Eq, PartialEq, Copy, PartialOrd, Hash, FromPrimitive)]\n` + + `#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]\n`, + ); + + // And the following imports to be used. + expect([...imports.imports]).toStrictEqual([ + 'borsh::BorshSerialize', + 'borsh::BorshDeserialize', + 'num_derive::FromPrimitive', + ]); + }); + + test('it defaults to a set of traits for structs', () => { + // Given an account node. + const node = accountNode({ + data: structTypeNode([ + structFieldTypeNode({ name: 'x', type: numberTypeNode('u64') }), + structFieldTypeNode({ name: 'y', type: numberTypeNode('u64') }), + ]), + name: 'Coordinates', + }); + + // When we get the traits from the node using the default options. + const { render, imports } = getTraitsFromNode(node); + + // Then we expect the following traits to be rendered. + expect(render).toBe( + `#[derive(BorshSerialize, BorshDeserialize, Clone, Debug, Eq, PartialEq)]\n` + + `#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]\n`, + ); + + // And the following imports to be used. + expect([...imports.imports]).toStrictEqual(['borsh::BorshSerialize', 'borsh::BorshDeserialize']); + }); + + test('it does not use default traits if they are overridden', () => { + // Given a defined type node that should use custom traits. + const node = accountNode({ + data: structTypeNode([ + structFieldTypeNode({ name: 'x', type: numberTypeNode('u64') }), + structFieldTypeNode({ name: 'y', type: numberTypeNode('u64') }), + ]), + name: 'Coordinates', + }); + + // When we get the traits from the node using the + // default options with the overrides attribute. + const { render, imports } = getTraitsFromNode(node, { + overrides: { coordinates: ['My', 'special::Traits'] }, + }); + + // Then we expect the following traits to be rendered. + expect(render).toBe(`#[derive(My, Traits)]\n`); + + // And the following imports to be used. + expect([...imports.imports]).toStrictEqual(['special::Traits']); + }); + + test('it still uses feature flags for overridden traits', () => { + // Given a defined type node that should use custom traits. + const node = accountNode({ + data: structTypeNode([ + structFieldTypeNode({ name: 'x', type: numberTypeNode('u64') }), + structFieldTypeNode({ name: 'y', type: numberTypeNode('u64') }), + ]), + name: 'Coordinates', + }); + + // When we get the traits from the node using custom traits + // such that some are part of the feature flag defaults. + const { render } = getTraitsFromNode(node, { + overrides: { coordinates: ['My', 'special::Traits', 'serde::Serialize'] }, + }); + + // Then we expect the following traits to be rendered. + expect(render).toBe(`#[derive(My, Traits)]\n#[cfg_attr(feature = "serde", derive(serde::Serialize))]\n`); + }); +}); + +const RESET_OPTIONS: Required = { + baseDefaults: [], + dataEnumDefaults: [], + featureFlags: {}, + overrides: {}, + scalarEnumDefaults: [], + structDefaults: [], + useFullyQualifiedName: false, +}; + +describe('base traits', () => { + test('it uses both the base and data enum traits', () => { + // Given a data enum defined type. + const node = definedTypeNode({ + name: 'Command', + type: enumTypeNode([ + enumStructVariantTypeNode( + 'Play', + structTypeNode([structFieldTypeNode({ name: 'guess', type: numberTypeNode('u16') })]), + ), + enumEmptyVariantTypeNode('Quit'), + ]), + }); + + // When we get the traits from the node using custom base and data enum defaults. + const { render } = getTraitsFromNode(node, { + ...RESET_OPTIONS, + baseDefaults: ['MyBaseTrait'], + dataEnumDefaults: ['MyDataEnumTrait'], + }); + + // Then we expect both the base and data enum traits to be rendered. + expect(render).toBe(`#[derive(MyBaseTrait, MyDataEnumTrait)]\n`); + }); + + test('it uses both the base and scalar enum traits', () => { + // Given a scalar enum defined type. + const node = definedTypeNode({ + name: 'Feedback', + type: enumTypeNode([enumEmptyVariantTypeNode('Good'), enumEmptyVariantTypeNode('Bad')]), + }); + + // When we get the traits from the node using custom base and scalar enum defaults. + const { render } = getTraitsFromNode(node, { + ...RESET_OPTIONS, + baseDefaults: ['MyBaseTrait'], + scalarEnumDefaults: ['MyScalarEnumTrait'], + }); + + // Then we expect both the base and scalar enum traits to be rendered. + expect(render).toBe(`#[derive(MyBaseTrait, MyScalarEnumTrait)]\n`); + }); + + test('it uses both the base and struct traits', () => { + // Given an account node. + const node = accountNode({ + data: structTypeNode([ + structFieldTypeNode({ name: 'x', type: numberTypeNode('u64') }), + structFieldTypeNode({ name: 'y', type: numberTypeNode('u64') }), + ]), + name: 'Coordinates', + }); + + // When we get the traits from the node using custom base and struct defaults. + const { render } = getTraitsFromNode(node, { + ...RESET_OPTIONS, + baseDefaults: ['MyBaseTrait'], + structDefaults: ['MyStructTrait'], + }); + + // Then we expect both the base and struct traits to be rendered. + expect(render).toBe(`#[derive(MyBaseTrait, MyStructTrait)]\n`); + }); + + test('it never uses traits for type aliases', () => { + // Given a defined type node that is not an enum or struct. + const node = definedTypeNode({ + name: 'Score', + type: numberTypeNode('u64'), + }); + + // When we get the traits from the node such that we have base defaults. + const { render } = getTraitsFromNode(node, { + ...RESET_OPTIONS, + baseDefaults: ['MyBaseTrait'], + }); + + // Then we expect no traits to be rendered. + expect(render).toBe(''); + }); + + test('it identifies feature flags under all default traits', () => { + // Given a scalar enum defined type. + const node = definedTypeNode({ + name: 'Feedback', + type: enumTypeNode([enumEmptyVariantTypeNode('Good'), enumEmptyVariantTypeNode('Bad')]), + }); + + // When we get the traits from the node such that: + // - We provide custom base and scalar enum defaults. + // - We provide custom feature flags for traits in both categories. + const { render } = getTraitsFromNode(node, { + ...RESET_OPTIONS, + baseDefaults: ['MyBaseTrait', 'MyNonFeatureTrait'], + featureFlags: { + base: ['MyBaseTrait'], + enum: ['MyScalarEnumTrait'], + }, + scalarEnumDefaults: ['MyScalarEnumTrait'], + }); + + // Then we expect both the base and enum traits to be rendered as separate feature flags. + expect(render).toBe( + `#[derive(MyNonFeatureTrait)]\n` + + `#[cfg_attr(feature = "base", derive(MyBaseTrait))]\n` + + `#[cfg_attr(feature = "enum", derive(MyScalarEnumTrait))]\n`, + ); + }); + + test('it renders traits correctly when they are all under feature flags', () => { + // Given a scalar enum defined type. + const node = definedTypeNode({ + name: 'Feedback', + type: enumTypeNode([enumEmptyVariantTypeNode('Good'), enumEmptyVariantTypeNode('Bad')]), + }); + + // When we get the traits from the node such that + // all traits are under feature flags. + const { render } = getTraitsFromNode(node, { + ...RESET_OPTIONS, + baseDefaults: ['MyBaseTrait'], + featureFlags: { + base: ['MyBaseTrait'], + enum: ['MyScalarEnumTrait'], + }, + scalarEnumDefaults: ['MyScalarEnumTrait'], + }); + + // Then we expect the following traits to be rendered. + expect(render).toBe( + `#[cfg_attr(feature = "base", derive(MyBaseTrait))]\n#[cfg_attr(feature = "enum", derive(MyScalarEnumTrait))]\n`, + ); + }); +}); + +describe('overridden traits', () => { + test('it replaces all default traits with the overridden traits', () => { + // Given a scalar enum defined type. + const node = definedTypeNode({ + name: 'Feedback', + type: enumTypeNode([enumEmptyVariantTypeNode('Good'), enumEmptyVariantTypeNode('Bad')]), + }); + + // When we get the traits from the node such that: + // - We provide custom base and enum defaults. + // - We override the feedback type with custom traits. + const { render } = getTraitsFromNode(node, { + ...RESET_OPTIONS, + baseDefaults: ['MyBaseTrait'], + overrides: { feedback: ['MyFeedbackTrait'] }, + scalarEnumDefaults: ['MyScalarEnumTrait'], + }); + + // Then we expect only the feedback traits to be rendered. + expect(render).toBe(`#[derive(MyFeedbackTrait)]\n`); + }); + + test('it finds traits to override when using pascal case', () => { + // Given a scalar enum defined type. + const node = definedTypeNode({ + name: 'Feedback', + type: enumTypeNode([enumEmptyVariantTypeNode('Good'), enumEmptyVariantTypeNode('Bad')]), + }); + + // When we get the traits from the node such that + // we use PascalCase for the type name. + const { render } = getTraitsFromNode(node, { + ...RESET_OPTIONS, + overrides: { Feedback: ['MyFeedbackTrait'] }, + }); + + // Then we still expect the custom feedback traits to be rendered. + expect(render).toBe(`#[derive(MyFeedbackTrait)]\n`); + }); + + test('it identifies feature flags under all overridden traits', () => { + // Given a scalar enum defined type. + const node = definedTypeNode({ + name: 'Feedback', + type: enumTypeNode([enumEmptyVariantTypeNode('Good'), enumEmptyVariantTypeNode('Bad')]), + }); + + // When we get the traits from the node such that: + // - We override the feedback type with custom traits. + // - We provide custom feature flags for these some of these custom traits. + const { render } = getTraitsFromNode(node, { + ...RESET_OPTIONS, + featureFlags: { custom: ['MyFeedbackTrait'] }, + overrides: { feedback: ['MyFeedbackTrait', 'MyNonFeatureTrait'] }, + }); + + // Then we expect some of the overridden traits to be rendered under feature flags. + expect(render).toBe(`#[derive(MyNonFeatureTrait)]\n#[cfg_attr(feature = "custom", derive(MyFeedbackTrait))]\n`); + }); +}); + +describe('fully qualified name traits', () => { + test('it can use fully qualified names for traits instead of importing them', () => { + // Given a scalar enum defined type. + const node = definedTypeNode({ + name: 'Feedback', + type: enumTypeNode([enumEmptyVariantTypeNode('Good'), enumEmptyVariantTypeNode('Bad')]), + }); + + // When we get the traits from the node such that we use fully qualified names. + const { render, imports } = getTraitsFromNode(node, { + ...RESET_OPTIONS, + baseDefaults: ['fruits::Apple', 'fruits::Banana', 'vegetables::Carrot'], + useFullyQualifiedName: true, + }); + + // Then we expect the fully qualified names to be used for the traits. + expect(render).toBe(`#[derive(fruits::Apple, fruits::Banana, vegetables::Carrot)]\n`); + + // And no imports should be used. + expect([...imports.imports]).toStrictEqual([]); + }); +});