Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for dbref #1035

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package com.amazonaws.athena.connectors.docdb;

import com.amazonaws.athena.connector.lambda.data.FieldResolver;
import com.mongodb.DBRef;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.Field;
import org.bson.Document;
Expand Down Expand Up @@ -47,6 +48,19 @@ else if (value instanceof Document) {
Object rawVal = ((Document) value).get(field.getName());
return TypeUtils.coerce(field, rawVal);
}
else if (value instanceof DBRef) {
Object rawVal = null;
if (field.getName().equals("_id")) {
rawVal = ((DBRef) value).getId();
}
if (field.getName().equals("_db")) {
rawVal = ((DBRef) value).getDatabaseName();
}
if (field.getName().equals("_ref")) {
rawVal = ((DBRef) value).getCollectionName();
}
return TypeUtils.coerce(field, rawVal);
}
throw new RuntimeException("Expected LIST or Document type but found " + minorType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.amazonaws.athena.connector.lambda.data.FieldBuilder;
import com.amazonaws.athena.connector.lambda.data.SchemaBuilder;
import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.mongodb.DBRef;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCursor;
import com.mongodb.client.MongoDatabase;
Expand Down Expand Up @@ -253,6 +254,13 @@ else if (value instanceof Document) {
}
return new Field(key, FieldType.nullable(Types.MinorType.STRUCT.getType()), children);
}
else if (value instanceof DBRef) {
List<Field> children = new ArrayList<>();
children.add(new Field("_db", FieldType.nullable(Types.MinorType.VARCHAR.getType()), null));
children.add(new Field("_ref", FieldType.nullable(Types.MinorType.VARCHAR.getType()), null));
children.add(new Field("_id", FieldType.nullable(Types.MinorType.VARCHAR.getType()), null));
return new Field(key, FieldType.nullable(Types.MinorType.STRUCT.getType()), children);
}

String className = (value == null || value.getClass() == null) ? "null" : value.getClass().getName();
logger.warn("Unknown type[" + className + "] for field[" + key + "], defaulting to varchar.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
import com.google.common.collect.ImmutableList;
import com.google.common.io.ByteStreams;
import com.mongodb.DBRef;
import com.mongodb.client.FindIterable;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
Expand All @@ -59,6 +60,7 @@
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.bson.Document;
import org.bson.types.ObjectId;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
Expand Down Expand Up @@ -441,6 +443,74 @@ public void nestedStructTest()
assertEquals(expectedString, BlockUtils.rowToString(response.getRecords(), 0));
}

@Test
public void dbRefTest()
throws Exception
{
ObjectId id = ObjectId.get();

List<Document> documents = new ArrayList<>();
Document result = new Document();
documents.add(result);
result.put("DbRef", new DBRef("otherDb", "otherColl", id));

Document simpleStruct = new Document();
simpleStruct.put("SomeSimpleStruct", "someSimpleStruct");
result.put("SimpleStruct", simpleStruct);

when(mockCollection.find()).thenReturn(mockIterable);
when(mockIterable.limit(anyInt())).thenReturn(mockIterable);
Mockito.lenient().when(mockIterable.maxScan(anyInt())).thenReturn(mockIterable);
when(mockIterable.batchSize(anyInt())).thenReturn(mockIterable);
when(mockIterable.iterator()).thenReturn(new StubbingCursor(documents.iterator()));

GetTableRequest req = new GetTableRequest(IDENTITY, QUERY_ID, DEFAULT_CATALOG, TABLE_NAME);
GetTableResponse res = mdHandler.doGetTable(allocator, req);
logger.info("doGetTable - {}", res);

when(mockCollection.find(nullable(Document.class))).thenAnswer((InvocationOnMock invocationOnMock) -> {
logger.info("doReadRecordsNoSpill: query[{}]", invocationOnMock.getArguments()[0]);
return mockIterable;
});
when(mockIterable.projection(nullable(Document.class))).thenAnswer((InvocationOnMock invocationOnMock) -> {
logger.info("doReadRecordsNoSpill: projection[{}]", invocationOnMock.getArguments()[0]);
return mockIterable;
});
when(mockIterable.batchSize(anyInt())).thenReturn(mockIterable);
when(mockIterable.iterator()).thenReturn(new StubbingCursor(documents.iterator()));


Map<String, ValueSet> constraintsMap = new HashMap<>();
S3SpillLocation splitLoc = S3SpillLocation.newBuilder()
.withBucket(UUID.randomUUID().toString())
.withSplitId(UUID.randomUUID().toString())
.withQueryId(UUID.randomUUID().toString())
.withIsDirectory(true)
.build();

ReadRecordsRequest request = new ReadRecordsRequest(IDENTITY,
DEFAULT_CATALOG,
"queryId-" + System.currentTimeMillis(),
TABLE_NAME,
res.getSchema(),
Split.newBuilder(splitLoc, keyFactory.create()).add(DOCDB_CONN_STR, CONNECTION_STRING).build(),
new Constraints(constraintsMap),
100_000_000_000L, //100GB don't expect this to spill
100_000_000_000L
);

RecordResponse rawResponse = handler.doReadRecords(allocator, request);

assertTrue(rawResponse instanceof ReadRecordsResponse);

ReadRecordsResponse response = (ReadRecordsResponse) rawResponse;
logger.info("doReadRecordsNoSpill: rows[{}]", response.getRecordCount());
logger.info("doReadRecordsNoSpill: {}", BlockUtils.rowToString(response.getRecords(), 0));
assertTrue(response.getRecordCount() == 1);
String expectedString = "[DbRef : {[_db : otherDb],[_ref : otherColl],[_id : " + id.toHexString() + "]}], [SimpleStruct : {[SomeSimpleStruct : someSimpleStruct]}]";
assertEquals(expectedString, BlockUtils.rowToString(response.getRecords(), 0));
}

