diff --git a/pom.xml b/pom.xml
index 7b224e8..3d2fe15 100644
--- a/pom.xml
+++ b/pom.xml
@@ -110,6 +110,13 @@
${tng.archunit.version}
+
+
+ com.tngtech.archunit
+ archunit-junit5-api
+ ${tng.archunit.version}
+
+
com.google.guava
guava
diff --git a/src/main/java/com/societegenerale/commons/plugin/service/InvokableRules.java b/src/main/java/com/societegenerale/commons/plugin/service/InvokableRules.java
index 349b9cd..3569459 100644
--- a/src/main/java/com/societegenerale/commons/plugin/service/InvokableRules.java
+++ b/src/main/java/com/societegenerale/commons/plugin/service/InvokableRules.java
@@ -3,8 +3,11 @@
import java.lang.reflect.Field;
import java.lang.reflect.Member;
import java.lang.reflect.Method;
+import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
+import java.util.Deque;
+import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
@@ -13,20 +16,20 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.societegenerale.commons.plugin.Log;
+import com.societegenerale.commons.plugin.utils.ReflectionUtils;
import com.tngtech.archunit.core.domain.JavaClasses;
+import com.tngtech.archunit.junit.ArchTests;
import com.tngtech.archunit.lang.ArchRule;
import static com.societegenerale.commons.plugin.utils.ReflectionUtils.getValue;
import static com.societegenerale.commons.plugin.utils.ReflectionUtils.invoke;
import static com.societegenerale.commons.plugin.utils.ReflectionUtils.loadClassWithContextClassLoader;
-import static com.societegenerale.commons.plugin.utils.ReflectionUtils.newInstance;
import static java.lang.System.lineSeparator;
import static java.util.Arrays.stream;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toSet;
class InvokableRules {
- private final Class> rulesLocation;
private final Set archRuleFields;
private final Set archRuleMethods;
@@ -36,11 +39,17 @@ private InvokableRules(String rulesClassName, List ruleChecks, Log log)
this.log=log;
- rulesLocation = loadClassWithContextClassLoader(rulesClassName);
+ Class> definedRulesClass = loadClassWithContextClassLoader(rulesClassName);
- Set allFieldsWhichAreArchRules = getAllFieldsWhichAreArchRules(rulesLocation.getDeclaredFields());
- Set allMethodsWhichAreArchRules = getAllMethodsWhichAreArchRules(rulesLocation.getDeclaredMethods());
- validateRuleChecks(Sets.union(allMethodsWhichAreArchRules, allFieldsWhichAreArchRules), ruleChecks);
+ Set> rulesClasses = getAllClassesWhichAreArchTests(definedRulesClass);
+ rulesClasses.add(definedRulesClass);
+ Set allFieldsWhichAreArchRules = new HashSet<>();
+ Set allMethodsWhichAreArchRules = new HashSet<>();
+ for (Class> rulesClass : rulesClasses) {
+ allFieldsWhichAreArchRules.addAll(getAllFieldsWhichAreArchRules(rulesClass.getDeclaredFields()));
+ allMethodsWhichAreArchRules.addAll(getAllMethodsWhichAreArchRules(rulesClass.getDeclaredMethods()));
+ }
+ validateRuleChecks(definedRulesClass, Sets.union(allMethodsWhichAreArchRules, allFieldsWhichAreArchRules), ruleChecks);
Predicate isChosenCheck = ruleChecks.isEmpty() ? check -> true : ruleChecks::contains;
@@ -64,7 +73,7 @@ private void logBuiltInvokableRules(String rulesClassName) {
}
- private void validateRuleChecks(Set extends Member> allFieldsAndMethods, Collection ruleChecks) {
+ private void validateRuleChecks(Class> rulesLocation, Set extends Member> allFieldsAndMethods, Collection ruleChecks) {
Set allFieldAndMethodNames = allFieldsAndMethods.stream().map(Member::getName).collect(toSet());
Set illegalChecks = Sets.difference(ImmutableSet.copyOf(ruleChecks), allFieldAndMethodNames);
@@ -91,9 +100,26 @@ private Set getAllFieldsWhichAreArchRules(Field[] fields) {
.collect(toSet());
}
- InvocationResult invokeOn(JavaClasses importedClasses) {
+ private Set> getAllClassesWhichAreArchTests(Class> startClass) {
+ Set> allClassesWhichAreArchTests = new HashSet<>();
+ Deque> stack = new ArrayDeque<>();
+ stack.push(startClass);
+ while (!stack.isEmpty()) {
+ Class> currentClass = stack.pop();
+ stream(currentClass.getDeclaredFields())
+ .filter(f -> ArchTests.class.isAssignableFrom(f.getType()))
+ .map(f -> getValue(f, null))
+ .map(ArchTests.class::cast)
+ .map(ArchTests::getDefinitionLocation)
+ .forEach(childClass -> {
+ allClassesWhichAreArchTests.add(childClass);
+ stack.push(childClass);
+ });
+ }
+ return allClassesWhichAreArchTests;
+ }
- Object instance = newInstance(rulesLocation);
+ InvocationResult invokeOn(JavaClasses importedClasses) {
if(log.isInfoEnabled()) {
log.info("applying rules on "+importedClasses.size()+" classe(s). To see the details, enable debug logs");
@@ -105,11 +131,11 @@ InvocationResult invokeOn(JavaClasses importedClasses) {
InvocationResult result = new InvocationResult();
for (Method method : archRuleMethods) {
- checkForFailure(() -> invoke(method, instance, importedClasses))
+ checkForFailure(() -> invoke(method, null, importedClasses))
.ifPresent(result::add);
}
for (Field field : archRuleFields) {
- ArchRule rule = getValue(field, instance);
+ ArchRule rule = getValue(field, null);
checkForFailure(() -> rule.check(importedClasses))
.ifPresent(result::add);
}
diff --git a/src/test/java/com/societegenerale/commons/plugin/rules/classesForTests/DoubleIncludedCustomRule.java b/src/test/java/com/societegenerale/commons/plugin/rules/classesForTests/DoubleIncludedCustomRule.java
new file mode 100644
index 0000000..09c5f85
--- /dev/null
+++ b/src/test/java/com/societegenerale/commons/plugin/rules/classesForTests/DoubleIncludedCustomRule.java
@@ -0,0 +1,10 @@
+package com.societegenerale.commons.plugin.rules.classesForTests;
+
+import com.tngtech.archunit.junit.ArchTest;
+import com.tngtech.archunit.junit.ArchTests;
+
+public class DoubleIncludedCustomRule {
+
+ @ArchTest
+ static final ArchTests DOUBLE_INCLUDED = ArchTests.in(IncludedCustomRule.class);
+}
diff --git a/src/test/java/com/societegenerale/commons/plugin/rules/classesForTests/IncludedCustomRule.java b/src/test/java/com/societegenerale/commons/plugin/rules/classesForTests/IncludedCustomRule.java
new file mode 100644
index 0000000..5f5ba54
--- /dev/null
+++ b/src/test/java/com/societegenerale/commons/plugin/rules/classesForTests/IncludedCustomRule.java
@@ -0,0 +1,10 @@
+package com.societegenerale.commons.plugin.rules.classesForTests;
+
+import com.tngtech.archunit.junit.ArchTest;
+import com.tngtech.archunit.junit.ArchTests;
+
+public class IncludedCustomRule {
+
+ @ArchTest
+ static final ArchTests INCLUDED = ArchTests.in(DummyCustomRule.class);
+}
diff --git a/src/test/java/com/societegenerale/commons/plugin/service/InvokableRulesTest.java b/src/test/java/com/societegenerale/commons/plugin/service/InvokableRulesTest.java
new file mode 100644
index 0000000..120f9ee
--- /dev/null
+++ b/src/test/java/com/societegenerale/commons/plugin/service/InvokableRulesTest.java
@@ -0,0 +1,75 @@
+package com.societegenerale.commons.plugin.service;
+
+import com.societegenerale.commons.plugin.Log;
+import com.societegenerale.commons.plugin.model.RootClassFolder;
+import com.societegenerale.commons.plugin.rules.classesForTests.DoubleIncludedCustomRule;
+import com.societegenerale.commons.plugin.rules.classesForTests.DummyCustomRule;
+import com.societegenerale.commons.plugin.rules.classesForTests.IncludedCustomRule;
+import com.societegenerale.commons.plugin.utils.ArchUtils;
+import com.tngtech.archunit.core.domain.JavaClasses;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+
+import java.util.Arrays;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.mock;
+
+class InvokableRulesTest {
+
+ @BeforeAll
+ @SuppressWarnings("InstantiationOfUtilityClass")
+ static void instantiateArchUtils() {
+ new ArchUtils(mock(Log.class));
+ }
+
+ @Test
+ void shouldInvokeAllRulesDefinedAsFields() {
+ assertThat(invokeAndGetMessage(DummyCustomRule.class))
+ .contains("Rule 'classes should be annotated with @Test' was violated")
+ .contains("Rule 'classes should reside in a package 'myPackage'' was violated");
+ }
+
+ @Test
+ void shouldInvokeSpecificRuleDefinedAsField() {
+ assertThat(invokeAndGetMessage(DummyCustomRule.class, "annotatedWithTest"))
+ .contains("Rule 'classes should be annotated with @Test' was violated")
+ .doesNotContain("Rule 'classes should reside in a package 'myPackage'' was violated");
+ }
+
+ @Test
+ void shouldInvokeAllRulesIncludedViaField() {
+ assertThat(invokeAndGetMessage(IncludedCustomRule.class))
+ .contains("Rule 'classes should be annotated with @Test' was violated")
+ .contains("Rule 'classes should reside in a package 'myPackage'' was violated");
+ }
+
+ @Test
+ void shouldInvokeSpecificRuleIncludedViaField() {
+ assertThat(invokeAndGetMessage(IncludedCustomRule.class, "annotatedWithTest"))
+ .contains("Rule 'classes should be annotated with @Test' was violated")
+ .doesNotContain("Rule 'classes should reside in a package 'myPackage'' was violated");
+ }
+
+ @Test
+ void shouldInvokeAllRulesIncludedViaFieldThatItselfIncludes() {
+ assertThat(invokeAndGetMessage(DoubleIncludedCustomRule.class))
+ .contains("Rule 'classes should be annotated with @Test' was violated")
+ .contains("Rule 'classes should reside in a package 'myPackage'' was violated");
+ }
+
+ @Test
+ void shouldInvokeSpecificRuleIncludedViaFieldThatItselfIncludes() {
+ assertThat(invokeAndGetMessage(DoubleIncludedCustomRule.class,"annotatedWithTest"))
+ .contains("Rule 'classes should be annotated with @Test' was violated")
+ .doesNotContain("Rule 'classes should reside in a package 'myPackage'' was violated");
+ }
+
+
+ private static String invokeAndGetMessage(Class> rulesClass, String... checks) {
+ JavaClasses javaClasses = ArchUtils.importAllClassesInPackage(new RootClassFolder(""), "");
+ InvokableRules invokableRules = InvokableRules.of(rulesClass.getName(), Arrays.asList(checks), mock(Log.class));
+ InvokableRules.InvocationResult invocationResult = invokableRules.invokeOn(javaClasses);
+ return invocationResult.getMessage();
+ }
+}