1818
1919import java .lang .annotation .Annotation ;
2020import java .lang .reflect .Constructor ;
21+ import java .util .ArrayList ;
2122import java .util .Collections ;
2223import java .util .HashSet ;
2324import java .util .LinkedHashSet ;
25+ import java .util .List ;
2426import java .util .Set ;
2527import java .util .stream .Collectors ;
2628
5355import org .springframework .core .type .AnnotationMetadata ;
5456import org .springframework .test .context .ContextCustomizer ;
5557import org .springframework .test .context .MergedContextConfiguration ;
58+ import org .springframework .test .context .TestContextAnnotationUtils ;
5659import org .springframework .util .ReflectionUtils ;
5760
5861/**
6669 */
6770class ImportsContextCustomizer implements ContextCustomizer {
6871
69- private static final String TEST_CLASS_NAME_ATTRIBUTE = "testClassName " ;
72+ private static final String TEST_CLASS_NAMES_ATTRIBUTE = "testClassNames " ;
7073
71- private final String testClassName ;
74+ private final String [] testClassNames ;
7275
7376 private final ContextCustomizerKey key ;
7477
7578 ImportsContextCustomizer (Class <?> testClass ) {
76- this .testClassName = testClass . getName ( );
79+ this .testClassNames = collectClassNames ( testClass );
7780 this .key = new ContextCustomizerKey (testClass );
7881 }
7982
83+ private static String [] collectClassNames (Class <?> source ) {
84+ List <String > classNames = new ArrayList <>();
85+ collectClassNames (source , classNames );
86+ return classNames .toArray (new String [0 ]);
87+ }
88+
89+ private static void collectClassNames (Class <?> source , List <String > classNames ) {
90+ classNames .add (source .getName ());
91+ if (TestContextAnnotationUtils .searchEnclosingClass (source )) {
92+ collectClassNames (source .getEnclosingClass (), classNames );
93+ }
94+ }
95+
8096 @ Override
8197 public void customizeContext (ConfigurableApplicationContext context ,
8298 MergedContextConfiguration mergedContextConfiguration ) {
@@ -90,13 +106,13 @@ private void registerCleanupPostProcessor(BeanDefinitionRegistry registry, Annot
90106 BeanDefinition definition = registerBean (registry , reader , ImportsCleanupPostProcessor .BEAN_NAME ,
91107 ImportsCleanupPostProcessor .class );
92108 definition .setRole (BeanDefinition .ROLE_INFRASTRUCTURE );
93- definition .getConstructorArgumentValues ().addIndexedArgumentValue (0 , this .testClassName );
109+ definition .getConstructorArgumentValues ().addIndexedArgumentValue (0 , this .testClassNames );
94110 }
95111
96112 private void registerImportsConfiguration (BeanDefinitionRegistry registry , AnnotatedBeanDefinitionReader reader ) {
97113 BeanDefinition definition = registerBean (registry , reader , ImportsConfiguration .BEAN_NAME ,
98114 ImportsConfiguration .class );
99- definition .setAttribute (TEST_CLASS_NAME_ATTRIBUTE , this .testClassName );
115+ definition .setAttribute (TEST_CLASS_NAMES_ATTRIBUTE , this .testClassNames );
100116 }
101117
102118 private BeanDefinitionRegistry getBeanDefinitionRegistry (ApplicationContext context ) {
@@ -169,8 +185,8 @@ public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
169185 @ Override
170186 public String [] selectImports (AnnotationMetadata importingClassMetadata ) {
171187 BeanDefinition definition = this .beanFactory .getBeanDefinition (ImportsConfiguration .BEAN_NAME );
172- Object testClassName = definition .getAttribute (TEST_CLASS_NAME_ATTRIBUTE );
173- return (testClassName != null ) ? new String [] { ( String ) testClassName } : NO_IMPORTS ;
188+ Object testClassNames = definition .getAttribute (TEST_CLASS_NAMES_ATTRIBUTE );
189+ return (testClassNames != null ) ? ( String []) testClassNames : NO_IMPORTS ;
174190 }
175191
176192 }
@@ -184,10 +200,10 @@ static class ImportsCleanupPostProcessor implements BeanDefinitionRegistryPostPr
184200
185201 static final String BEAN_NAME = ImportsCleanupPostProcessor .class .getName ();
186202
187- private final String testClassName ;
203+ private final String [] testClassNames ;
188204
189- ImportsCleanupPostProcessor (String testClassName ) {
190- this .testClassName = testClassName ;
205+ ImportsCleanupPostProcessor (String [] testClassNames ) {
206+ this .testClassNames = testClassNames ;
191207 }
192208
193209 @ Override
@@ -196,15 +212,15 @@ public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory)
196212
197213 @ Override
198214 public void postProcessBeanDefinitionRegistry (BeanDefinitionRegistry registry ) throws BeansException {
215+ for (String testClassName : this .testClassNames ) {
216+ removeBean (testClassName , registry );
217+ }
218+ removeBean (ImportsConfiguration .BEAN_NAME , registry );
219+ }
220+
221+ private void removeBean (String beanName , BeanDefinitionRegistry registry ) {
199222 try {
200- String [] names = registry .getBeanDefinitionNames ();
201- for (String name : names ) {
202- BeanDefinition definition = registry .getBeanDefinition (name );
203- if (this .testClassName .equals (definition .getBeanClassName ())) {
204- registry .removeBeanDefinition (name );
205- }
206- }
207- registry .removeBeanDefinition (ImportsConfiguration .BEAN_NAME );
223+ registry .removeBeanDefinition (beanName );
208224 }
209225 catch (NoSuchBeanDefinitionException ex ) {
210226 // Ignore
0 commit comments