diff --git a/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBFieldResolver.java b/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBFieldResolver.java index a1a35f98e5..6d49edeffb 100644 --- a/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBFieldResolver.java +++ b/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBFieldResolver.java @@ -20,10 +20,14 @@ package com.amazonaws.athena.connectors.docdb; import com.amazonaws.athena.connector.lambda.data.FieldResolver; +import com.mongodb.DBRef; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; import org.bson.Document; +import java.util.Map; +import java.util.function.Function; + /** * Used to resolve DocDB complex structures to Apache Arrow Types. * @@ -36,6 +40,12 @@ public class DocDBFieldResolver private DocDBFieldResolver() {} + static final Map> dbRefExtractor = Map.of( + "_id", dbRef -> dbRef.getId().toString(), + "_db", DBRef::getDatabaseName, + "_ref", DBRef::getCollectionName + ); + @Override public Object getFieldValue(Field field, Object value) { @@ -47,6 +57,9 @@ else if (value instanceof Document) { Object rawVal = ((Document) value).get(field.getName()); return TypeUtils.coerce(field, rawVal); } + else if (value instanceof DBRef) { + return TypeUtils.coerce(field, dbRefExtractor.get(field.getName()).apply((DBRef) value)); + } throw new RuntimeException("Expected LIST or Document type but found " + minorType); } } diff --git a/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/SchemaUtils.java b/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/SchemaUtils.java index 40fe3fb5fe..ec1d15912c 100644 --- a/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/SchemaUtils.java +++ b/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/SchemaUtils.java @@ -22,6 +22,7 @@ import com.amazonaws.athena.connector.lambda.data.FieldBuilder; import com.amazonaws.athena.connector.lambda.data.SchemaBuilder; import com.amazonaws.athena.connector.lambda.domain.TableName; +import com.mongodb.DBRef; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCursor; import com.mongodb.client.MongoDatabase; @@ -253,6 +254,13 @@ else if (value instanceof Document) { } return new Field(key, FieldType.nullable(Types.MinorType.STRUCT.getType()), children); } + else if (value instanceof DBRef) { + List children = new ArrayList<>(); + children.add(new Field("_db", FieldType.nullable(Types.MinorType.VARCHAR.getType()), null)); + children.add(new Field("_ref", FieldType.nullable(Types.MinorType.VARCHAR.getType()), null)); + children.add(new Field("_id", FieldType.nullable(Types.MinorType.VARCHAR.getType()), null)); + return new Field(key, FieldType.nullable(Types.MinorType.STRUCT.getType()), children); + } String className = (value == null || value.getClass() == null) ? "null" : value.getClass().getName(); logger.warn("Unknown type[" + className + "] for field[" + key + "], defaulting to varchar."); diff --git a/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandlerTest.java b/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandlerTest.java index ed87b500c3..37b9b3d286 100644 --- a/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandlerTest.java +++ b/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandlerTest.java @@ -50,6 +50,7 @@ import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.google.common.collect.ImmutableList; import com.google.common.io.ByteStreams; +import com.mongodb.DBRef; import com.mongodb.client.FindIterable; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; @@ -59,6 +60,7 @@ import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Schema; import org.bson.Document; +import org.bson.types.ObjectId; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -441,6 +443,74 @@ public void nestedStructTest() assertEquals(expectedString, BlockUtils.rowToString(response.getRecords(), 0)); } + @Test + public void dbRefTest() + throws Exception + { + ObjectId id = ObjectId.get(); + + List documents = new ArrayList<>(); + Document result = new Document(); + documents.add(result); + result.put("DbRef", new DBRef("otherDb", "otherColl", id)); + + Document simpleStruct = new Document(); + simpleStruct.put("SomeSimpleStruct", "someSimpleStruct"); + result.put("SimpleStruct", simpleStruct); + + when(mockCollection.find()).thenReturn(mockIterable); + when(mockIterable.limit(anyInt())).thenReturn(mockIterable); + Mockito.lenient().when(mockIterable.maxScan(anyInt())).thenReturn(mockIterable); + when(mockIterable.batchSize(anyInt())).thenReturn(mockIterable); + when(mockIterable.iterator()).thenReturn(new StubbingCursor(documents.iterator())); + + GetTableRequest req = new GetTableRequest(IDENTITY, QUERY_ID, DEFAULT_CATALOG, TABLE_NAME); + GetTableResponse res = mdHandler.doGetTable(allocator, req); + logger.info("doGetTable - {}", res); + + when(mockCollection.find(nullable(Document.class))).thenAnswer((InvocationOnMock invocationOnMock) -> { + logger.info("doReadRecordsNoSpill: query[{}]", invocationOnMock.getArguments()[0]); + return mockIterable; + }); + when(mockIterable.projection(nullable(Document.class))).thenAnswer((InvocationOnMock invocationOnMock) -> { + logger.info("doReadRecordsNoSpill: projection[{}]", invocationOnMock.getArguments()[0]); + return mockIterable; + }); + when(mockIterable.batchSize(anyInt())).thenReturn(mockIterable); + when(mockIterable.iterator()).thenReturn(new StubbingCursor(documents.iterator())); + + + Map constraintsMap = new HashMap<>(); + S3SpillLocation splitLoc = S3SpillLocation.newBuilder() + .withBucket(UUID.randomUUID().toString()) + .withSplitId(UUID.randomUUID().toString()) + .withQueryId(UUID.randomUUID().toString()) + .withIsDirectory(true) + .build(); + + ReadRecordsRequest request = new ReadRecordsRequest(IDENTITY, + DEFAULT_CATALOG, + "queryId-" + System.currentTimeMillis(), + TABLE_NAME, + res.getSchema(), + Split.newBuilder(splitLoc, keyFactory.create()).add(DOCDB_CONN_STR, CONNECTION_STRING).build(), + new Constraints(constraintsMap), + 100_000_000_000L, //100GB don't expect this to spill + 100_000_000_000L + ); + + RecordResponse rawResponse = handler.doReadRecords(allocator, request); + + assertTrue(rawResponse instanceof ReadRecordsResponse); + + ReadRecordsResponse response = (ReadRecordsResponse) rawResponse; + logger.info("doReadRecordsNoSpill: rows[{}]", response.getRecordCount()); + logger.info("doReadRecordsNoSpill: {}", BlockUtils.rowToString(response.getRecords(), 0)); + assertTrue(response.getRecordCount() == 1); + String expectedString = "[DbRef : {[_db : otherDb],[_ref : otherColl],[_id : " + id.toHexString() + "]}], [SimpleStruct : {[SomeSimpleStruct : someSimpleStruct]}]"; + assertEquals(expectedString, BlockUtils.rowToString(response.getRecords(), 0)); + } + private class ByteHolder { private byte[] bytes; diff --git a/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/SchemaUtilsTest.java b/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/SchemaUtilsTest.java index 61bf31427d..8a3841f1ac 100644 --- a/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/SchemaUtilsTest.java +++ b/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/SchemaUtilsTest.java @@ -20,6 +20,7 @@ package com.amazonaws.athena.connectors.docdb; import com.amazonaws.athena.connector.lambda.domain.TableName; +import com.mongodb.DBRef; import com.mongodb.client.FindIterable; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; @@ -28,6 +29,7 @@ import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.bson.Document; +import org.bson.types.ObjectId; import org.junit.Test; import java.util.ArrayList; @@ -191,4 +193,46 @@ public void emptyListTest() assertEquals(Types.MinorType.LIST, Types.getMinorTypeForArrowType(fields.get("col4").getType())); assertEquals(Types.MinorType.VARCHAR, Types.getMinorTypeForArrowType(fields.get("col4").getChildren().get(0).getType())); } + + @Test + public void dbRefTet() + { + List docs = new ArrayList<>(); + Document doc1 = new Document(); + doc1.put("col1", 1); + doc1.put("col2", new DBRef("otherColl", ObjectId.get())); + docs.add(doc1); + + Document doc2 = new Document(); + doc2.put("col1", 1); + doc2.put("col2", new DBRef("otherDb", "otherColl", ObjectId.get())); + docs.add(doc2); + + MongoClient mockClient = mock(MongoClient.class); + MongoDatabase mockDatabase = mock(MongoDatabase.class); + MongoCollection mockCollection = mock(MongoCollection.class); + FindIterable mockIterable = mock(FindIterable.class); + when(mockClient.getDatabase(any())).thenReturn(mockDatabase); + when(mockDatabase.getCollection(any())).thenReturn(mockCollection); + when(mockCollection.find()).thenReturn(mockIterable); + when(mockIterable.limit(anyInt())).thenReturn(mockIterable); + when(mockIterable.maxScan(anyInt())).thenReturn(mockIterable); + when(mockIterable.batchSize(anyInt())).thenReturn(mockIterable); + when(mockIterable.iterator()).thenReturn(new StubbingCursor(docs.iterator())); + + Schema schema = SchemaUtils.inferSchema(mockClient, new TableName("test", "test"), 10); + assertEquals(2, schema.getFields().size()); + + Map fields = new HashMap<>(); + schema.getFields().stream().forEach(next -> fields.put(next.getName(), next)); + + assertEquals(Types.MinorType.INT, Types.getMinorTypeForArrowType(fields.get("col1").getType())); + assertEquals(Types.MinorType.STRUCT, Types.getMinorTypeForArrowType(fields.get("col2").getType())); + assertEquals("_db", fields.get("col2").getChildren().get(0).getName()); + assertEquals(Types.MinorType.VARCHAR, Types.getMinorTypeForArrowType(fields.get("col2").getChildren().get(0).getType())); + assertEquals("_ref", fields.get("col2").getChildren().get(1).getName()); + assertEquals(Types.MinorType.VARCHAR, Types.getMinorTypeForArrowType(fields.get("col2").getChildren().get(1).getType())); + assertEquals("_id", fields.get("col2").getChildren().get(2).getName()); + assertEquals(Types.MinorType.VARCHAR, Types.getMinorTypeForArrowType(fields.get("col2").getChildren().get(2).getType())); + } }