Skip to content

Commit d0f82cc

Browse files
committed
Add gaps identified by the reviewer and add tests to prevent regression
1 parent b12d7f1 commit d0f82cc

5 files changed

Lines changed: 328 additions & 11 deletions

File tree

codegen/core/src/main/java/software/amazon/smithy/python/codegen/writer/ImportDeclarations.java

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,34 @@ private ImportDeclarations addImportToMap(
108108
* Retroactively aliases an existing import to avoid name collisions.
109109
*
110110
* <p>This is called by {@link PythonWriter} at {@code toString()} time when a collision
111-
* is detected between an imported name and a locally-defined name. The import statement
112-
* is rewritten from {@code from namespace import name} to {@code from namespace import name as alias}.
111+
* is detected. The import statement is rewritten from {@code from namespace import name}
112+
* to {@code from namespace import name as alias}.
113113
*
114-
* @param namespace The module namespace of the import.
114+
* <p>The lookup mirrors the dispatch logic in {@link #addImport} so that aliasing works
115+
* for all three import maps:
116+
* <ul>
117+
* <li>{@code stdlibImports} — keyed by absolute namespace.</li>
118+
* <li>{@code externalImports} — keyed by absolute namespace.</li>
119+
* <li>{@code localImports} — keyed by relativized namespace for same-package imports.</li>
120+
* </ul>
121+
*
122+
* @param namespace The module namespace of the import (absolute).
115123
* @param name The original imported name.
116124
* @param alias The alias to use instead.
117125
*/
118126
void aliasImport(String namespace, String name, String alias) {
127+
// stdlib imports are stored with absolute namespaces.
128+
aliasImportInMap(namespace, name, alias, stdlibImports);
129+
130+
// external imports are stored with absolute namespaces.
119131
aliasImportInMap(namespace, name, alias, externalImports);
120-
aliasImportInMap(namespace, name, alias, localImports);
132+
133+
// local imports may be stored with relativized namespaces (mirroring addImport).
134+
if (namespace.startsWith(settings.moduleName())) {
135+
var isTestModule = this.localNamespace.startsWith("tests");
136+
var ns = isTestModule ? namespace : relativize(namespace);
137+
aliasImportInMap(ns, name, alias, localImports);
138+
}
121139
}
122140

123141
private void aliasImportInMap(

codegen/core/src/main/java/software/amazon/smithy/python/codegen/writer/PythonWriter.java

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -479,15 +479,29 @@ public String apply(Object type, String indent) {
479479
/**
480480
* Returns a placeholder token for the symbol and registers it in the symbol table.
481481
*
482-
* <p>Only symbols that are imported from another module (have a namespace but no
483-
* definition file) get a placeholder. Symbols without a namespace (builtins like
484-
* int, str, bool) and symbols with a definition file (locally-defined types) are
485-
* returned as-is since they don't need collision detection.
482+
* <p>Symbols that do not participate in collision detection are returned as-is:
483+
* <ul>
484+
* <li>Builtins (no namespace) such as {@code int}, {@code str}, {@code bool}:
485+
* they produce no import statement, so there is no import we could alias
486+
* to resolve a collision. A sanity check across all public AWS service
487+
* models confirms no class-generating shape name collides with a Python
488+
* built-in emitted by the codegen, so builtins are not special-cased here.</li>
489+
* <li>Symbols defined in the current writer's file (namespace equals the
490+
* writer's package name): they produce no import and cannot collide
491+
* with themselves.</li>
492+
* </ul>
493+
*
494+
* <p>All other symbols (framework types and generated symbols imported from
495+
* other files in the same package) are registered in the symbol table so
496+
* collisions can be detected at {@link PythonWriter#toString()} time.
486497
*/
487498
private String resolvePlaceholder(Symbol symbol) {
488-
if (symbol.getNamespace().isEmpty() || !symbol.getDefinitionFile().isEmpty()) {
489-
// No namespace means builtin. Has definition file means locally-defined.
490-
// Neither case needs collision detection.
499+
if (symbol.getNamespace().isEmpty()) {
500+
// Builtin — no import statement is produced, so there is no alias to apply.
501+
return symbol.getName();
502+
}
503+
if (symbol.getNamespace().equals(fullPackageName)) {
504+
// Defined in the current writer's file — no import needed.
491505
return symbol.getName();
492506
}
493507

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package software.amazon.smithy.python.codegen;
6+
7+
import static org.junit.jupiter.api.Assertions.assertEquals;
8+
9+
import org.junit.jupiter.api.Test;
10+
import software.amazon.smithy.model.Model;
11+
import software.amazon.smithy.model.shapes.MemberShape;
12+
import software.amazon.smithy.model.shapes.ShapeId;
13+
import software.amazon.smithy.model.shapes.UnionShape;
14+
15+
public class PythonSymbolProviderTest {
16+
17+
private static final String NS = "smithy.example";
18+
19+
@Test
20+
public void testUnionMemberVariantNameCollidingWithShapeUsesUnderscoreSeparator() {
21+
Model model = loadModel("""
22+
$version: "2"
23+
namespace smithy.example
24+
25+
service TestService {
26+
version: "2024-01-01"
27+
operations: [TestOp]
28+
}
29+
30+
operation TestOp {
31+
input: TestOpInput
32+
}
33+
34+
structure TestOpInput {
35+
principal: Principal
36+
}
37+
38+
union Principal {
39+
user: PrincipalUser
40+
}
41+
42+
structure PrincipalUser {
43+
name: String
44+
}
45+
""");
46+
PythonSymbolProvider provider = createProvider(model);
47+
var userMember = model.expectShape(ShapeId.from(NS + "#Principal$user"), MemberShape.class);
48+
49+
assertEquals("Principal_User", provider.toSymbol(userMember).getName());
50+
}
51+
52+
@Test
53+
public void testUnionUnknownVariantNameCollidingWithShapeUsesUnderscoreSeparator() {
54+
Model model = loadModel("""
55+
$version: "2"
56+
namespace smithy.example
57+
58+
service TestService {
59+
version: "2024-01-01"
60+
operations: [TestOp]
61+
}
62+
63+
operation TestOp {
64+
input: TestOpInput
65+
}
66+
67+
structure TestOpInput {
68+
value: MyUnion
69+
other: MyUnionUnknown
70+
}
71+
72+
union MyUnion {
73+
foo: String
74+
}
75+
76+
structure MyUnionUnknown {
77+
message: String
78+
}
79+
""");
80+
PythonSymbolProvider provider = createProvider(model);
81+
var union = model.expectShape(ShapeId.from(NS + "#MyUnion"), UnionShape.class);
82+
83+
assertEquals("MyUnion_Unknown",
84+
provider.toSymbol(union).expectProperty(SymbolProperties.UNION_UNKNOWN).getName());
85+
}
86+
87+
private static Model loadModel(String smithyIdl) {
88+
return Model.assembler().addUnparsedModel("test.smithy", smithyIdl).assemble().unwrap();
89+
}
90+
91+
private static PythonSymbolProvider createProvider(Model model) {
92+
PythonSettings settings = PythonSettings.builder()
93+
.service(ShapeId.from(NS + "#TestService"))
94+
.moduleName("test_client")
95+
.moduleVersion("0.0.1")
96+
.build();
97+
return new PythonSymbolProvider(model, settings);
98+
}
99+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package software.amazon.smithy.python.codegen.writer;
6+
7+
import static org.junit.jupiter.api.Assertions.assertFalse;
8+
import static org.junit.jupiter.api.Assertions.assertTrue;
9+
import static org.mockito.Mockito.mock;
10+
import static org.mockito.Mockito.when;
11+
12+
import org.junit.jupiter.api.Test;
13+
import software.amazon.smithy.python.codegen.PythonSettings;
14+
15+
public class ImportDeclarationsTest {
16+
17+
private static final String MODULE = "aws_sdk_example";
18+
private static final String LOCAL_NAMESPACE = MODULE + ".config";
19+
20+
@Test
21+
public void testAliasImportRewritesStdlibImport() {
22+
ImportDeclarations imports = createImports();
23+
imports.addStdlibImport("decimal", "Decimal");
24+
25+
imports.aliasImport("decimal", "Decimal", "_Decimal");
26+
27+
String out = imports.toString();
28+
assertTrue(normalize(out).contains("from decimal import Decimal as _Decimal"));
29+
}
30+
31+
@Test
32+
public void testAliasImportRewritesExternalImport() {
33+
ImportDeclarations imports = createImports();
34+
imports.addImport("smithy_core.documents", "Document", "Document");
35+
36+
imports.aliasImport("smithy_core.documents", "Document", "_Document");
37+
38+
String out = imports.toString();
39+
assertTrue(normalize(out).contains("from smithy_core.documents import Document as _Document"));
40+
}
41+
42+
@Test
43+
public void testAliasImportRewritesLocalImport() {
44+
ImportDeclarations imports = createImports();
45+
imports.addImport(MODULE + ".models", "Foo", "Foo");
46+
47+
imports.aliasImport(MODULE + ".models", "Foo", "_Foo");
48+
49+
String out = imports.toString();
50+
assertTrue(normalize(out).contains("from .models import Foo as _Foo"));
51+
}
52+
53+
@Test
54+
public void testAliasImportDoesNotOverwritePreExistingAlias() {
55+
ImportDeclarations imports = createImports();
56+
imports.addImport("smithy_core.documents", "Document", "OriginalAlias");
57+
58+
imports.aliasImport("smithy_core.documents", "Document", "_ShouldNotApply");
59+
60+
String out = imports.toString();
61+
String normalized = normalize(out);
62+
assertTrue(normalized.contains("from smithy_core.documents import Document as OriginalAlias"));
63+
assertFalse(normalized.contains("_ShouldNotApply"));
64+
}
65+
66+
private static ImportDeclarations createImports() {
67+
PythonSettings settings = mock(PythonSettings.class);
68+
when(settings.moduleName()).thenReturn(MODULE);
69+
return new ImportDeclarations(settings, LOCAL_NAMESPACE);
70+
}
71+
72+
private static String normalize(String output) {
73+
return output.replaceAll("\\(\\s+", "")
74+
.replaceAll(",\\s*\\)", "")
75+
.replaceAll("\\s+", " ");
76+
}
77+
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package software.amazon.smithy.python.codegen.writer;
6+
7+
import static org.junit.jupiter.api.Assertions.assertFalse;
8+
import static org.junit.jupiter.api.Assertions.assertTrue;
9+
import static org.mockito.Mockito.mock;
10+
import static org.mockito.Mockito.when;
11+
12+
import org.junit.jupiter.api.Test;
13+
import software.amazon.smithy.codegen.core.Symbol;
14+
import software.amazon.smithy.python.codegen.PythonSettings;
15+
16+
public class PythonWriterTest {
17+
18+
private static final String CURRENT_PACKAGE = "aws_sdk_example.models";
19+
20+
@Test
21+
public void testFrameworkSymbolCollidingWithLocalIsAliased() {
22+
PythonWriter writer = createWriter(CURRENT_PACKAGE);
23+
Symbol framework = frameworkSymbol("smithy_core.documents", "Document");
24+
Symbol local = generatedSymbol(CURRENT_PACKAGE, "Document");
25+
writer.addLocallyDefinedSymbol(local);
26+
27+
writer.write("value: $T", framework);
28+
String out = writer.toString();
29+
30+
assertTrue(normalize(out).contains("from smithy_core.documents import Document as _Document"));
31+
assertTrue(out.contains("value: _Document"));
32+
}
33+
34+
@Test
35+
public void testTwoFrameworkSymbolsWithSameSimpleNameGetModuleAliases() {
36+
PythonWriter writer = createWriter("aws_sdk_example.auth");
37+
Symbol a = frameworkSymbol("smithy_core.auth", "AuthOption");
38+
Symbol b = frameworkSymbol("smithy_core.interfaces.auth", "AuthOption");
39+
40+
writer.write("x: $T", a);
41+
writer.write("y: $T", b);
42+
String out = writer.toString();
43+
String normalized = normalize(out);
44+
45+
assertTrue(normalized.contains(
46+
"from smithy_core.auth import AuthOption as _smithy_core_auth_AuthOption"));
47+
assertTrue(normalized.contains(
48+
"from smithy_core.interfaces.auth import "
49+
+ "AuthOption as _smithy_core_interfaces_auth_AuthOption"));
50+
assertTrue(out.contains("x: _smithy_core_auth_AuthOption"));
51+
assertTrue(out.contains("y: _smithy_core_interfaces_auth_AuthOption"));
52+
}
53+
54+
@Test
55+
public void testCrossFileGeneratedSymbolCollidingWithFrameworkImportIsAliased() {
56+
PythonWriter writer = createWriter("aws_sdk_example.config");
57+
Symbol framework = frameworkSymbol("smithy_core.aio.interfaces", "EndpointResolver");
58+
Symbol crossFileGenerated = generatedSymbol("aws_sdk_example.models", "EndpointResolver");
59+
60+
writer.write("x: $T", framework);
61+
writer.write("y: $T", crossFileGenerated);
62+
String out = writer.toString();
63+
String normalized = normalize(out);
64+
65+
assertTrue(normalized.contains(
66+
"from smithy_core.aio.interfaces import "
67+
+ "EndpointResolver as _smithy_core_aio_interfaces_EndpointResolver"));
68+
assertTrue(normalized.contains(
69+
"from .models import EndpointResolver as _aws_sdk_example_models_EndpointResolver"));
70+
}
71+
72+
@Test
73+
public void testSymbolDefinedInCurrentWriterFileIsNotRegistered() {
74+
PythonWriter writer = createWriter(CURRENT_PACKAGE);
75+
Symbol selfReference = generatedSymbol(CURRENT_PACKAGE, "MyStruct");
76+
writer.addLocallyDefinedSymbol(selfReference);
77+
78+
writer.write("value: $T", selfReference);
79+
String out = writer.toString();
80+
81+
assertFalse(out.contains("import MyStruct"));
82+
assertFalse(out.contains("_MyStruct"));
83+
assertTrue(out.contains("value: MyStruct"));
84+
}
85+
86+
private static PythonWriter createWriter(String fullPackageName) {
87+
PythonSettings settings = mock(PythonSettings.class);
88+
when(settings.moduleName()).thenReturn(fullPackageName.split("\\.")[0]);
89+
return new PythonWriter(settings, fullPackageName);
90+
}
91+
92+
private static Symbol frameworkSymbol(String namespace, String name) {
93+
return Symbol.builder().name(name).namespace(namespace, ".").build();
94+
}
95+
96+
private static Symbol generatedSymbol(String namespace, String name) {
97+
return Symbol.builder()
98+
.name(name)
99+
.namespace(namespace, ".")
100+
.definitionFile("./src/" + namespace.replace('.', '/') + ".py")
101+
.build();
102+
}
103+
104+
private static String normalize(String output) {
105+
return output.replaceAll("\\(\\s+", "")
106+
.replaceAll(",\\s*\\)", "")
107+
.replaceAll("\\s+", " ");
108+
}
109+
}

0 commit comments

Comments
 (0)