Skip to content

Commit aed6ed6

Browse files
authored
Add an import helper and import interface-injected classes when possible (#47)
1 parent a8a57ea commit aed6ed6

File tree

16 files changed

+375
-15
lines changed

16 files changed

+375
-15
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
package net.neoforged.jst.api;
2+
3+
import com.intellij.psi.PsiClass;
4+
import com.intellij.psi.PsiField;
5+
import com.intellij.psi.PsiFile;
6+
import com.intellij.psi.PsiImportStatementBase;
7+
import com.intellij.psi.PsiImportStaticStatement;
8+
import com.intellij.psi.PsiJavaFile;
9+
import com.intellij.psi.PsiMethod;
10+
import com.intellij.psi.PsiModifier;
11+
import com.intellij.psi.PsiPackage;
12+
import org.jetbrains.annotations.Nullable;
13+
import org.jetbrains.annotations.VisibleForTesting;
14+
15+
import java.util.HashMap;
16+
import java.util.HashSet;
17+
import java.util.Map;
18+
import java.util.Objects;
19+
import java.util.Set;
20+
import java.util.stream.Collectors;
21+
22+
/**
23+
* Helper class used to import classes while processing a source file.
24+
* @see ImportHelper#get(PsiFile)
25+
*/
26+
public class ImportHelper implements PostProcessReplacer {
27+
private final PsiJavaFile psiFile;
28+
private final Map<String, String> importedNames = new HashMap<>();
29+
30+
private final Set<String> successfulImports = new HashSet<>();
31+
32+
public ImportHelper(PsiJavaFile psiFile) {
33+
this.psiFile = psiFile;
34+
35+
if (psiFile.getPackageStatement() != null) {
36+
var resolved = psiFile.getPackageStatement().getPackageReference().resolve();
37+
// We cannot import a class with the name of a class in the package of the file
38+
if (resolved instanceof PsiPackage pkg) {
39+
for (PsiClass cls : pkg.getClasses()) {
40+
importedNames.put(cls.getName(), cls.getQualifiedName());
41+
}
42+
}
43+
}
44+
45+
if (psiFile.getImportList() != null) {
46+
for (PsiImportStatementBase stmt : psiFile.getImportList().getImportStatements()) {
47+
var res = stmt.resolve();
48+
if (res instanceof PsiPackage pkg) {
49+
// Wildcard package imports will reserve all names of top-level classes in the package
50+
for (PsiClass cls : pkg.getClasses()) {
51+
importedNames.put(cls.getName(), cls.getQualifiedName());
52+
}
53+
} else if (res instanceof PsiClass cls) {
54+
importedNames.put(cls.getName(), cls.getQualifiedName());
55+
}
56+
}
57+
58+
for (PsiImportStaticStatement stmt : psiFile.getImportList().getImportStaticStatements()) {
59+
var res = stmt.resolve();
60+
if (res instanceof PsiMethod method) {
61+
importedNames.put(method.getName(), method.getName());
62+
} else if (res instanceof PsiField fld) {
63+
importedNames.put(fld.getName(), fld.getName());
64+
} else if (res instanceof PsiClass cls && stmt.isOnDemand()) {
65+
// On-demand imports are static wildcard imports which will reserve the names of
66+
// - all static methods available through the imported class
67+
for (PsiMethod met : cls.getAllMethods()) {
68+
if (met.getModifierList().hasModifierProperty(PsiModifier.STATIC)) {
69+
importedNames.put(met.getName(), met.getName());
70+
}
71+
}
72+
73+
// - all fields available through the imported class
74+
for (PsiField fld : cls.getAllFields()) {
75+
if (fld.getModifierList() != null && fld.getModifierList().hasModifierProperty(PsiModifier.STATIC)) {
76+
importedNames.put(fld.getName(), fld.getName());
77+
}
78+
}
79+
80+
// - all inner classes available through the imported class directly
81+
for (PsiClass c : cls.getAllInnerClasses()) {
82+
importedNames.put(c.getName(), c.getQualifiedName());
83+
}
84+
85+
// Note: to avoid possible issues, none of the above check for visibility. We prefer to be more conservative to make sure the output sources compile
86+
}
87+
}
88+
}
89+
}
90+
91+
@VisibleForTesting
92+
public boolean canImport(String name) {
93+
return !importedNames.containsKey(name);
94+
}
95+
96+
/**
97+
* Attempts to import the given fully qualified class name, returning a reference to it which is either
98+
* its short name (if an import is successful) or the qualified name if not.
99+
*/
100+
public String importClass(String cls) {
101+
var clsByDot = cls.split("\\.");
102+
// We do not try to import classes in the default package or classes already imported
103+
if (clsByDot.length == 1 || successfulImports.contains(cls)) {
104+
return clsByDot[clsByDot.length - 1];
105+
}
106+
// We also do not want to import classes under java.lang.*
107+
else if (clsByDot.length == 3 && clsByDot[0].equals("java") && clsByDot[1].equals("lang")) {
108+
return clsByDot[2];
109+
}
110+
111+
var name = clsByDot[clsByDot.length - 1];
112+
113+
if (Objects.equals(importedNames.get(name), cls)) {
114+
return name;
115+
}
116+
117+
if (canImport(name)) {
118+
successfulImports.add(cls);
119+
return name;
120+
}
121+
122+
return cls;
123+
}
124+
125+
@Override
126+
public void process(Replacements replacements) {
127+
if (successfulImports.isEmpty()) return;
128+
129+
var insertion = successfulImports.stream()
130+
.sorted()
131+
.map(s -> "import " + s + ";")
132+
.collect(Collectors.joining("\n"));
133+
134+
if (psiFile.getImportList() != null && psiFile.getImportList().getLastChild() != null) {
135+
var lastImport = psiFile.getImportList().getLastChild();
136+
replacements.insertAfter(lastImport, "\n\n" + insertion);
137+
} else {
138+
replacements.insertBefore(psiFile.getClasses()[0], insertion + "\n\n");
139+
}
140+
}
141+
142+
@Nullable
143+
public static ImportHelper get(PsiFile file) {
144+
return file instanceof PsiJavaFile j ? get(j) : null;
145+
}
146+
147+
public static ImportHelper get(PsiJavaFile file) {
148+
return PostProcessReplacer.getOrCreateReplacer(file, ImportHelper.class, k -> new ImportHelper(file));
149+
}
150+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package net.neoforged.jst.api;
2+
3+
import com.intellij.openapi.util.Key;
4+
import com.intellij.psi.PsiFile;
5+
import org.jetbrains.annotations.UnmodifiableView;
6+
7+
import java.util.Collections;
8+
import java.util.IdentityHashMap;
9+
import java.util.Map;
10+
import java.util.function.Function;
11+
12+
/**
13+
* A replacer linked to a {@link PsiFile} will run and collect replacements after all {@link SourceTransformer transformers} have processed the file.
14+
*/
15+
public interface PostProcessReplacer {
16+
Key<Map<Class<?>, PostProcessReplacer>> REPLACERS = Key.create("jst.post_process_replacers");
17+
18+
/**
19+
* Process replacements in the file after {@link SourceTransformer transformers} have processed it.
20+
*/
21+
void process(Replacements replacements);
22+
23+
@UnmodifiableView
24+
static Map<Class<?>, PostProcessReplacer> getReplacers(PsiFile file) {
25+
var rep = file.getUserData(REPLACERS);
26+
return rep == null ? Map.of() : Collections.unmodifiableMap(rep);
27+
}
28+
29+
static <T extends PostProcessReplacer> T getOrCreateReplacer(PsiFile file, Class<T> type, Function<PsiFile, T> creator) {
30+
var rep = file.getUserData(REPLACERS);
31+
if (rep == null) {
32+
rep = new IdentityHashMap<>();
33+
file.putUserData(REPLACERS, rep);
34+
}
35+
//noinspection unchecked
36+
return (T)rep.computeIfAbsent(type, k -> creator.apply(file));
37+
}
38+
}

cli/src/main/java/net/neoforged/jst/cli/SourceFileProcessor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import net.neoforged.jst.api.FileSink;
77
import net.neoforged.jst.api.FileSource;
88
import net.neoforged.jst.api.Logger;
9+
import net.neoforged.jst.api.PostProcessReplacer;
910
import net.neoforged.jst.api.Replacement;
1011
import net.neoforged.jst.api.Replacements;
1112
import net.neoforged.jst.api.SourceTransformer;
@@ -153,6 +154,10 @@ private byte[] transformSource(VirtualFile contentRoot, FileEntry entry, List<So
153154
transformer.visitFile(psiFile, replacements);
154155
}
155156

157+
for (PostProcessReplacer rep : PostProcessReplacer.getReplacers(psiFile).values()) {
158+
rep.process(replacements);
159+
}
160+
156161
var readOnlyReplacements = Collections.unmodifiableList(replacementsList);
157162
boolean success = true;
158163
for (var transformer : transformers) {
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package net.neoforged.jst.cli;
2+
3+
import com.intellij.openapi.vfs.VirtualFile;
4+
import com.intellij.psi.PsiClass;
5+
import com.intellij.psi.PsiElement;
6+
import com.intellij.psi.PsiJavaFile;
7+
import com.intellij.psi.PsiMethod;
8+
import com.intellij.psi.util.PsiTreeUtil;
9+
import net.neoforged.jst.api.ImportHelper;
10+
import net.neoforged.jst.api.Logger;
11+
import net.neoforged.jst.api.Replacements;
12+
import net.neoforged.jst.cli.intellij.IntelliJEnvironmentImpl;
13+
import org.intellij.lang.annotations.Language;
14+
import org.junit.jupiter.api.AfterAll;
15+
import org.junit.jupiter.api.BeforeAll;
16+
import org.junit.jupiter.api.Test;
17+
18+
import java.io.IOException;
19+
20+
import static org.junit.jupiter.api.Assertions.assertEquals;
21+
import static org.assertj.core.api.Assertions.*;
22+
import static org.junit.jupiter.api.Assertions.assertFalse;
23+
import static org.junit.jupiter.api.Assertions.assertTrue;
24+
25+
public class ImportHelperTest {
26+
static IntelliJEnvironmentImpl ijEnv;
27+
28+
@BeforeAll
29+
static void setUp() throws IOException {
30+
ijEnv = new IntelliJEnvironmentImpl(new Logger(null, null));
31+
ijEnv.addCurrentJdkToClassPath();
32+
}
33+
34+
@AfterAll
35+
static void tearDown() throws IOException {
36+
ijEnv.close();
37+
}
38+
39+
@Test
40+
public void testSimpleImports() {
41+
var helper = getImportHelper("""
42+
import java.util.Collection;
43+
import java.lang.annotation.Retention;
44+
import java.util.concurrent.atomic.AtomicReference;""");
45+
46+
assertFalse(helper.canImport("Collection"), "Collection can wrongly be imported");
47+
assertFalse(helper.canImport("Retention"), "Retention can wrongly be imported");
48+
assertFalse(helper.canImport("AtomicReference"), "AtomicReference can wrongly be imported");
49+
50+
assertTrue(helper.canImport("MyRandomClass"), "Cannot import a non-reserved name");
51+
}
52+
53+
@Test
54+
public void testWildcardImports() {
55+
var helper = getImportHelper("""
56+
import java.util.concurrent.*;""");
57+
58+
assertFalse(helper.canImport("Future"), "Future can wrongly be imported");
59+
assertFalse(helper.canImport("Executor"), "Executor can wrongly be imported");
60+
61+
assertTrue(helper.canImport("ThisWillNotExist"), "Cannot import a non-reserved name");
62+
}
63+
64+
@Test
65+
public void testStaticImports() {
66+
var helper = getImportHelper("""
67+
import static java.util.Spliterators.emptyDoubleSpliterator;
68+
import static java.util.Collections.*;""");
69+
70+
assertFalse(helper.canImport("emptyDoubleSpliterator"), "emptyDoubleSpliterator can wrongly be imported");
71+
72+
assertFalse(helper.canImport("min"), "min can wrongly be imported");
73+
assertFalse(helper.canImport("checkedSortedMap"), "checkedSortedMap can wrongly be imported");
74+
assertFalse(helper.canImport("EMPTY_LIST"), "EMPTY_LIST can wrongly be imported");
75+
76+
assertTrue(helper.canImport("ThisWillNotExist"), "Cannot import a non-reserved name");
77+
}
78+
79+
@Test
80+
void testReplace() {
81+
var file = parseSingleFile("""
82+
package java.lang.annotation;
83+
84+
import java.util.*;
85+
86+
class MyClass {
87+
}""");
88+
89+
var helper = ImportHelper.get(file);
90+
91+
assertEquals("HelloWorld", helper.importClass("com.hello.world.HelloWorld"));
92+
93+
assertEquals("Annotation", helper.importClass("java.lang.annotation.Annotation"));
94+
assertEquals("com.hello.world.Annotation", helper.importClass("com.hello.world.Annotation"));
95+
96+
assertEquals("List", helper.importClass("java.util.List"));
97+
assertEquals("com.hello.world.List", helper.importClass("com.hello.world.List"));
98+
99+
assertEquals("Thing", helper.importClass("a.b.c.Thing"));
100+
101+
var rep = new Replacements();
102+
helper.process(rep);
103+
104+
assertThat(rep.apply(file.getText()))
105+
.isEqualToNormalizingNewlines("""
106+
package java.lang.annotation;
107+
108+
import java.util.*;
109+
110+
import a.b.c.Thing;
111+
import com.hello.world.HelloWorld;
112+
113+
class MyClass {
114+
}""");
115+
}
116+
117+
private ImportHelper getImportHelper(@Language("JAVA") String javaCode) {
118+
var file = parseSingleFile(javaCode);
119+
return new ImportHelper(file);
120+
}
121+
122+
private PsiJavaFile parseSingleFile(@Language("JAVA") String javaCode) {
123+
return parseSingleElement(javaCode, PsiJavaFile.class);
124+
}
125+
126+
private <T extends PsiElement> T parseSingleElement(@Language("JAVA") String javaCode, Class<T> type) {
127+
var file = ijEnv.parseFileFromMemory("Test.java", javaCode);
128+
129+
var elements = PsiTreeUtil.collectElementsOfType(file, type);
130+
assertEquals(1, elements.size());
131+
return elements.iterator().next();
132+
}
133+
}

interfaceinjection/src/main/java/net/neoforged/jst/interfaceinjection/InjectInterfacesVisitor.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import com.intellij.psi.PsiWhiteSpace;
99
import com.intellij.psi.util.ClassUtil;
1010
import com.intellij.util.containers.MultiMap;
11+
import net.neoforged.jst.api.ImportHelper;
1112
import net.neoforged.jst.api.Replacements;
1213
import org.jetbrains.annotations.NotNull;
1314
import org.jetbrains.annotations.Nullable;
@@ -60,6 +61,8 @@ private void inject(PsiClass psiClass, Collection<String> targets) {
6061
return;
6162
}
6263

64+
var imports = ImportHelper.get(psiClass.getContainingFile());
65+
6366
var implementsList = psiClass.isInterface() ? psiClass.getExtendsList() : psiClass.getImplementsList();
6467
var implementedInterfaces = Arrays.stream(implementsList.getReferencedTypes())
6568
.map(PsiClassType::resolve)
@@ -71,8 +74,8 @@ private void inject(PsiClass psiClass, Collection<String> targets) {
7174
.distinct()
7275
.map(stubs::createStub)
7376
.filter(iface -> !implementedInterfaces.contains(iface.interfaceDeclaration()))
74-
.map(StubStore.InterfaceInformation::toString)
75-
.map(this::decorate)
77+
.map(iface -> possiblyImport(imports, iface))
78+
.map(iface -> decorate(imports, iface))
7679
.sorted(Comparator.naturalOrder())
7780
.collect(Collectors.joining(", "));
7881

@@ -94,10 +97,15 @@ private void inject(PsiClass psiClass, Collection<String> targets) {
9497
}
9598
}
9699

97-
private String decorate(String iface) {
100+
private String possiblyImport(@Nullable ImportHelper helper, StubStore.InterfaceInformation info) {
101+
var interfaceName = helper == null ? info.interfaceDeclaration() : helper.importClass(info.interfaceDeclaration());
102+
return info.generics().isBlank() ? interfaceName : (interfaceName + "<" + info.generics() + ">");
103+
}
104+
105+
private String decorate(@Nullable ImportHelper helper, String iface) {
98106
if (marker == null) {
99107
return iface;
100108
}
101-
return "@" + marker + " " + iface;
109+
return "@" + (helper == null ? marker : helper.importClass(marker)) + " " + iface;
102110
}
103111
}
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
package net;
22

3-
public class Example implements Runnable, com.example.InjectedInterface {
3+
import com.example.InjectedInterface;
4+
5+
public class Example implements Runnable, InjectedInterface {
46
}

0 commit comments

Comments
 (0)