Skip to content

Commit e205b06

Browse files
committed
Add @MethodConversion converter
1 parent eb8fa49 commit e205b06

4 files changed

Lines changed: 421 additions & 0 deletions

File tree

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Copyright © 2025-present Stefano Cordio
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package io.github.scordio.junit.converters;
17+
18+
import org.jspecify.annotations.Nullable;
19+
import org.junit.jupiter.api.extension.ParameterContext;
20+
import org.junit.jupiter.params.converter.ArgumentConversionException;
21+
import org.junit.jupiter.params.converter.ArgumentConverter;
22+
import org.junit.jupiter.params.support.AnnotationConsumer;
23+
import org.junit.jupiter.params.support.FieldContext;
24+
import org.junit.platform.commons.support.HierarchyTraversalMode;
25+
import org.junit.platform.commons.support.ModifierSupport;
26+
import org.junit.platform.commons.support.ReflectionSupport;
27+
28+
import java.lang.reflect.Method;
29+
import java.util.List;
30+
import java.util.Objects;
31+
import java.util.function.Predicate;
32+
import java.util.stream.Collectors;
33+
34+
class MethodArgumentConverter implements ArgumentConverter, AnnotationConsumer<MethodConversion> {
35+
36+
private static final Predicate<Method> IS_STATIC = ModifierSupport::isStatic;
37+
38+
private @Nullable MethodConversion annotation;
39+
40+
@Override
41+
public void accept(MethodConversion annotation) {
42+
this.annotation = annotation;
43+
}
44+
45+
@Override
46+
public @Nullable Object convert(@Nullable Object source, ParameterContext context) {
47+
return convert(source, context.getParameter().getType(), context.getDeclaringExecutable().getDeclaringClass());
48+
}
49+
50+
@Override
51+
public @Nullable Object convert(@Nullable Object source, FieldContext context) {
52+
return convert(source, context.getField().getType(), context.getField().getDeclaringClass());
53+
}
54+
55+
private @Nullable Object convert(@Nullable Object source, Class<?> targetType, Class<?> declaringClass) {
56+
Objects.requireNonNull(source, "'null' is not supported");
57+
Objects.requireNonNull(annotation, "'annotation' must not be null");
58+
Method conversionMethod = annotation.value().isEmpty()
59+
? findConversionMethod(source.getClass(), targetType, declaringClass)
60+
: findConversionMethod(annotation.value(), source.getClass(), targetType, declaringClass);
61+
return ReflectionSupport.invokeMethod(conversionMethod, null, source);
62+
}
63+
64+
private static Method findConversionMethod(Class<?> sourceType, Class<?> targetType, Class<?> declaringClass) {
65+
Predicate<Method> filter = IS_STATIC.and(accepts(sourceType)).and(produces(targetType));
66+
List<Method> methods = ReflectionSupport.findMethods(declaringClass, filter, HierarchyTraversalMode.BOTTOM_UP);
67+
68+
if (methods.isEmpty()) {
69+
throw new ArgumentConversionException(
70+
String.format("No conversion method found compatible with source type %s and target type %s",
71+
sourceType.getName(), targetType.getName()));
72+
}
73+
74+
if (methods.size() > 1) {
75+
List<String> signatures = methods.stream()
76+
.map(method -> String.format("%s %s(%s)", method.getReturnType().getName(), method.getName(),
77+
method.getParameterTypes()[0].getName()))
78+
.collect(Collectors.toList());
79+
80+
throw new ArgumentConversionException(
81+
String.format("Too many conversion methods compatible with source type %s and target type %s: %s",
82+
sourceType.getName(), targetType.getName(), signatures));
83+
}
84+
85+
return methods.get(0);
86+
}
87+
88+
private static Method findConversionMethod(String name, Class<?> sourceType, Class<?> targetType,
89+
Class<?> declaringClass) {
90+
Predicate<Method> filter = IS_STATIC.and(hasName(name)).and(accepts(sourceType)).and(produces(targetType));
91+
List<Method> methods = ReflectionSupport.findMethods(declaringClass, filter, HierarchyTraversalMode.BOTTOM_UP);
92+
93+
if (methods.isEmpty()) {
94+
throw new ArgumentConversionException(
95+
String.format("No conversion method found with the following signature: static %s %s(%s)",
96+
targetType.getName(), name, sourceType.getName()));
97+
}
98+
99+
return methods.get(0);
100+
}
101+
102+
private static Predicate<Method> accepts(Class<?> sourceType) {
103+
return method -> {
104+
Class<?>[] parameterTypes = method.getParameterTypes();
105+
return parameterTypes.length == 1 && ReflectionUtils.isAssignableTo(sourceType, parameterTypes[0]);
106+
};
107+
}
108+
109+
private static Predicate<Method> produces(Class<?> targetType) {
110+
return method -> targetType.isAssignableFrom(method.getReturnType());
111+
}
112+
113+
private static Predicate<Method> hasName(String name) {
114+
return method -> method.getName().equals(name);
115+
}
116+
117+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright © 2025-present Stefano Cordio
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package io.github.scordio.junit.converters;
17+
18+
import org.junit.jupiter.params.converter.ConvertWith;
19+
20+
import java.lang.annotation.Documented;
21+
import java.lang.annotation.ElementType;
22+
import java.lang.annotation.Retention;
23+
import java.lang.annotation.RetentionPolicy;
24+
import java.lang.annotation.Target;
25+
26+
/**
27+
* {@code @MethodConversion} is a {@link ConvertWith} composed annotation that converts
28+
* arguments using a {@linkplain #value() static method} declared in the test class.
29+
*
30+
* @since 0.2.0
31+
*/
32+
@Target({ ElementType.ANNOTATION_TYPE, ElementType.PARAMETER, ElementType.FIELD })
33+
@Retention(RetentionPolicy.RUNTIME)
34+
@Documented
35+
@ConvertWith(MethodArgumentConverter.class)
36+
@SuppressWarnings("exports")
37+
public @interface MethodConversion {
38+
39+
/**
40+
* The name of the conversion method within the test class to use for the conversion.
41+
* <p>
42+
* If no name is declared, the converter will look for a single static method within
43+
* the test class whose parameter type matches the source type and whose return type
44+
* matches the target type. The search traverses the class hierarchy with
45+
* {@linkplain org.junit.platform.commons.support.HierarchyTraversalMode#BOTTOM_UP
46+
* bottom-up} semantics.
47+
* @return the name of the conversion method to use
48+
*/
49+
String value() default "";
50+
51+
}
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package io.github.scordio.junit.converters;
2+
3+
import java.util.Collections;
4+
import java.util.IdentityHashMap;
5+
import java.util.Map;
6+
import java.util.Objects;
7+
8+
/**
9+
* Methods borrowed from {@link org.junit.platform.commons.util.ReflectionUtils}
10+
*/
11+
class ReflectionUtils {
12+
13+
private static final Map<Class<?>, Class<?>> primitiveToWrapperMap;
14+
15+
static {
16+
@SuppressWarnings("IdentityHashMapUsage")
17+
Map<Class<?>, Class<?>> primitivesToWrappers = new IdentityHashMap<>(9);
18+
19+
primitivesToWrappers.put(boolean.class, Boolean.class);
20+
primitivesToWrappers.put(byte.class, Byte.class);
21+
primitivesToWrappers.put(char.class, Character.class);
22+
primitivesToWrappers.put(short.class, Short.class);
23+
primitivesToWrappers.put(int.class, Integer.class);
24+
primitivesToWrappers.put(long.class, Long.class);
25+
primitivesToWrappers.put(float.class, Float.class);
26+
primitivesToWrappers.put(double.class, Double.class);
27+
28+
primitiveToWrapperMap = Collections.unmodifiableMap(primitivesToWrappers);
29+
}
30+
31+
static boolean isAssignableTo(Class<?> sourceType, Class<?> targetType) {
32+
Objects.requireNonNull(sourceType, "source type must not be null");
33+
Objects.requireNonNull(targetType, "target type must not be null");
34+
35+
if (sourceType.isPrimitive()) {
36+
throw new IllegalArgumentException("source type must not be a primitive type");
37+
}
38+
39+
if (targetType.isAssignableFrom(sourceType)) {
40+
return true;
41+
}
42+
43+
if (targetType.isPrimitive()) {
44+
return sourceType == primitiveToWrapperMap.get(targetType) || isWideningConversion(sourceType, targetType);
45+
}
46+
47+
return false;
48+
}
49+
50+
private static boolean isWideningConversion(Class<?> sourceType, Class<?> targetType) {
51+
if (!targetType.isPrimitive()) {
52+
throw new IllegalArgumentException("targetType must be primitive");
53+
}
54+
55+
boolean isPrimitive = sourceType.isPrimitive();
56+
boolean isWrapper = primitiveToWrapperMap.containsValue(sourceType);
57+
58+
// Neither a primitive nor a wrapper?
59+
if (!isPrimitive && !isWrapper) {
60+
return false;
61+
}
62+
63+
if (isPrimitive) {
64+
sourceType = primitiveToWrapperMap.get(sourceType);
65+
}
66+
67+
// @formatter:off
68+
if (sourceType == Byte.class) {
69+
return
70+
targetType == short.class ||
71+
targetType == int.class ||
72+
targetType == long.class ||
73+
targetType == float.class ||
74+
targetType == double.class;
75+
}
76+
77+
if (sourceType == Short.class || sourceType == Character.class) {
78+
return
79+
targetType == int.class ||
80+
targetType == long.class ||
81+
targetType == float.class ||
82+
targetType == double.class;
83+
}
84+
85+
if (sourceType == Integer.class) {
86+
return
87+
targetType == long.class ||
88+
targetType == float.class ||
89+
targetType == double.class;
90+
}
91+
92+
if (sourceType == Long.class) {
93+
return
94+
targetType == float.class ||
95+
targetType == double.class;
96+
}
97+
98+
if (sourceType == Float.class) {
99+
return
100+
targetType == double.class;
101+
}
102+
// @formatter:on
103+
104+
return false;
105+
}
106+
107+
}

0 commit comments

Comments
 (0)