diff --git a/misk-admin/src/test/kotlin/misk/web/actions/TestWebActionModule.kt b/misk-admin/src/test/kotlin/misk/web/actions/TestWebActionModule.kt index 3ce6459b908..5d1c5b3a901 100644 --- a/misk-admin/src/test/kotlin/misk/web/actions/TestWebActionModule.kt +++ b/misk-admin/src/test/kotlin/misk/web/actions/TestWebActionModule.kt @@ -35,6 +35,7 @@ class TestWebActionModule : KAbstractModule() { install(WebActionModule.create()) install(WebActionModule.create()) install(WebActionModule.create()) + install(WebActionModule.create()) multibind() .toInstance(AccessAnnotationEntry(services = listOf("payments"))) @@ -90,3 +91,15 @@ class GrpcAction @Inject constructor() : TestWebActionModule.ShippingGetDestinat return Warehouse.Builder().warehouse_id(7777L).build() } } + +data class DataClassRequest(val token: String, val count: Int, val entries: List) + +data class DataClassEntry(val id: Long, val note: String?) + +class DataClassRequestAction @Inject constructor() : WebAction { + @Post("/data_class_request") + @RequestContentType(MediaTypes.APPLICATION_JSON) + @ResponseContentType(MediaTypes.TEXT_PLAIN_UTF8) + @Unauthenticated + fun handle(@RequestBody request: DataClassRequest) = "request: $request".toResponseBody() +} diff --git a/misk-admin/src/test/kotlin/misk/web/metadata/MiskWebFormBuilderTest.kt b/misk-admin/src/test/kotlin/misk/web/metadata/MiskWebFormBuilderTest.kt index 93252138262..d2e9036db74 100644 --- a/misk-admin/src/test/kotlin/misk/web/metadata/MiskWebFormBuilderTest.kt +++ b/misk-admin/src/test/kotlin/misk/web/metadata/MiskWebFormBuilderTest.kt @@ -21,10 +21,152 @@ internal class MiskWebFormBuilderTest { } @Test - fun `handles non-wire messages`() { - assertThat(miskWebFormBuilder.calculateTypes(String::class.createType())).isEmpty() + fun `handles plain data class with primitive fields`() { + val types = miskWebFormBuilder.calculateTypes(SimpleRequest::class.createType()) + + assertThat(types).containsKey(SimpleRequest::class.java.canonicalName) + val type = types[SimpleRequest::class.java.canonicalName]!! + assertThat(type.fields) + .containsExactlyInAnyOrder( + Field("id", "Long", false, emptyList()), + Field("name", "String", false, emptyList()), + Field("active", "Boolean", false, emptyList()), + ) + } + + @Test + fun `handles nullable fields on data class`() { + val types = miskWebFormBuilder.calculateTypes(NullableRequest::class.createType()) + val type = types[NullableRequest::class.java.canonicalName]!! + assertThat(type.fields) + .containsExactlyInAnyOrder( + Field("optionalName", "String", false, emptyList()), + Field("optionalCount", "Int", false, emptyList()), + ) + } + + @Test + fun `handles list of primitive on data class`() { + val types = miskWebFormBuilder.calculateTypes(StringListRequest::class.createType()) + val type = types[StringListRequest::class.java.canonicalName]!! + assertThat(type.fields).containsExactly(Field("tokens", "String", repeated = true, emptyList())) + } + + @Test + fun `handles list of nested data class with recursion`() { + val types = miskWebFormBuilder.calculateTypes(NestedListRequest::class.createType()) + + assertThat(types).containsKey(NestedListRequest::class.java.canonicalName) + assertThat(types).containsKey(NestedItem::class.java.canonicalName) + + val outer = types[NestedListRequest::class.java.canonicalName]!! + assertThat(outer.fields) + .containsExactly(Field("entries", NestedItem::class.java.canonicalName!!, repeated = true, emptyList())) + + val nested = types[NestedItem::class.java.canonicalName]!! + assertThat(nested.fields) + .containsExactlyInAnyOrder( + Field("itemId", "Long", false, emptyList()), + Field("note", "String", false, emptyList()), + ) + } + + @Test + fun `handles map on data class`() { + val types = miskWebFormBuilder.calculateTypes(MapRequest::class.createType()) + val type = types[MapRequest::class.java.canonicalName]!! + // Mirrors Wire-message behavior: map keys are skipped, values are emitted as a repeated field. + assertThat(type.fields).containsExactly(Field("counts", "Int", repeated = true, emptyList())) } + @Test + fun `handles nested data classes`() { + val types = miskWebFormBuilder.calculateTypes(OuterRequest::class.createType()) + + assertThat(types).containsKey(OuterRequest::class.java.canonicalName) + assertThat(types).containsKey(InnerRequest::class.java.canonicalName) + + val outer = types[OuterRequest::class.java.canonicalName]!! + assertThat(outer.fields) + .containsExactly(Field("inner", InnerRequest::class.java.canonicalName!!, repeated = false, emptyList())) + + val inner = types[InnerRequest::class.java.canonicalName]!! + assertThat(inner.fields).containsExactly(Field("value", "String", false, emptyList())) + } + + @Test + fun `does not re-walk a visited type`() { + // Self-referential structure: a data class with a list of itself. The visited-set guard + // prevents infinite recursion and ensures the type only appears once in the output map. + val types = miskWebFormBuilder.calculateTypes(RecursiveRequest::class.createType()) + + assertThat(types).hasSize(1) + assertThat(types).containsKey(RecursiveRequest::class.java.canonicalName) + val type = types[RecursiveRequest::class.java.canonicalName]!! + assertThat(type.fields) + .contains(Field("children", RecursiveRequest::class.java.canonicalName!!, repeated = true, emptyList())) + } + + @Test + fun `handles enum fields on data class`() { + val types = miskWebFormBuilder.calculateTypes(EnumRequest::class.createType()) + val type = types[EnumRequest::class.java.canonicalName]!! + assertThat(type.fields) + .containsExactly( + Field( + name = "color", + type = "Enum<${SimpleColor::class.java.canonicalName},RED,GREEN,BLUE>", + repeated = false, + annotations = emptyList(), + ) + ) + } + + @Test + fun `handles data class containing wire message`() { + val types = miskWebFormBuilder.calculateTypes(MixedRequest::class.createType()) + + assertThat(types).containsKey(MixedRequest::class.java.canonicalName) + // Nested Wire message gets walked via the WireField path, populating its proto fields too. + assertThat(types).containsKey(KotlinProtoShipment::class.qualifiedName) + assertThat(types).containsKey(KotlinProtoWarehouse::class.qualifiedName) + + val mixed = types[MixedRequest::class.java.canonicalName]!! + assertThat(mixed.fields) + .contains( + Field("name", "String", false, emptyList()), + Field("shipment", KotlinProtoShipment::class.qualifiedName!!, repeated = false, emptyList()), + ) + } + + private data class SimpleRequest(val id: Long, val name: String, val active: Boolean) + + private data class NullableRequest(val optionalName: String?, val optionalCount: Int?) + + private data class StringListRequest(val tokens: List) + + private data class NestedItem(val itemId: Long, val note: String?) + + private data class NestedListRequest(val entries: List) + + private data class MapRequest(val counts: Map) + + private data class InnerRequest(val value: String) + + private data class OuterRequest(val inner: InnerRequest) + + private data class RecursiveRequest(val name: String, val children: List) + + private enum class SimpleColor { + RED, + GREEN, + BLUE, + } + + private data class EnumRequest(val color: SimpleColor) + + private data class MixedRequest(val name: String, val shipment: KotlinProtoShipment) + @Test fun `handles java wire messages`() { val types = miskWebFormBuilder.calculateTypes(Shipment::class.createType()) diff --git a/misk-admin/src/test/kotlin/misk/web/metadata/webaction/WebActionMetadataActionTest.kt b/misk-admin/src/test/kotlin/misk/web/metadata/webaction/WebActionMetadataActionTest.kt index 149575368dc..45544e11af5 100644 --- a/misk-admin/src/test/kotlin/misk/web/metadata/webaction/WebActionMetadataActionTest.kt +++ b/misk-admin/src/test/kotlin/misk/web/metadata/webaction/WebActionMetadataActionTest.kt @@ -5,8 +5,12 @@ import com.squareup.protos.test.parsing.Warehouse import jakarta.inject.Inject import misk.testing.MiskTest import misk.testing.MiskTestModule +import misk.web.MiskWebFormBuilder import misk.web.actions.CustomCapabilityAccessAction import misk.web.actions.CustomServiceAccessAction +import misk.web.actions.DataClassEntry +import misk.web.actions.DataClassRequest +import misk.web.actions.DataClassRequestAction import misk.web.actions.GrpcAction import misk.web.mediatype.MediaTypes import misk.web.metadata.MetadataTestingModule @@ -64,4 +68,30 @@ class WebActionMetadataActionTest { assertThat(metadata.returnType).isEqualTo(Warehouse::class.qualifiedName) assertThat(metadata.types).isNotEmpty } + + @Test + fun `data class request body is introspected for form metadata`() { + val response = webActionMetadataAction.getAll() + val metadata = response.webActionMetadata.find { it.name == DataClassRequestAction::class.simpleName }!! + + val requestTypeKey = DataClassRequest::class.qualifiedName!! + val entryTypeKey = DataClassEntry::class.qualifiedName!! + + assertThat(metadata.types).containsKeys(requestTypeKey, entryTypeKey) + + val requestType = metadata.types[requestTypeKey]!! + assertThat(requestType.fields) + .containsExactlyInAnyOrder( + MiskWebFormBuilder.Field("token", "String", repeated = false, annotations = emptyList()), + MiskWebFormBuilder.Field("count", "Int", repeated = false, annotations = emptyList()), + MiskWebFormBuilder.Field("entries", entryTypeKey, repeated = true, annotations = emptyList()), + ) + + val entryType = metadata.types[entryTypeKey]!! + assertThat(entryType.fields) + .containsExactlyInAnyOrder( + MiskWebFormBuilder.Field("id", "Long", repeated = false, annotations = emptyList()), + MiskWebFormBuilder.Field("note", "String", repeated = false, annotations = emptyList()), + ) + } } diff --git a/misk/src/main/kotlin/misk/web/MiskWebFormBuilder.kt b/misk/src/main/kotlin/misk/web/MiskWebFormBuilder.kt index 13e7bd6abb5..2dfc62ef453 100644 --- a/misk/src/main/kotlin/misk/web/MiskWebFormBuilder.kt +++ b/misk/src/main/kotlin/misk/web/MiskWebFormBuilder.kt @@ -21,22 +21,20 @@ import okio.ByteString /** * Provides a mapping from field name to Type definition given a KType. Useful for processes that want to have a schema * definition of a type. For example: used by the WebActions admin dashboard tab to show a statically typed form - * containing request fields for developers to fill out. Currently only supports Wire request type messages; non-Wire - * messages return an empty mapping. + * containing request fields for developers to fill out. + * + * Wire-generated [Message] types are introspected via their [WireField]-annotated fields. Plain Kotlin + * `data class` request types are introspected by walking [declaredMemberProperties] using Kotlin reflection + * so generic type parameters (e.g. `List`, `Map`) are first-class. */ class MiskWebFormBuilder @JvmOverloads constructor(private val documentationProvider: ProtoDocumentationProvider? = null) { fun calculateTypes(requestType: KType?): Map { - // Type maps can only be calculated for wire messages if (requestType == null) { return mapOf() } - - val requestClass = requestType.classifier as KClass<*> - if (Message::class !in requestClass.superclasses) { - return mapOf() - } + val requestClass = requestType.classifier as? KClass<*> ?: return mapOf() val typesMap = mutableMapOf() val stack = LinkedList>() @@ -44,36 +42,67 @@ constructor(private val documentationProvider: ProtoDocumentationProvider? = nul while (stack.isNotEmpty()) { val clazz = stack.pop() + val canonicalName = clazz.java.canonicalName ?: continue - // No need to re-process a given type. - // This acts as the visited set of our type graph traversal. - if (typesMap.containsKey(clazz.java.canonicalName!!)) { + // Acts as the visited set of our type graph traversal. + if (typesMap.containsKey(canonicalName)) { continue } - val fields = mutableListOf() - - for (property in clazz.declaredMemberProperties) { - val field = property.javaField - // Use the WireField annotation to identify fields of our proto. - if (field?.annotations?.any { it is WireField } == true) { - handleField( - fieldType = TypeLiteral.get(field.genericType), - fieldName = field.name, - fields = fields, - stack = stack, - annotations = property.annotations.filter { it !is WireField }, - ) - } + val isWireMessage = Message::class in clazz.superclasses + // Only Wire messages and Kotlin data classes are meaningfully introspectable here. Skipping + // other types preserves the original Wire-only contract for unsupported request types and + // also avoids Kotlin reflection errors on synthetic classes (e.g. `Function0` lambdas). + if (!isWireMessage && !clazz.isData) { + continue } + val fields = + if (isWireMessage) { + collectWireFields(clazz, stack) + } else { + collectKotlinFields(clazz, stack) + } + val documentationUrl = documentationProvider?.get(clazz.getProtobufType()) - typesMap[clazz.java.canonicalName!!] = Type(fields.toList(), documentationUrl) + typesMap[canonicalName] = Type(fields, documentationUrl) } return typesMap } + private fun collectWireFields(clazz: KClass<*>, stack: LinkedList>): List { + val fields = mutableListOf() + for (property in clazz.declaredMemberProperties) { + val field = property.javaField + // Use the WireField annotation to identify fields of our proto. + if (field?.annotations?.any { it is WireField } == true) { + handleField( + fieldType = TypeLiteral.get(field.genericType), + fieldName = field.name, + fields = fields, + stack = stack, + annotations = property.annotations.filter { it !is WireField }, + ) + } + } + return fields + } + + private fun collectKotlinFields(clazz: KClass<*>, stack: LinkedList>): List { + val fields = mutableListOf() + for (property in clazz.declaredMemberProperties) { + handleKotlinField( + fieldType = property.returnType, + fieldName = property.name, + fields = fields, + stack = stack, + annotations = property.annotations, + ) + } + return fields + } + private fun handleField( fieldType: TypeLiteral<*>, fieldName: String, @@ -115,6 +144,52 @@ constructor(private val documentationProvider: ProtoDocumentationProvider? = nul } } + private fun handleKotlinField( + fieldType: KType, + fieldName: String, + fields: MutableList, + stack: LinkedList>, + repeated: Boolean = false, + annotations: List, + ) { + val fieldClass = fieldType.classifier as? KClass<*> ?: return + val javaClass = fieldClass.javaObjectType + val maybePrimitiveType = maybeCreatePrimitiveField(javaClass, fieldName, repeated, annotations) + when { + maybePrimitiveType != null -> fields.add(maybePrimitiveType) + + fieldClass.java.isEnum -> { + fields.add(createEnumField(fieldClass.java, fieldName, repeated, annotations)) + } + + fieldClass.isSubclassOf(Collection::class) -> { + val argType = fieldType.arguments.firstOrNull()?.type ?: return + handleKotlinField(argType, fieldName, fields, stack, repeated = true, annotations = annotations) + } + + fieldClass.isSubclassOf(Map::class) -> { + val args = fieldType.arguments + if (args.size < 2) return + // key type can never be a nested message so we skip it + val valueType = args[1].type ?: return + handleKotlinField(valueType, fieldName, fields, stack, repeated = true, annotations = annotations) + } + + else -> { + val canonicalName = fieldClass.java.canonicalName ?: return + fields.add( + Field( + name = fieldName, + type = canonicalName, + repeated = repeated, + annotations = annotations.toStrings(), + ) + ) + stack.push(fieldClass) + } + } + } + companion object { /** Create misk-web [Field]s for primitives and enum types. Returns null if the type cannot be mapped. */ fun maybeCreatePrimitiveField(