Skip to content

Commit 2c6220c

Browse files
authored
feat: pydantic discriminator for unions (#2316)
1 parent 90beedd commit 2c6220c

File tree

8 files changed

+141
-38
lines changed

8 files changed

+141
-38
lines changed

.sonarcloud.properties

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
# Disable specific duplicate code since it would introduce more complexity to reduce it.
22
sonar.cpd.exclusions=src/generators/**/*.ts
33
sonar.exclusions=modelina-website/next.config.js
4+
sonar.sources=src,modelina-cli/src
5+
sonar.tests=test,modelina-cli/test

modelina-cli/src/helpers/python.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ export function buildPythonGenerator(flags: any): BuilderReturnType {
2323
if (packageName === undefined) {
2424
throw new Error('In order to generate models to Python, we need to know which package they are under. Add `--packageName=PACKAGENAME` to set the desired package name.');
2525
}
26-
27-
const presets = [];
26+
27+
const presets = [];
2828
if (pyDantic) {
2929
presets.push(PYTHON_PYDANTIC_PRESET);
3030
}

src/generators/python/PythonConstrainer.ts

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ import { defaultModelNameConstraints } from './constrainer/ModelNameConstrainer'
77
import { defaultPropertyKeyConstraints } from './constrainer/PropertyKeyConstrainer';
88
import { defaultConstantConstraints } from './constrainer/ConstantConstrainer';
99
import { PythonOptions, PythonTypeMapping } from './PythonGenerator';
10+
import { PYTHON_PYDANTIC_PRESET } from './presets';
11+
import {
12+
ConstrainedObjectModel,
13+
ConstrainedReferenceModel
14+
} from '../../models';
1015

1116
export const PythonDefaultTypeMapping: PythonTypeMapping = {
1217
Object({ constrainedModel }): string {
@@ -46,12 +51,38 @@ export const PythonDefaultTypeMapping: PythonTypeMapping = {
4651
//Returning name here because all enum models have been split out
4752
return constrainedModel.name;
4853
},
49-
Union({ constrainedModel }): string {
54+
Union({ constrainedModel, options, dependencyManager }): string {
55+
dependencyManager.addDependency('from typing import Union');
5056
const unionTypes = constrainedModel.union.map((unionModel) => {
5157
return unionModel.type;
5258
});
5359
const uniqueSet = new Set(unionTypes);
54-
return [...uniqueSet].join(' | ');
60+
// support Python pre 3.10
61+
let union = `Union[${[...uniqueSet].join(', ')}]`;
62+
63+
if (
64+
constrainedModel.options.discriminator &&
65+
options.presets?.includes(PYTHON_PYDANTIC_PRESET)
66+
) {
67+
const discriminator =
68+
constrainedModel.options.discriminator.discriminator;
69+
const reference = constrainedModel.union[0];
70+
if (reference instanceof ConstrainedReferenceModel) {
71+
const ref = reference.ref;
72+
if (ref instanceof ConstrainedObjectModel) {
73+
const properties = ref.properties;
74+
const property = Object.values(properties).find((property) => {
75+
return property.unconstrainedPropertyName === discriminator;
76+
});
77+
if (property !== undefined) {
78+
dependencyManager.addDependency('from typing import Annotated');
79+
union = `Annotated[${union}, Field(discriminator='${property.propertyName}')]`;
80+
}
81+
}
82+
}
83+
}
84+
85+
return union;
5586
},
5687
Dictionary({ constrainedModel }): string {
5788
return `dict[${constrainedModel.key.type}, ${constrainedModel.value.type}]`;

src/generators/python/presets/Pydantic.ts

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import {
22
ConstrainedDictionaryModel,
3-
ConstrainedObjectPropertyModel,
4-
ConstrainedUnionModel
3+
ConstrainedObjectPropertyModel
54
} from '../../../models';
65
import { PythonOptions } from '../PythonGenerator';
76
import { ClassPresetType, PythonPreset } from '../PythonPreset';
87

98
const PYTHON_PYDANTIC_CLASS_PRESET: ClassPresetType<PythonOptions> = {
109
async self({ renderer, model }) {
1110
renderer.dependencyManager.addDependency(
12-
'from typing import Optional, Any, Union'
11+
'from typing import Optional, Any'
1312
);
1413
renderer.dependencyManager.addDependency(
1514
'from pydantic import BaseModel, Field'
@@ -26,18 +25,19 @@ const PYTHON_PYDANTIC_CLASS_PRESET: ClassPresetType<PythonOptions> = {
2625
let type = property.property.type;
2726
const propertyName = property.propertyName;
2827

29-
if (property.property instanceof ConstrainedUnionModel) {
30-
const unionTypes = property.property.union.map(
31-
(unionModel) => unionModel.type
32-
);
33-
type = `Union[${unionTypes.join(', ')}]`;
34-
}
35-
3628
const isOptional =
3729
!property.required || property.property.options.isNullable === true;
3830
if (isOptional) {
3931
type = `Optional[${type}]`;
4032
}
33+
if (
34+
property.property.options.const &&
35+
model.options.discriminator?.discriminator ===
36+
property.unconstrainedPropertyName
37+
) {
38+
renderer.dependencyManager.addDependency('from typing import Literal');
39+
type = `Literal['${property.property.options.const.originalInput}']`;
40+
}
4141
type = renderer.renderPropertyType({
4242
modelType: model.type,
4343
propertyType: type
@@ -54,7 +54,14 @@ const PYTHON_PYDANTIC_CLASS_PRESET: ClassPresetType<PythonOptions> = {
5454
decoratorArgs.push('default=None');
5555
}
5656
if (property.property.options.const) {
57-
decoratorArgs.push(`default=${property.property.options.const.value}`);
57+
let value = property.property.options.const.value;
58+
if (
59+
model.options.discriminator?.discriminator ===
60+
property.unconstrainedPropertyName
61+
) {
62+
value = property.property.options.const.originalInput;
63+
}
64+
decoratorArgs.push(`default='${value}'`);
5865
decoratorArgs.push('frozen=True');
5966
}
6067
if (

test/generators/python/PythonConstrainer.spec.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ describe('PythonConstrainer', () => {
196196
constrainedModel: model,
197197
...defaultOptions
198198
});
199-
expect(type).toEqual('str');
199+
expect(type).toEqual('Union[str]');
200200
});
201201
test('should render multiple types', () => {
202202
const unionModel1 = new ConstrainedStringModel(
@@ -219,7 +219,7 @@ describe('PythonConstrainer', () => {
219219
constrainedModel: model,
220220
...defaultOptions
221221
});
222-
expect(type).toEqual('str | str2');
222+
expect(type).toEqual('Union[str, str2]');
223223
});
224224
test('should render unique types', () => {
225225
const unionModel1 = new ConstrainedStringModel(
@@ -242,7 +242,7 @@ describe('PythonConstrainer', () => {
242242
constrainedModel: model,
243243
...defaultOptions
244244
});
245-
expect(type).toEqual('str');
245+
expect(type).toEqual('Union[str]');
246246
});
247247
});
248248

test/generators/python/__snapshots__/PythonGenerator.spec.ts.snap

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ exports[`PythonGenerator Class should handle self reference models 1`] = `
1212
if 'map_model' in input:
1313
self._map_model: dict[str, Address] = input['map_model']
1414
if 'union_model' in input:
15-
self._union_model: Address | str = input['union_model']
15+
self._union_model: Union[Address, str] = input['union_model']
1616
1717
@property
1818
def self_model(self) -> Address:
@@ -43,10 +43,10 @@ exports[`PythonGenerator Class should handle self reference models 1`] = `
4343
self._map_model = map_model
4444
4545
@property
46-
def union_model(self) -> Address | str:
46+
def union_model(self) -> Union[Address, str]:
4747
return self._union_model
4848
@union_model.setter
49-
def union_model(self, union_model: Address | str):
49+
def union_model(self, union_model: Union[Address, str]):
5050
self._union_model = union_model
5151
"
5252
`;
@@ -76,10 +76,10 @@ exports[`PythonGenerator Class should render \`class\` type 1`] = `
7676
if 'marriage' in input:
7777
self._marriage: bool = input['marriage']
7878
if 'members' in input:
79-
self._members: str | float | bool = input['members']
80-
self._array_type: List[str | float | Any] = input['array_type']
79+
self._members: Union[str, float, bool] = input['members']
80+
self._array_type: List[Union[str, float, Any]] = input['array_type']
8181
if 'additional_properties' in input:
82-
self._additional_properties: dict[str, Any | str] = input['additional_properties']
82+
self._additional_properties: dict[str, Union[Any, str]] = input['additional_properties']
8383
8484
@property
8585
def street_name(self) -> str:
@@ -117,24 +117,24 @@ exports[`PythonGenerator Class should render \`class\` type 1`] = `
117117
self._marriage = marriage
118118
119119
@property
120-
def members(self) -> str | float | bool:
120+
def members(self) -> Union[str, float, bool]:
121121
return self._members
122122
@members.setter
123-
def members(self, members: str | float | bool):
123+
def members(self, members: Union[str, float, bool]):
124124
self._members = members
125125
126126
@property
127-
def array_type(self) -> List[str | float | Any]:
127+
def array_type(self) -> List[Union[str, float, Any]]:
128128
return self._array_type
129129
@array_type.setter
130-
def array_type(self, array_type: List[str | float | Any]):
130+
def array_type(self, array_type: List[Union[str, float, Any]]):
131131
self._array_type = array_type
132132
133133
@property
134-
def additional_properties(self) -> dict[str, Any | str]:
134+
def additional_properties(self) -> dict[str, Union[Any, str]]:
135135
return self._additional_properties
136136
@additional_properties.setter
137-
def additional_properties(self, additional_properties: dict[str, Any | str]):
137+
def additional_properties(self, additional_properties: dict[str, Union[Any, str]]):
138138
self._additional_properties = additional_properties
139139
"
140140
`;

test/generators/python/presets/Pydantic.spec.ts

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,44 @@ describe('PYTHON_PYDANTIC_PRESET', () => {
104104
},
105105
components: {
106106
messages: {
107+
Garage: {
108+
payload: {
109+
$ref: '#/components/schemas/Garage'
110+
}
111+
},
107112
Vehicle: {
108113
payload: {
109-
oneOf: [
110-
{ $ref: '#/components/schemas/Car' },
111-
{ $ref: '#/components/schemas/Truck' }
112-
]
114+
$ref: '#/components/schemas/Vehicle'
113115
}
114116
}
115117
},
116118
schemas: {
119+
Garage: {
120+
title: 'Garage',
121+
type: 'object',
122+
properties: {
123+
favorite: {
124+
$ref: '#/components/schemas/Vehicle'
125+
},
126+
vehicles: {
127+
type: 'array',
128+
items: {
129+
$ref: '#/components/schemas/Vehicle'
130+
}
131+
}
132+
}
133+
},
117134
Vehicle: {
118135
title: 'Vehicle',
119136
type: 'object',
137+
oneOf: [
138+
{ $ref: '#/components/schemas/Car' },
139+
{ $ref: '#/components/schemas/Truck' }
140+
]
141+
},
142+
VehicleBase: {
143+
title: 'VehicleBase',
144+
type: 'object',
120145
discriminator: 'vehicleType',
121146
properties: {
122147
vehicleType: {
@@ -132,7 +157,7 @@ describe('PYTHON_PYDANTIC_PRESET', () => {
132157
},
133158
Car: {
134159
allOf: [
135-
{ $ref: '#/components/schemas/Vehicle' },
160+
{ $ref: '#/components/schemas/VehicleBase' },
136161
{
137162
type: 'object',
138163
properties: {
@@ -145,7 +170,7 @@ describe('PYTHON_PYDANTIC_PRESET', () => {
145170
},
146171
Truck: {
147172
allOf: [
148-
{ $ref: '#/components/schemas/Vehicle' },
173+
{ $ref: '#/components/schemas/VehicleBase' },
149174
{
150175
type: 'object',
151176
properties: {

test/generators/python/presets/__snapshots__/Pydantic.spec.ts.snap

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,46 @@ Array [
4545
exports[`PYTHON_PYDANTIC_PRESET should render default value for discriminator when using polymorphism 1`] = `
4646
Array [
4747
"",
48+
"class Garage(BaseModel):
49+
favorite: Optional[Annotated[Union[Car, Truck], Field(discriminator='vehicle_type')]] = Field(default=None)
50+
vehicles: Optional[List[Annotated[Union[Car, Truck], Field(discriminator='vehicle_type')]]] = Field(default=None)
51+
additional_properties: Optional[dict[str, Any]] = Field(default=None, exclude=True)
52+
53+
@model_serializer(mode='wrap')
54+
def custom_serializer(self, handler):
55+
serialized_self = handler(self)
56+
additional_properties = getattr(self, \\"additional_properties\\")
57+
if additional_properties is not None:
58+
for key, value in additional_properties.items():
59+
# Never overwrite existing values, to avoid clashes
60+
if not hasattr(serialized_self, key):
61+
serialized_self[key] = value
62+
63+
return serialized_self
64+
65+
@model_validator(mode='before')
66+
@classmethod
67+
def unwrap_additional_properties(cls, data):
68+
if not isinstance(data, dict):
69+
data = data.model_dump()
70+
json_properties = list(data.keys())
71+
known_object_properties = ['favorite', 'vehicles', 'additional_properties']
72+
unknown_object_properties = [element for element in json_properties if element not in known_object_properties]
73+
# Ignore attempts that validate regular models, only when unknown input is used we add unwrap extensions
74+
if len(unknown_object_properties) == 0:
75+
return data
76+
77+
known_json_properties = ['favorite', 'vehicles', 'additionalProperties']
78+
additional_properties = data.get('additional_properties', {})
79+
for obj_key in unknown_object_properties:
80+
if not known_json_properties.__contains__(obj_key):
81+
additional_properties[obj_key] = data.pop(obj_key, None)
82+
data['additional_properties'] = additional_properties
83+
return data
84+
85+
",
4886
"class Car(BaseModel):
49-
vehicle_type: VehicleType = Field(default=VehicleType.CAR, frozen=True, alias='''vehicleType''')
87+
vehicle_type: Literal['Car'] = Field(default='Car', frozen=True, alias='''vehicleType''')
5088
length: Optional[float] = Field(default=None)
5189
additional_properties: Optional[dict[str, Any]] = Field(default=None, exclude=True)
5290
@@ -87,7 +125,7 @@ Array [
87125
CAR = \\"Car\\"
88126
TRUCK = \\"Truck\\"",
89127
"class Truck(BaseModel):
90-
vehicle_type: VehicleType = Field(default=VehicleType.TRUCK, frozen=True, alias='''vehicleType''')
128+
vehicle_type: Literal['Truck'] = Field(default='Truck', frozen=True, alias='''vehicleType''')
91129
length: Optional[float] = Field(default=None)
92130
additional_properties: Optional[dict[str, Any]] = Field(default=None, exclude=True)
93131

0 commit comments

Comments
 (0)