Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions misk-admin/src/test/kotlin/misk/web/actions/TestWebActionModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class TestWebActionModule : KAbstractModule() {
install(WebActionModule.create<CustomCapabilityAccessAction>())
install(WebActionModule.create<RequestTypeAction>())
install(WebActionModule.create<GrpcAction>())
install(WebActionModule.create<DataClassRequestAction>())

multibind<AccessAnnotationEntry>()
.toInstance(AccessAnnotationEntry<CustomServiceAccess>(services = listOf("payments")))
Expand Down Expand Up @@ -90,3 +91,15 @@ class GrpcAction @Inject constructor() : TestWebActionModule.ShippingGetDestinat
return Warehouse.Builder().warehouse_id(7777L).build()
}
}

data class DataClassRequest(val token: String, val count: Int, val entries: List<DataClassEntry>)

data class DataClassEntry(val id: Long, val note: String?)

class DataClassRequestAction @Inject constructor() : WebAction {
@Post("/data_class_request")
@RequestContentType(MediaTypes.APPLICATION_JSON)
@ResponseContentType(MediaTypes.TEXT_PLAIN_UTF8)
@Unauthenticated
fun handle(@RequestBody request: DataClassRequest) = "request: $request".toResponseBody()
}
146 changes: 144 additions & 2 deletions misk-admin/src/test/kotlin/misk/web/metadata/MiskWebFormBuilderTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,152 @@ internal class MiskWebFormBuilderTest {
}

