Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.druid.query.aggregation.datasketches.hll;

import org.apache.datasketches.hll.HllSketch;
import org.apache.datasketches.hll.TgtHllType;
import org.apache.druid.java.util.common.StringEncoding;
import org.apache.druid.segment.serde.ComplexMetricSerde;
import org.apache.druid.segment.serde.ComplexMetrics;
import org.junit.Assert;
import org.junit.Test;

public class HllSketchBuildComplexMetricSerdeTest
{

@Test
public void testComplexSerdeToBytesOnRealtimeSegmentSketch()
{
ComplexMetrics.registerSerde(HllSketchModule.BUILD_TYPE_NAME, new HllSketchBuildComplexMetricSerde());
ComplexMetricSerde serde = ComplexMetrics.getSerdeForType(HllSketchModule.BUILD_TYPE_NAME);
Assert.assertNotNull(serde);
HllSketch sketchNew = new HllSketch(14, TgtHllType.HLL_8);
HllSketchBuildUtil.updateSketch(sketchNew, StringEncoding.UTF16LE, new int[]{1, 2});

HllSketchHolder sketchHolder = HllSketchHolder.of(sketchNew);

byte[] bytes = serde.toBytes(sketchHolder);
Assert.assertEquals(bytes, serde.toBytes(bytes));

HllSketchHolder fromBytesHolder = (HllSketchHolder) serde.fromBytes(bytes, 0, bytes.length);

Assert.assertEquals(sketchHolder.getSketch().getLgConfigK(), fromBytesHolder.getSketch().getLgConfigK());
Assert.assertEquals(sketchHolder.getSketch().getTgtHllType(), fromBytesHolder.getSketch().getTgtHllType());
Assert.assertEquals(
sketchHolder.getSketch().getCompactSerializationBytes(),
fromBytesHolder.getSketch().getCompactSerializationBytes()
);
Assert.assertEquals(
sketchHolder.getSketch().getUpdatableSerializationBytes(),
fromBytesHolder.getSketch().getUpdatableSerializationBytes()
);
Assert.assertEquals(sketchHolder.getSketch().getEstimate(), fromBytesHolder.getSketch().getEstimate(), 0);
Assert.assertEquals(sketchHolder.getSketch().getLowerBound(1), fromBytesHolder.getSketch().getLowerBound(1), 0);
Assert.assertEquals(sketchHolder.getSketch().getUpperBound(1), fromBytesHolder.getSketch().getUpperBound(1), 0);

// During fromBytes() sketches field memory changes to TRUE
Assert.assertFalse(sketchHolder.getSketch().isMemory());
Assert.assertTrue(fromBytesHolder.getSketch().isMemory());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.druid.error.DruidException;
import org.apache.druid.error.DruidExceptionMatcher;
import org.apache.druid.frame.util.DurableStorageUtils;
import org.apache.druid.hll.HyperLogLogCollector;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Intervals;
Expand Down Expand Up @@ -64,6 +65,7 @@
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.cardinality.CardinalityAggregatorFactory;
import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
Expand All @@ -90,13 +92,15 @@
import org.hamcrest.CoreMatchers;
import org.junit.Assert;
import org.junit.internal.matchers.ThrowableMessageMatcher;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;

import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand All @@ -105,6 +109,8 @@
import java.util.List;
import java.util.Map;

import static org.apache.druid.sql.calcite.util.CalciteTests.SOME_DATASOURCE;

public class MSQSelectTest extends MSQTestBase
{

Expand Down Expand Up @@ -244,6 +250,78 @@ public void testSelectOnFoo(String contextName, Map<String, Object> context)
.verifyResults();
}

@MethodSource("data")
@ParameterizedTest(name = "{index}:with context {0}")
public void testSelectOnFooWithComplex(String contextName, Map<String, Object> context)
{
RowSignature resultSignature = RowSignature.builder()
.add("cnt", ColumnType.LONG)
.add("unique_dim1", HyperUniquesAggregatorFactory.TYPE)
.build();
final Map<String, Object> contextToPut = new HashMap<>(context);
contextToPut.put("includeSegmentSource", "REALTIME");

final Map<String, Object> queryContext = new HashMap<>(context);
queryContext.put("includeSegmentSource", "REALTIME");

testSelectQuery()
.setSql("select cnt,unique_dim1 from some_datasource")
.setExpectedMSQSpec(
LegacyMSQSpec.builder()
.query(
newScanQueryBuilder()
.dataSource(SOME_DATASOURCE)
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("cnt", "unique_dim1")
.columnTypes(List.of(ColumnType.LONG, HyperUniquesAggregatorFactory.TYPE))
.context(contextToPut)
.build()
)
.columnMappings(ColumnMappings.identity(resultSignature))
.tuningConfig(MSQTuningConfig.defaultConfig())
.destination(isDurableStorageDestination(contextName, context)
? DurableStorageMSQDestination.INSTANCE
: TaskReportMSQDestination.INSTANCE)
.build()
)
.setQueryContext(queryContext)
.setExpectedRowSignature(resultSignature)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().totalFiles(1),
0, 0, "input0"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().rows(2).frames(1),
0, 0, "output"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with()
.rows(2)
.frames(1),
0, 0, "shuffle"
)
.setExpectedResultRows(ImmutableList.of(
new Object[]{7L, "\"AQAAAQAAAAEkAQ==\""},
new Object[]{8L, "\"AQAAAgAAAAAoYAEkAQ==\""}
))
.setExpectedLookupLoadingSpec(LookupLoadingSpec.NONE)
.verifyResults();

HyperLogLogCollector expectedCollector1 = HyperLogLogCollector.makeLatestCollector()
.fold(ByteBuffer.wrap(StringUtils.decodeBase64String(
"AQAAAQAAAAEkAQ==")));

HyperLogLogCollector expectedCollector2 = HyperLogLogCollector.makeLatestCollector()
.fold(ByteBuffer.wrap(StringUtils.decodeBase64String(
"AQAAAgAAAAAoYAEkAQ==")));

Assertions.assertEquals(1, expectedCollector1.estimateCardinality(), 0.01);
Assertions.assertEquals(2, expectedCollector2.estimateCardinality(), 0.01);
}

@MethodSource("data")
@ParameterizedTest(name = "{index}:with context {0}")
public void testSelectOnFoo2(String contextName, Map<String, Object> context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.druid.client.ImmutableSegmentLoadInfo;
import org.apache.druid.collections.ReferenceCountingResourceHolder;
import org.apache.druid.collections.ResourceHolder;
import org.apache.druid.data.input.InputRow;
import org.apache.druid.data.input.impl.DimensionsSpec;
import org.apache.druid.data.input.impl.LongDimensionSchema;
import org.apache.druid.data.input.impl.StringDimensionSchema;
Expand All @@ -60,6 +61,7 @@
import org.apache.druid.guice.annotations.EscalatedGlobal;
import org.apache.druid.guice.annotations.Self;
import org.apache.druid.hll.HyperLogLogCollector;
import org.apache.druid.hll.HyperLogLogHash;
import org.apache.druid.indexing.common.SegmentCacheManagerFactory;
import org.apache.druid.indexing.common.task.CompactionTask;
import org.apache.druid.indexing.common.task.IndexTask;
Expand All @@ -72,6 +74,7 @@
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.java.util.emitter.EmittingLogger;
Expand Down Expand Up @@ -143,9 +146,12 @@
import org.apache.druid.segment.IndexIO;
import org.apache.druid.segment.QueryableIndex;
import org.apache.druid.segment.QueryableIndexCursorFactory;
import org.apache.druid.segment.RowAdapters;
import org.apache.druid.segment.Segment;
import org.apache.druid.segment.TestSegmentUtils;
import org.apache.druid.segment.column.ColumnConfig;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
import org.apache.druid.segment.loading.DataSegmentPusher;
Expand Down Expand Up @@ -244,6 +250,7 @@
import static org.apache.druid.sql.calcite.util.CalciteTests.DATASOURCE1;
import static org.apache.druid.sql.calcite.util.CalciteTests.DATASOURCE2;
import static org.apache.druid.sql.calcite.util.CalciteTests.RESTRICTED_DATASOURCE;
import static org.apache.druid.sql.calcite.util.CalciteTests.SOME_DATASOURCE;
import static org.apache.druid.sql.calcite.util.CalciteTests.WIKIPEDIA;
import static org.apache.druid.sql.calcite.util.TestDataBuilder.ROWS1;
import static org.apache.druid.sql.calcite.util.TestDataBuilder.ROWS2;
Expand Down Expand Up @@ -452,7 +459,7 @@ public void tearDown2()
// which depends on the object mapper that the injector will provide, once it
// is built, but has not yet been build while we build the SQL engine.
@BeforeEach
public void setUp2() throws Exception
public void setUp2()
{
groupByBuffers = TestGroupByBuffers.createDefault();

Expand Down Expand Up @@ -678,6 +685,7 @@ protected Supplier<ResourceHolder<CompleteSegment>> getSupplierForSegment(
{
if (segmentManager.getSegment(segmentId) == null) {
final QueryableIndex index;
Optional<Segment> segment = Optional.empty();
switch (segmentId.getDataSource()) {
case DATASOURCE1:
case RESTRICTED_DATASOURCE: // RESTRICTED_DATASOURCE share the same index as DATASOURCE1.
Expand Down Expand Up @@ -728,41 +736,86 @@ protected Supplier<ResourceHolder<CompleteSegment>> getSupplierForSegment(
case WIKIPEDIA:
index = TestDataBuilder.makeWikipediaIndex(newTempFolder());
break;
case SOME_DATASOURCE:
index = null;
HyperLogLogCollector collector1 = HyperLogLogCollector.makeLatestCollector();
collector1.add(HyperLogLogHash.getDefault().hash("abc"));
HyperLogLogCollector collector2 = HyperLogLogCollector.makeLatestCollector();
collector2.add(HyperLogLogHash.getDefault().hash("abc"));
collector2.add(HyperLogLogHash.getDefault().hash("cde"));
final List<ImmutableMap<String, Object>> rtRawRows = ImmutableList.of(
ImmutableMap.<String, Object>builder()
.put("t", "2000-01-01")
.put("m1", "1.0")
.put("m2", "1.0")
.put("dim1", "")
.put("dim2", ImmutableList.of("a"))
.put("dim3", ImmutableList.of("a", "b"))
.put("cnt", 7)
.put("unique_dim1", collector1.toByteArray())
.build(),
ImmutableMap.<String, Object>builder()
.put("t", "2000-01-02")
.put("m1", "2.0")
.put("m2", "2.0")
.put("dim1", "10.1")
.put("dim2", ImmutableList.of())
.put("dim3", ImmutableList.of("b", "c"))
.put("cnt", 8)
.put("unique_dim1", collector2.toByteArray())
.build()
);
final List<InputRow> rtRows =
rtRawRows.stream().map(TestDataBuilder::createRow).collect(Collectors.toList());

segment = Optional.of(new TestSegmentUtils.InMemoryTestSegment<>(
segmentId,
Sequences.simple(rtRows),
RowAdapters.standardRow(),
RowSignature.builder()
.add("cnt", ColumnType.LONG)
.add("unique_dim1", HyperUniquesAggregatorFactory.TYPE)
.build()
)
);
break;
default:
throw new ISE("Cannot query segment %s in test runner", segmentId);
}
Segment segment = new Segment()
{
@Override
public SegmentId getId()
if (segment.isEmpty()) {
segment = Optional.of(new Segment()
{
return segmentId;
}
@Override
public SegmentId getId()
{
return segmentId;
}

@Override
public Interval getDataInterval()
{
return segmentId.getInterval();
}
@Override
public Interval getDataInterval()
{
return segmentId.getInterval();
}

@Nullable
@Override
public <T> T as(@Nonnull Class<T> clazz)
{
if (CursorFactory.class.equals(clazz)) {
return (T) new QueryableIndexCursorFactory(index);
} else if (QueryableIndex.class.equals(clazz)) {
return (T) index;
@Nullable
@Override
public <T> T as(@Nonnull Class<T> clazz)
{
if (CursorFactory.class.equals(clazz)) {
return (T) new QueryableIndexCursorFactory(index);
} else if (QueryableIndex.class.equals(clazz)) {
return (T) index;
}
return null;
}
return null;
}

@Override
public void close()
{
}
};
segmentManager.addSegment(segment);
@Override
public void close()
{
}
});
}
segmentManager.addSegment(segment.get());
}
DataSegment dataSegment = DataSegment.builder()
.dataSource(segmentId.getDataSource())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ public interface ObjectStrategy<T> extends Comparator<T>
@Nullable
byte[] toBytes(@Nullable T val);

/**
* Additional method to avoid ClassCastException when {@link ObjectStrategy#toBytes(T)} is called with a byte[] as a
* parameter. Class casting is performed explicitly to prevent unexpected behavior.
*/
@Nullable
default byte[] objectToBytes(@Nullable Object val)
{
if (val instanceof byte[]) {
return (byte[]) val;
} else {
return toBytes((T) val);
}
}

/**
* Whether {@link #compare} is valid or not.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
public byte[] toBytes(@Nullable Object val)
{
if (val != null) {
byte[] bytes = getObjectStrategy().toBytes(val);
byte[] bytes = getObjectStrategy().objectToBytes(val);
return bytes != null ? bytes : ByteArrays.EMPTY_ARRAY;
} else {
return ByteArrays.EMPTY_ARRAY;
Expand All @@ -116,7 +116,7 @@
if (start > 0) {
bb.position(start);
}
return getObjectStrategy().fromByteBuffer(bb, numBytes);

Check notice

Code scanning / CodeQL

Deprecated method or constructor invocation Note

Invoking
ComplexMetricSerde.getObjectStrategy
should be avoided because it has been deprecated.
}

/**
Expand Down
Loading