private class ByteHolder
{
private byte[] bytes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package com.amazonaws.athena.connectors.docdb;

import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.mongodb.DBRef;
import com.mongodb.client.FindIterable;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
Expand All @@ -28,6 +29,7 @@
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.bson.Document;
import org.bson.types.ObjectId;
import org.junit.Test;

import java.util.ArrayList;
Expand Down Expand Up @@ -191,4 +193,46 @@ public void emptyListTest()
assertEquals(Types.MinorType.LIST, Types.getMinorTypeForArrowType(fields.get("col4").getType()));
assertEquals(Types.MinorType.VARCHAR, Types.getMinorTypeForArrowType(fields.get("col4").getChildren().get(0).getType()));
}

@Test
public void dbRefTet()
{
List<Document> docs = new ArrayList<>();
Document doc1 = new Document();
doc1.put("col1", 1);
doc1.put("col2", new DBRef("otherColl", ObjectId.get()));
docs.add(doc1);

Document doc2 = new Document();
doc2.put("col1", 1);
doc2.put("col2", new DBRef("otherDb", "otherColl", ObjectId.get()));
docs.add(doc2);

MongoClient mockClient = mock(MongoClient.class);
MongoDatabase mockDatabase = mock(MongoDatabase.class);
MongoCollection mockCollection = mock(MongoCollection.class);
FindIterable mockIterable = mock(FindIterable.class);
when(mockClient.getDatabase(any())).thenReturn(mockDatabase);
when(mockDatabase.getCollection(any())).thenReturn(mockCollection);
when(mockCollection.find()).thenReturn(mockIterable);
when(mockIterable.limit(anyInt())).thenReturn(mockIterable);
when(mockIterable.maxScan(anyInt())).thenReturn(mockIterable);
when(mockIterable.batchSize(anyInt())).thenReturn(mockIterable);
when(mockIterable.iterator()).thenReturn(new StubbingCursor(docs.iterator()));

Schema schema = SchemaUtils.inferSchema(mockClient, new TableName("test", "test"), 10);
assertEquals(2, schema.getFields().size());

Map<String, Field> fields = new HashMap<>();
schema.getFields().stream().forEach(next -> fields.put(next.getName(), next));

assertEquals(Types.MinorType.INT, Types.getMinorTypeForArrowType(fields.get("col1").getType()));
assertEquals(Types.MinorType.STRUCT, Types.getMinorTypeForArrowType(fields.get("col2").getType()));
assertEquals("_db", fields.get("col2").getChildren().get(0).getName());
assertEquals(Types.MinorType.VARCHAR, Types.getMinorTypeForArrowType(fields.get("col2").getChildren().get(0).getType()));
assertEquals("_ref", fields.get("col2").getChildren().get(1).getName());
assertEquals(Types.MinorType.VARCHAR, Types.getMinorTypeForArrowType(fields.get("col2").getChildren().get(1).getType()));
assertEquals("_id", fields.get("col2").getChildren().get(2).getName());
assertEquals(Types.MinorType.VARCHAR, Types.getMinorTypeForArrowType(fields.get("col2").getChildren().get(2).getType()));
}
}