Skip to content

Commit 491ee3e

Browse files
committed
Address PR Comments
This commit addresses comments from the open PR. Including: - Check for nulls in Injected c-tors and other public methods - Refactor Decoder and Encoder to simplify, avoiding interface and factory - Provide informative exception messages for unsupported properties - Ensure Trino codestyle followed (no final local variables) I also discovered and fixed a bug where null map values were not being encoded correctly.
1 parent 659eeea commit 491ee3e

File tree

12 files changed

+838
-875
lines changed

12 files changed

+838
-875
lines changed

Diff for: lib/trino-hive-formats/src/main/java/io/trino/hive/formats/ion/IonDecoder.java

+474-6
Large diffs are not rendered by default.

Diff for: lib/trino-hive-formats/src/main/java/io/trino/hive/formats/ion/IonDecoderConfig.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
* a trino column will be correctly typed or coerced for that column.
2727
* @param caseSensitive whether field name matching should be case-sensitive or not.
2828
*/
29-
public record IonDecoderConfig(Map<String, String> pathExtractors, Boolean strictTyping, Boolean caseSensitive)
29+
public record IonDecoderConfig(Map<String, String> pathExtractors, boolean strictTyping, boolean caseSensitive)
3030
{
3131
static IonDecoderConfig defaultConfig()
3232
{

Diff for: lib/trino-hive-formats/src/main/java/io/trino/hive/formats/ion/IonDecoderFactory.java

-487
This file was deleted.

Diff for: lib/trino-hive-formats/src/main/java/io/trino/hive/formats/ion/IonEncoder.java

+273-8
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,285 @@
1313
*/
1414
package io.trino.hive.formats.ion;
1515

16+
import com.amazon.ion.IonType;
1617
import com.amazon.ion.IonWriter;
18+
import com.amazon.ion.Timestamp;
19+
import com.google.common.collect.ImmutableList;
20+
import io.trino.hive.formats.line.Column;
1721
import io.trino.spi.Page;
22+
import io.trino.spi.block.ArrayBlock;
23+
import io.trino.spi.block.Block;
24+
import io.trino.spi.block.RowBlock;
25+
import io.trino.spi.block.SqlMap;
26+
import io.trino.spi.type.ArrayType;
27+
import io.trino.spi.type.BigintType;
28+
import io.trino.spi.type.BooleanType;
29+
import io.trino.spi.type.CharType;
30+
import io.trino.spi.type.DateType;
31+
import io.trino.spi.type.DecimalType;
32+
import io.trino.spi.type.DoubleType;
33+
import io.trino.spi.type.Int128;
34+
import io.trino.spi.type.IntegerType;
35+
import io.trino.spi.type.LongTimestamp;
36+
import io.trino.spi.type.MapType;
37+
import io.trino.spi.type.RealType;
38+
import io.trino.spi.type.RowType;
39+
import io.trino.spi.type.SmallintType;
40+
import io.trino.spi.type.TimestampType;
41+
import io.trino.spi.type.TinyintType;
42+
import io.trino.spi.type.Type;
43+
import io.trino.spi.type.VarbinaryType;
44+
import io.trino.spi.type.VarcharType;
1845

1946
import java.io.IOException;
47+
import java.math.BigDecimal;
48+
import java.math.RoundingMode;
49+
import java.nio.charset.StandardCharsets;
50+
import java.time.LocalDate;
51+
import java.time.ZoneId;
52+
import java.util.Date;
53+
import java.util.List;
54+
import java.util.Optional;
55+
import java.util.function.IntFunction;
2056

21-
public interface IonEncoder
57+
import static com.google.common.base.Preconditions.checkArgument;
58+
59+
/**
60+
* An IonEncoder encodes Pages of trino data to an IonWriter.
61+
*/
62+
public class IonEncoder
2263
{
64+
private final RowEncoder rowEncoder;
65+
66+
public IonEncoder(List<Column> columns)
67+
{
68+
rowEncoder = RowEncoder.forFields(columns.stream()
69+
.map(c -> new RowType.Field(Optional.of(c.name()), c.type()))
70+
.toList());
71+
}
72+
2373
/**
24-
* Encodes the Page into the IonWriter provided.
25-
* <p>
26-
* Will flush() the writer after encoding the page.
27-
* Expects that the calling code is responsible for closing
28-
* the writer after all pages are written.
74+
* Encode the page of data to the IonWriter.
75+
* <br>
76+
* IonWriter.flush() will be called after the page is written.
2977
*/
30-
void encode(IonWriter writer, Page page)
31-
throws IOException;
78+
public void encode(IonWriter writer, Page page)
79+
throws IOException
80+
{
81+
for (int i = 0; i < page.getPositionCount(); i++) {
82+
rowEncoder.encodeStruct(writer, page::getBlock, i);
83+
}
84+
// todo: consider decoupling ion writer flushes from page sizes.
85+
writer.flush();
86+
}
87+
88+
private interface BlockEncoder
89+
{
90+
void encode(IonWriter writer, Block block, int position)
91+
throws IOException;
92+
}
93+
94+
private static BlockEncoder encoderForType(Type type)
95+
{
96+
BlockEncoder encoder = switch (type) {
97+
case TinyintType _ -> BYTE_ENCODER;
98+
case SmallintType _ -> SHORT_ENCODER;
99+
case IntegerType _ -> INT_ENCODER;
100+
case BigintType _ -> LONG_ENCODER;
101+
case BooleanType _ -> BOOL_ENCODER;
102+
case VarbinaryType _ -> BINARY_ENCODER;
103+
case RealType _ -> REAL_ENCODER;
104+
case DoubleType _ -> DOUBLE_ENCODER;
105+
case VarcharType _, CharType _ -> STRING_ENCODER;
106+
case DateType _ -> DATE_ENCODER;
107+
case DecimalType t -> decimalEncoder(t);
108+
case TimestampType t -> timestampEncoder(t);
109+
case RowType t -> RowEncoder.forFields(t.getFields());
110+
case MapType t -> new MapEncoder(t, t.getKeyType(), encoderForType(t.getValueType()));
111+
case ArrayType t -> new ArrayEncoder(encoderForType(t.getElementType()));
112+
default -> throw new IllegalArgumentException(String.format("Unsupported type: %s", type));
113+
};
114+
return wrapEncoder(encoder);
115+
}
116+
117+
private static BlockEncoder wrapEncoder(BlockEncoder encoder)
118+
{
119+
return (writer, block, position) ->
120+
{
121+
if (block.isNull(position)) {
122+
writer.writeNull();
123+
}
124+
else {
125+
encoder.encode(writer, block, position);
126+
}
127+
};
128+
}
129+
130+
private record RowEncoder(List<String> fieldNames, List<BlockEncoder> fieldEncoders)
131+
implements BlockEncoder
132+
{
133+
private static RowEncoder forFields(List<RowType.Field> fields)
134+
{
135+
ImmutableList.Builder<String> fieldNamesBuilder = ImmutableList.builder();
136+
ImmutableList.Builder<BlockEncoder> fieldEncodersBuilder = ImmutableList.builder();
137+
138+
for (RowType.Field field : fields) {
139+
fieldNamesBuilder.add(field.getName().get());
140+
fieldEncodersBuilder.add(encoderForType(field.getType()));
141+
}
142+
143+
return new RowEncoder(fieldNamesBuilder.build(), fieldEncodersBuilder.build());
144+
}
145+
146+
@Override
147+
public void encode(IonWriter writer, Block block, int position)
148+
throws IOException
149+
{
150+
encodeStruct(writer, ((RowBlock) block)::getFieldBlock, position);
151+
}
152+
153+
// used for encoding 'top-level' rows by the IonEncoder
154+
private void encodeStruct(IonWriter writer, IntFunction<Block> blockSelector, int position)
155+
throws IOException
156+
{
157+
writer.stepIn(IonType.STRUCT);
158+
for (int i = 0; i < fieldEncoders.size(); i++) {
159+
// fields are omitted by default, as was true in the hive serde.
160+
// there is an unimplemented hive legacy property of `ion.serialize_null`
161+
// that could be used to specify typed or untyped ion nulls instead.
162+
Block block = blockSelector.apply(i);
163+
if (block.isNull(position)) {
164+
continue;
165+
}
166+
writer.setFieldName(fieldNames.get(i));
167+
fieldEncoders.get(i)
168+
.encode(writer, block, position);
169+
}
170+
writer.stepOut();
171+
}
172+
}
173+
174+
private record MapEncoder(MapType mapType, Type keyType, BlockEncoder encoder)
175+
implements BlockEncoder
176+
{
177+
public MapEncoder(MapType mapType, Type keyType, BlockEncoder encoder)
178+
{
179+
this.mapType = mapType;
180+
if (!(keyType instanceof VarcharType _ || keyType instanceof CharType _)) {
181+
throw new UnsupportedOperationException("Unsupported map key type: " + keyType);
182+
}
183+
this.keyType = keyType;
184+
this.encoder = encoder;
185+
}
186+
187+
@Override
188+
public void encode(IonWriter writer, Block block, int position)
189+
throws IOException
190+
{
191+
SqlMap sqlMap = mapType.getObject(block, position);
192+
int rawOffset = sqlMap.getRawOffset();
193+
Block rawKeyBlock = sqlMap.getRawKeyBlock();
194+
Block rawValueBlock = sqlMap.getRawValueBlock();
195+
196+
writer.stepIn(IonType.STRUCT);
197+
for (int i = 0; i < sqlMap.getSize(); i++) {
198+
checkArgument(!rawKeyBlock.isNull(rawOffset + i), "map key is null");
199+
writer.setFieldName(VarcharType.VARCHAR.getSlice(rawKeyBlock, rawOffset + i).toString(StandardCharsets.UTF_8));
200+
encoder.encode(writer, rawValueBlock, rawOffset + i);
201+
}
202+
writer.stepOut();
203+
}
204+
}
205+
206+
private record ArrayEncoder(BlockEncoder elementEncoder)
207+
implements BlockEncoder
208+
{
209+
@Override
210+
public void encode(IonWriter writer, Block block, int position)
211+
throws IOException
212+
{
213+
writer.stepIn(IonType.LIST);
214+
Block elementBlock = ((ArrayBlock) block).getArray(position);
215+
for (int i = 0; i < elementBlock.getPositionCount(); i++) {
216+
elementEncoder.encode(writer, elementBlock, i);
217+
}
218+
writer.stepOut();
219+
}
220+
}
221+
222+
private static BlockEncoder timestampEncoder(TimestampType type)
223+
{
224+
if (type.isShort()) {
225+
return (writer, block, position) -> {
226+
long epochMicros = type.getLong(block, position);
227+
BigDecimal decimalMillis = BigDecimal.valueOf(epochMicros)
228+
.movePointLeft(3)
229+
.setScale(type.getPrecision() - 3, RoundingMode.UNNECESSARY);
230+
231+
writer.writeTimestamp(Timestamp.forMillis(decimalMillis, 0));
232+
};
233+
}
234+
else {
235+
return (writer, block, position) -> {
236+
LongTimestamp longTimestamp = (LongTimestamp) type.getObject(block, position);
237+
BigDecimal picosOfMicros = BigDecimal.valueOf(longTimestamp.getPicosOfMicro())
238+
.movePointLeft(9);
239+
BigDecimal decimalMillis = BigDecimal.valueOf(longTimestamp.getEpochMicros())
240+
.movePointLeft(3)
241+
.add(picosOfMicros)
242+
.setScale(type.getPrecision() - 3, RoundingMode.UNNECESSARY);
243+
244+
writer.writeTimestamp(Timestamp.forMillis(decimalMillis, 0));
245+
};
246+
}
247+
}
248+
249+
private static BlockEncoder decimalEncoder(DecimalType type)
250+
{
251+
if (type.isShort()) {
252+
return (writer, block, position) -> {
253+
writer.writeDecimal(BigDecimal.valueOf(type.getLong(block, position), type.getScale()));
254+
};
255+
}
256+
else {
257+
return (writer, block, position) -> {
258+
writer.writeDecimal(new BigDecimal(((Int128) type.getObject(block, position)).toBigInteger(), type.getScale()));
259+
};
260+
}
261+
}
262+
263+
private static final BlockEncoder BYTE_ENCODER = (writer, block, position) ->
264+
writer.writeInt(TinyintType.TINYINT.getLong(block, position));
265+
266+
private static final BlockEncoder SHORT_ENCODER = (writer, block, position) ->
267+
writer.writeInt(SmallintType.SMALLINT.getLong(block, position));
268+
269+
private static final BlockEncoder INT_ENCODER = (writer, block, position) ->
270+
writer.writeInt(IntegerType.INTEGER.getInt(block, position));
271+
272+
private static final BlockEncoder STRING_ENCODER = (writer, block, position) ->
273+
writer.writeString(VarcharType.VARCHAR.getSlice(block, position).toString(StandardCharsets.UTF_8));
274+
275+
private static final BlockEncoder BOOL_ENCODER = (writer, block, position) ->
276+
writer.writeBool(BooleanType.BOOLEAN.getBoolean(block, position));
277+
278+
private static final BlockEncoder BINARY_ENCODER = (writer, block, position) ->
279+
writer.writeBlob(VarbinaryType.VARBINARY.getSlice(block, position).getBytes());
280+
281+
private static final BlockEncoder LONG_ENCODER = (writer, block, position) ->
282+
writer.writeInt(BigintType.BIGINT.getLong(block, position));
283+
284+
private static final BlockEncoder REAL_ENCODER = (writer, block, position) ->
285+
writer.writeFloat(RealType.REAL.getFloat(block, position));
286+
287+
private static final BlockEncoder DOUBLE_ENCODER = (writer, block, position) ->
288+
writer.writeFloat(DoubleType.DOUBLE.getDouble(block, position));
289+
290+
private static final BlockEncoder DATE_ENCODER = (writer, block, position) ->
291+
writer.writeTimestamp(
292+
Timestamp.forDateZ(
293+
Date.from(
294+
LocalDate.ofEpochDay(DateType.DATE.getLong(block, position))
295+
.atStartOfDay(ZoneId.of("UTC"))
296+
.toInstant())));
32297
}

0 commit comments

Comments
 (0)