Skip to content

Commit 91a499b

Browse files
authored
[rust-axum] Fix polymorphic type discriminators (#22222)
* [rust-axum] Make discriminator field name camelCase * [rust-axum] Give polymorphic enum values a serde alias using the mapping keys if available * update samples * update test file checksum
1 parent eda2e67 commit 91a499b

File tree

8 files changed

+413
-31
lines changed

8 files changed

+413
-31
lines changed

bin/utils/test_file_list.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,6 @@
5959
sha256: 45cdaba3d2adc212cd4f0184ad475419a95e2326254c2ef84175e210c922b2f3
6060
# rust axum test files
6161
- filename: "samples/server/petstore/rust-axum/output/rust-axum-oneof/tests/oneof_with_discriminator.rs"
62-
sha256: 2d4f5a069fdcb3057bb078d5e75b3de63cd477b97725e457079df24bd2c30600
62+
sha256: b2093528aac971193f2863a70f46eea45cf8bda79120b133a614599e80d8b46d
6363
- filename: "samples/server/petstore/rust-axum/output/openapi-v3/tests/oneof_untagged.rs"
6464
sha256: 1d3fb01f65e98290b1d3eece28014c7d3e3f2fdf18e7110249d3c591cc4642ab

modules/openapi-generator/src/main/java/org/openapitools/codegen/languages/RustAxumServerCodegen.java

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -648,38 +648,40 @@ public CodegenOperation fromOperation(String path, String httpMethod, Operation
648648
}
649649

650650
private void postProcessPolymorphism(final List<ModelMap> allModels) {
651-
final HashMap<String, List<String>> discriminatorsForModel = new HashMap<>();
651+
final HashMap<String, List<CodegenDiscriminator>> discriminatorsForModel = new HashMap<>();
652652

653653
for (final ModelMap mo : allModels) {
654654
final CodegenModel cm = mo.getModel();
655655

656656
final CodegenComposedSchemas cs = cm.getComposedSchemas();
657657
if (cs != null) {
658658
final List<CodegenProperty> csOneOf = cs.getOneOf();
659+
CodegenDiscriminator discriminator = cm.getDiscriminator();
660+
659661
if (csOneOf != null) {
660-
processPolymorphismDataType(csOneOf);
662+
processPolymorphismDataType(csOneOf, discriminator);
661663
cs.setOneOf(csOneOf);
662664
cm.setComposedSchemas(cs);
663665
}
664666

665667
final List<CodegenProperty> csAnyOf = cs.getAnyOf();
666668
if (csAnyOf != null) {
667-
processPolymorphismDataType(csAnyOf);
669+
processPolymorphismDataType(csAnyOf, discriminator);
668670
cs.setAnyOf(csAnyOf);
669671
cm.setComposedSchemas(cs);
670672
}
671673
}
672674

673675
if (cm.discriminator != null) {
674676
for (final String model : cm.oneOf) {
675-
final List<String> discriminators = discriminatorsForModel.getOrDefault(model, new ArrayList<>());
676-
discriminators.add(cm.discriminator.getPropertyName());
677+
final List<CodegenDiscriminator> discriminators = discriminatorsForModel.getOrDefault(model, new ArrayList<>());
678+
discriminators.add(cm.discriminator);
677679
discriminatorsForModel.put(model, discriminators);
678680
}
679681

680682
for (final String model : cm.anyOf) {
681-
final List<String> discriminators = discriminatorsForModel.getOrDefault(model, new ArrayList<>());
682-
discriminators.add(cm.discriminator.getPropertyName());
683+
final List<CodegenDiscriminator> discriminators = discriminatorsForModel.getOrDefault(model, new ArrayList<>());
684+
discriminators.add(cm.discriminator);
683685
discriminatorsForModel.put(model, discriminators);
684686
}
685687
}
@@ -689,11 +691,11 @@ private void postProcessPolymorphism(final List<ModelMap> allModels) {
689691
for (ModelMap mo : allModels) {
690692
final CodegenModel cm = mo.getModel();
691693

692-
final List<String> discriminators = discriminatorsForModel.get(cm.getSchemaName());
694+
final List<CodegenDiscriminator> discriminators = discriminatorsForModel.get(cm.getSchemaName());
693695
if (discriminators != null) {
694696
// If the discriminator field is not a defined attribute in the variant structure, create it.
695697
if (!discriminating(discriminators, cm)) {
696-
final String discriminator = discriminators.get(0);
698+
final CodegenDiscriminator discriminator = discriminators.get(0);
697699

698700
CodegenProperty property = new CodegenProperty();
699701

@@ -710,17 +712,18 @@ private void postProcessPolymorphism(final List<ModelMap> allModels) {
710712
property.isDiscriminator = true;
711713

712714
// Attributes based on the discriminator value
713-
property.baseName = discriminator;
714-
property.name = discriminator;
715-
property.nameInCamelCase = camelize(discriminator);
715+
property.baseName = discriminator.getPropertyBaseName();
716+
property.name = discriminator.getPropertyName();
717+
property.nameInCamelCase = camelize(discriminator.getPropertyName());
716718
property.nameInPascalCase = property.nameInCamelCase.substring(0, 1).toUpperCase(Locale.ROOT) + property.nameInCamelCase.substring(1);
717-
property.nameInSnakeCase = underscore(discriminator).toUpperCase(Locale.ROOT);
719+
property.nameInSnakeCase = underscore(discriminator.getPropertyName()).toUpperCase(Locale.ROOT);
718720
property.getter = String.format(Locale.ROOT, "get%s", property.nameInPascalCase);
719721
property.setter = String.format(Locale.ROOT, "set%s", property.nameInPascalCase);
720722
property.defaultValueWithParam = String.format(Locale.ROOT, " = data.%s;", property.name);
721723

722724
// Attributes based on the model name
723725
property.defaultValue = String.format(Locale.ROOT, "r#\"%s\"#.to_string()", cm.getSchemaName());
726+
property.discriminatorValue = getDiscriminatorValue(cm.getClassname(), discriminator);
724727
property.jsonSchema = String.format(Locale.ROOT, "{ \"default\":\"%s\"; \"type\":\"string\" }", cm.getSchemaName());
725728

726729
cm.vars.add(property);
@@ -743,14 +746,27 @@ private void postProcessPolymorphism(final List<ModelMap> allModels) {
743746
}
744747
}
745748

746-
private static boolean discriminating(final List<String> discriminatorsForModel, final CodegenModel cm) {
749+
private static String getDiscriminatorValue(String modelName, CodegenDiscriminator discriminator) {
750+
if (discriminator == null || discriminator.getMappedModels() == null) {
751+
return modelName;
752+
}
753+
return discriminator
754+
.getMappedModels()
755+
.stream()
756+
.filter(m -> m.getModelName().equals(modelName) && m.getMappingName() != null)
757+
.map(CodegenDiscriminator.MappedModel::getMappingName)
758+
.findFirst()
759+
.orElse(modelName);
760+
}
761+
762+
private static boolean discriminating(final List<CodegenDiscriminator> discriminatorsForModel, final CodegenModel cm) {
747763
resetDiscriminatorProperty(cm);
748764

749765
// Discriminator will be presented as enum tag -> One and only one tag is allowed
750766
int countString = 0;
751767
int countNonString = 0;
752768
for (final CodegenProperty var : cm.vars) {
753-
if (discriminatorsForModel.stream().anyMatch(discriminator -> var.baseName.equals(discriminator) || var.name.equals(discriminator))) {
769+
if (discriminatorsForModel.stream().anyMatch(discriminator -> var.baseName.equals(discriminator.getPropertyBaseName()) || var.name.equals(discriminator.getPropertyName()))) {
754770
if (var.isString) {
755771
var.isDiscriminator = true;
756772
++countString;
@@ -773,7 +789,7 @@ private static void resetDiscriminatorProperty(final CodegenModel cm) {
773789
}
774790
}
775791

776-
private static void processPolymorphismDataType(final List<CodegenProperty> cp) {
792+
private static void processPolymorphismDataType(final List<CodegenProperty> cp, CodegenDiscriminator discriminator) {
777793
final HashSet<String> dedupDataTypeWithEnum = new HashSet<>();
778794
final HashMap<String, Integer> dedupDataType = new HashMap<>();
779795

@@ -783,6 +799,7 @@ private static void processPolymorphismDataType(final List<CodegenProperty> cp)
783799
// Mainly needed for primitive types.
784800
model.datatypeWithEnum = camelize(model.dataType.replaceAll("(?:\\w+::)+(\\w+)", "$1")
785801
.replace("<", "Of").replace(">", "")).replace(" ", "").replace(",", "");
802+
model.discriminatorValue = getDiscriminatorValue(model.datatypeWithEnum, discriminator);
786803
if (!dedupDataTypeWithEnum.add(model.datatypeWithEnum)) {
787804
model.datatypeWithEnum += ++idx;
788805
}

modules/openapi-generator/src/main/resources/rust-axum/models.mustache

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,9 @@ impl std::str::FromStr for {{{classname}}} {
796796
pub enum {{{classname}}} {
797797
{{#composedSchemas}}
798798
{{#anyOf}}
799+
{{#discriminator}}
800+
#[serde(alias = "{{{discriminatorValue}}}")]
801+
{{/discriminator}}
799802
{{{datatypeWithEnum}}}({{{dataType}}}),
800803
{{/anyOf}}
801804
{{/composedSchemas}}
@@ -871,6 +874,9 @@ impl From<{{{dataType}}}> for {{{classname}}} {
871874
pub enum {{{classname}}} {
872875
{{#composedSchemas}}
873876
{{#oneOf}}
877+
{{#discriminator}}
878+
#[serde(alias = "{{{discriminatorValue}}}")]
879+
{{/discriminator}}
874880
{{{datatypeWithEnum}}}({{{dataType}}}),
875881
{{/oneOf}}
876882
{{/composedSchemas}}
@@ -1071,7 +1077,7 @@ pub struct {{{classname}}} {
10711077
{{#isString}}
10721078
impl {{{classname}}} {
10731079
fn _name_for_{{{name}}}() -> String {
1074-
String::from("{{{classname}}}")
1080+
String::from("{{#discriminatorValue}}{{{discriminatorValue}}}{{/discriminatorValue}}{{^discriminatorValue}}{{classname}}{{/discriminatorValue}}")
10751081
}
10761082
10771083
fn _serialize_{{{name}}}<S>(_: &String, s: S) -> Result<S::Ok, S::Error>

modules/openapi-generator/src/test/resources/3_0/petstore.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -738,4 +738,4 @@ components:
738738
type:
739739
type: string
740740
message:
741-
type: string
741+
type: string

modules/openapi-generator/src/test/resources/3_0/rust-axum/rust-axum-oneof.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,14 @@ components:
7676
additionalProperties: false
7777
discriminator:
7878
propertyName: op
79+
mapping:
80+
yo: "#/components/schemas/YoMessage"
7981
oneOf:
8082
- "$ref": "#/components/schemas/Hello"
8183
- "$ref": "#/components/schemas/Greeting"
8284
- "$ref": "#/components/schemas/Goodbye"
8385
- "$ref": "#/components/schemas/SomethingCompletelyDifferent"
86+
- "$ref": "#/components/schemas/YoMessage"
8487
title: Message
8588
Hello:
8689
type: object
@@ -141,3 +144,17 @@ components:
141144
type: object
142145
type: array
143146
- type: object
147+
YoMessage:
148+
type: object
149+
title: Yo
150+
properties:
151+
d:
152+
type: object
153+
properties:
154+
nickname:
155+
type: string
156+
required:
157+
- nickname
158+
required:
159+
- op
160+
- d

0 commit comments

Comments
 (0)