@Test
fun `handles non-wire messages`() {
assertThat(miskWebFormBuilder.calculateTypes(String::class.createType())).isEmpty()
fun `handles plain data class with primitive fields`() {
val types = miskWebFormBuilder.calculateTypes(SimpleRequest::class.createType())

assertThat(types).containsKey(SimpleRequest::class.java.canonicalName)
val type = types[SimpleRequest::class.java.canonicalName]!!
assertThat(type.fields)
.containsExactlyInAnyOrder(
Field("id", "Long", false, emptyList()),
Field("name", "String", false, emptyList()),
Field("active", "Boolean", false, emptyList()),
)
}

@Test
fun `handles nullable fields on data class`() {
val types = miskWebFormBuilder.calculateTypes(NullableRequest::class.createType())
val type = types[NullableRequest::class.java.canonicalName]!!
assertThat(type.fields)
.containsExactlyInAnyOrder(
Field("optionalName", "String", false, emptyList()),
Field("optionalCount", "Int", false, emptyList()),
)
}

@Test
fun `handles list of primitive on data class`() {
val types = miskWebFormBuilder.calculateTypes(StringListRequest::class.createType())
val type = types[StringListRequest::class.java.canonicalName]!!
assertThat(type.fields).containsExactly(Field("tokens", "String", repeated = true, emptyList()))
}

@Test
fun `handles list of nested data class with recursion`() {
val types = miskWebFormBuilder.calculateTypes(NestedListRequest::class.createType())

assertThat(types).containsKey(NestedListRequest::class.java.canonicalName)
assertThat(types).containsKey(NestedItem::class.java.canonicalName)

val outer = types[NestedListRequest::class.java.canonicalName]!!
assertThat(outer.fields)
.containsExactly(Field("entries", NestedItem::class.java.canonicalName!!, repeated = true, emptyList()))

val nested = types[NestedItem::class.java.canonicalName]!!
assertThat(nested.fields)
.containsExactlyInAnyOrder(
Field("itemId", "Long", false, emptyList()),
Field("note", "String", false, emptyList()),
)
}

@Test
fun `handles map on data class`() {
val types = miskWebFormBuilder.calculateTypes(MapRequest::class.createType())
val type = types[MapRequest::class.java.canonicalName]!!
// Mirrors Wire-message behavior: map keys are skipped, values are emitted as a repeated field.
assertThat(type.fields).containsExactly(Field("counts", "Int", repeated = true, emptyList()))
}

@Test
fun `handles nested data classes`() {
val types = miskWebFormBuilder.calculateTypes(OuterRequest::class.createType())

assertThat(types).containsKey(OuterRequest::class.java.canonicalName)
assertThat(types).containsKey(InnerRequest::class.java.canonicalName)

val outer = types[OuterRequest::class.java.canonicalName]!!
assertThat(outer.fields)
.containsExactly(Field("inner", InnerRequest::class.java.canonicalName!!, repeated = false, emptyList()))

val inner = types[InnerRequest::class.java.canonicalName]!!
assertThat(inner.fields).containsExactly(Field("value", "String", false, emptyList()))
}

@Test
fun `does not re-walk a visited type`() {
// Self-referential structure: a data class with a list of itself. The visited-set guard
// prevents infinite recursion and ensures the type only appears once in the output map.
val types = miskWebFormBuilder.calculateTypes(RecursiveRequest::class.createType())

assertThat(types).hasSize(1)
assertThat(types).containsKey(RecursiveRequest::class.java.canonicalName)
val type = types[RecursiveRequest::class.java.canonicalName]!!
assertThat(type.fields)
.contains(Field("children", RecursiveRequest::class.java.canonicalName!!, repeated = true, emptyList()))
}

@Test
fun `handles enum fields on data class`() {
val types = miskWebFormBuilder.calculateTypes(EnumRequest::class.createType())
val type = types[EnumRequest::class.java.canonicalName]!!
assertThat(type.fields)
.containsExactly(
Field(
name = "color",
type = "Enum<${SimpleColor::class.java.canonicalName},RED,GREEN,BLUE>",
repeated = false,
annotations = emptyList(),
)
)
}

@Test
fun `handles data class containing wire message`() {
val types = miskWebFormBuilder.calculateTypes(MixedRequest::class.createType())

assertThat(types).containsKey(MixedRequest::class.java.canonicalName)
// Nested Wire message gets walked via the WireField path, populating its proto fields too.
assertThat(types).containsKey(KotlinProtoShipment::class.qualifiedName)
assertThat(types).containsKey(KotlinProtoWarehouse::class.qualifiedName)

val mixed = types[MixedRequest::class.java.canonicalName]!!
assertThat(mixed.fields)
.contains(
Field("name", "String", false, emptyList()),
Field("shipment", KotlinProtoShipment::class.qualifiedName!!, repeated = false, emptyList()),
)
}

private data class SimpleRequest(val id: Long, val name: String, val active: Boolean)

private data class NullableRequest(val optionalName: String?, val optionalCount: Int?)

private data class StringListRequest(val tokens: List<String>)

private data class NestedItem(val itemId: Long, val note: String?)

private data class NestedListRequest(val entries: List<NestedItem>)

private data class MapRequest(val counts: Map<String, Int>)

private data class InnerRequest(val value: String)

private data class OuterRequest(val inner: InnerRequest)

private data class RecursiveRequest(val name: String, val children: List<RecursiveRequest>)

private enum class SimpleColor {
RED,
GREEN,
BLUE,
}

private data class EnumRequest(val color: SimpleColor)

private data class MixedRequest(val name: String, val shipment: KotlinProtoShipment)

@Test
fun `handles java wire messages`() {
val types = miskWebFormBuilder.calculateTypes(Shipment::class.createType())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ import com.squareup.protos.test.parsing.Warehouse
import jakarta.inject.Inject
import misk.testing.MiskTest
import misk.testing.MiskTestModule
import misk.web.MiskWebFormBuilder
import misk.web.actions.CustomCapabilityAccessAction
import misk.web.actions.CustomServiceAccessAction
import misk.web.actions.DataClassEntry
import misk.web.actions.DataClassRequest
import misk.web.actions.DataClassRequestAction
import misk.web.actions.GrpcAction
import misk.web.mediatype.MediaTypes
import misk.web.metadata.MetadataTestingModule
Expand Down Expand Up @@ -64,4 +68,30 @@ class WebActionMetadataActionTest {
assertThat(metadata.returnType).isEqualTo(Warehouse::class.qualifiedName)
assertThat(metadata.types).isNotEmpty
}

@Test
fun `data class request body is introspected for form metadata`() {
val response = webActionMetadataAction.getAll()
val metadata = response.webActionMetadata.find { it.name == DataClassRequestAction::class.simpleName }!!

val requestTypeKey = DataClassRequest::class.qualifiedName!!
val entryTypeKey = DataClassEntry::class.qualifiedName!!

assertThat(metadata.types).containsKeys(requestTypeKey, entryTypeKey)

val requestType = metadata.types[requestTypeKey]!!
assertThat(requestType.fields)
.containsExactlyInAnyOrder(
MiskWebFormBuilder.Field("token", "String", repeated = false, annotations = emptyList()),
MiskWebFormBuilder.Field("count", "Int", repeated = false, annotations = emptyList()),
MiskWebFormBuilder.Field("entries", entryTypeKey, repeated = true, annotations = emptyList()),
)

val entryType = metadata.types[entryTypeKey]!!
assertThat(entryType.fields)
.containsExactlyInAnyOrder(
MiskWebFormBuilder.Field("id", "Long", repeated = false, annotations = emptyList()),
MiskWebFormBuilder.Field("note", "String", repeated = false, annotations = emptyList()),
)
}
}
127 changes: 101 additions & 26 deletions misk/src/main/kotlin/misk/web/MiskWebFormBuilder.kt
Original file line number Diff line number Diff line change
Expand Up @@ -21,59 +21,88 @@ import okio.ByteString
/**
* Provides a mapping from field name to Type definition given a KType. Useful for processes that want to have a schema
* definition of a type. For example: used by the WebActions admin dashboard tab to show a statically typed form
* containing request fields for developers to fill out. Currently only supports Wire request type messages; non-Wire
* messages return an empty mapping.
* containing request fields for developers to fill out.
*
* Wire-generated [Message] types are introspected via their [WireField]-annotated fields. Plain Kotlin
* `data class` request types are introspected by walking [declaredMemberProperties] using Kotlin reflection
* so generic type parameters (e.g. `List<Foo>`, `Map<String, Bar>`) are first-class.
*/
class MiskWebFormBuilder
@JvmOverloads
constructor(private val documentationProvider: ProtoDocumentationProvider? = null) {
fun calculateTypes(requestType: KType?): Map<String, Type> {
// Type maps can only be calculated for wire messages
if (requestType == null) {
return mapOf()
}

val requestClass = requestType.classifier as KClass<*>
if (Message::class !in requestClass.superclasses) {
return mapOf()
}
val requestClass = requestType.classifier as? KClass<*> ?: return mapOf()

val typesMap = mutableMapOf<String, Type>()
val stack = LinkedList<KClass<*>>()
stack.push(requestClass)

while (stack.isNotEmpty()) {
val clazz = stack.pop()
val canonicalName = clazz.java.canonicalName ?: continue

// No need to re-process a given type.
// This acts as the visited set of our type graph traversal.
if (typesMap.containsKey(clazz.java.canonicalName!!)) {
// Acts as the visited set of our type graph traversal.
if (typesMap.containsKey(canonicalName)) {
continue
}

val fields = mutableListOf<Field>()

for (property in clazz.declaredMemberProperties) {
val field = property.javaField
// Use the WireField annotation to identify fields of our proto.
if (field?.annotations?.any { it is WireField } == true) {
handleField(
fieldType = TypeLiteral.get(field.genericType),
fieldName = field.name,
fields = fields,
stack = stack,
annotations = property.annotations.filter { it !is WireField },
)
}
val isWireMessage = Message::class in clazz.superclasses
// Only Wire messages and Kotlin data classes are meaningfully introspectable here. Skipping
// other types preserves the original Wire-only contract for unsupported request types and
// also avoids Kotlin reflection errors on synthetic classes (e.g. `Function0` lambdas).
if (!isWireMessage && !clazz.isData) {
continue
}

val fields =
if (isWireMessage) {
collectWireFields(clazz, stack)
} else {
collectKotlinFields(clazz, stack)
}

val documentationUrl = documentationProvider?.get(clazz.getProtobufType())
typesMap[clazz.java.canonicalName!!] = Type(fields.toList(), documentationUrl)
typesMap[canonicalName] = Type(fields, documentationUrl)
}

return typesMap
}

private fun collectWireFields(clazz: KClass<*>, stack: LinkedList<KClass<*>>): List<Field> {
val fields = mutableListOf<Field>()
for (property in clazz.declaredMemberProperties) {
val field = property.javaField
// Use the WireField annotation to identify fields of our proto.
if (field?.annotations?.any { it is WireField } == true) {
handleField(
fieldType = TypeLiteral.get(field.genericType),
fieldName = field.name,
fields = fields,
stack = stack,
annotations = property.annotations.filter { it !is WireField },
)
}
}
return fields
}

private fun collectKotlinFields(clazz: KClass<*>, stack: LinkedList<KClass<*>>): List<Field> {
val fields = mutableListOf<Field>()
for (property in clazz.declaredMemberProperties) {
handleKotlinField(
fieldType = property.returnType,
fieldName = property.name,
fields = fields,
stack = stack,
annotations = property.annotations,
)
}
return fields
}

private fun handleField(
fieldType: TypeLiteral<*>,
fieldName: String,
Expand Down Expand Up @@ -115,6 +144,52 @@ constructor(private val documentationProvider: ProtoDocumentationProvider? = nul
}
}

private fun handleKotlinField(
fieldType: KType,
fieldName: String,
fields: MutableList<Field>,
stack: LinkedList<KClass<*>>,
repeated: Boolean = false,
annotations: List<Annotation>,
) {
val fieldClass = fieldType.classifier as? KClass<*> ?: return
val javaClass = fieldClass.javaObjectType
val maybePrimitiveType = maybeCreatePrimitiveField(javaClass, fieldName, repeated, annotations)
when {
maybePrimitiveType != null -> fields.add(maybePrimitiveType)

fieldClass.java.isEnum -> {
fields.add(createEnumField(fieldClass.java, fieldName, repeated, annotations))
}

fieldClass.isSubclassOf(Collection::class) -> {
val argType = fieldType.arguments.firstOrNull()?.type ?: return
handleKotlinField(argType, fieldName, fields, stack, repeated = true, annotations = annotations)
}

fieldClass.isSubclassOf(Map::class) -> {
val args = fieldType.arguments
if (args.size < 2) return
// key type can never be a nested message so we skip it
val valueType = args[1].type ?: return
handleKotlinField(valueType, fieldName, fields, stack, repeated = true, annotations = annotations)
}

else -> {
val canonicalName = fieldClass.java.canonicalName ?: return
fields.add(
Field(
name = fieldName,
type = canonicalName,
repeated = repeated,
annotations = annotations.toStrings(),
)
)
stack.push(fieldClass)
}
}
}

companion object {
/** Create misk-web [Field]s for primitives and enum types. Returns null if the type cannot be mapped. */
fun maybeCreatePrimitiveField(
Expand Down