diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java index 758ce44483dd..30281e1e2beb 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java @@ -57,7 +57,6 @@ import java.util.Objects; import java.util.stream.Collectors; import java.util.stream.Stream; -import org.apache.commons.lang3.ArrayUtils; import org.apache.kafka.connect.data.Schema; import org.codehaus.commons.compiler.CompileException; import org.codehaus.commons.compiler.CompilerFactoryFactory; @@ -164,8 +163,7 @@ public CompiledExpression buildCodeGenFromParseTree( final Class expressionType = SQL_TO_JAVA_TYPE_CONVERTER.toJavaType(returnType); - final IExpressionEvaluator ee = - cook(javaCode, expressionType, spec.argumentNames(), spec.argumentTypes()); + final IExpressionEvaluator ee = cook(javaCode, expressionType); return new CompiledExpression(ee, spec, returnType, expression); } catch (KsqlException | CompileException e) { @@ -185,17 +183,15 @@ public CompiledExpression buildCodeGenFromParseTree( @VisibleForTesting public static IExpressionEvaluator cook( final String javaCode, - final Class expressionType, - final String[] argNames, - final Class[] argTypes + final Class expressionType ) throws Exception { final IExpressionEvaluator ee = CompilerFactoryFactory.getDefaultCompilerFactory() .newExpressionEvaluator(); ee.setDefaultImports(SqlToJavaVisitor.JAVA_IMPORTS.toArray(new String[0])); ee.setParameters( - ArrayUtils.addAll(argNames, "defaultValue", "logger", "row"), - ArrayUtils.addAll(argTypes, Object.class, ProcessingLogger.class, GenericRow.class) + new String[]{"arguments", "defaultValue", "logger", "row"}, + new Class[]{Map.class, Object.class, ProcessingLogger.class, GenericRow.class} ); ee.setExpressionType(expressionType); ee.cook(javaCode); diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenSpec.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenSpec.java index 598432ed0c9d..bdc768dffa2e 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenSpec.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenSpec.java @@ -56,14 +56,6 @@ private CodeGenSpec( this.structToCodeName = structToCodeName; } - public String[] argumentNames() { - return arguments.stream().map(ArgumentSpec::name).toArray(String[]::new); - } - - public Class[] argumentTypes() { - return arguments.stream().map(ArgumentSpec::type).toArray(Class[]::new); - } - @SuppressFBWarnings(value = "EI_EXPOSE_REP", justification = "arguments is ImmutableList") public List arguments() { return arguments; @@ -81,10 +73,14 @@ public String getUniqueNameForFunction(final FunctionName functionName, final in return names.get(index); } - public void resolve(final GenericRow row, final Object[] parameters) { + public Map resolveArguments(final GenericRow row) { + final Map resolvedArguments = new HashMap<>(arguments.size()); for (int paramIdx = 0; paramIdx < arguments.size(); paramIdx++) { - parameters[paramIdx] = arguments.get(paramIdx).resolve(row); + final String name = arguments.get(paramIdx).name(); + final Object value = arguments.get(paramIdx).resolve(row); + resolvedArguments.put(name, value); } + return resolvedArguments; } public String getStructSchemaName(final CreateStructExpression createStructExpression) { diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenUtil.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenUtil.java index 9213bf60b8fb..72ab581fb991 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenUtil.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenUtil.java @@ -16,6 +16,8 @@ package io.confluent.ksql.execution.codegen; import io.confluent.ksql.name.FunctionName; +import io.confluent.ksql.schema.ksql.SchemaConverters; +import io.confluent.ksql.schema.ksql.types.SqlType; public final class CodeGenUtil { @@ -37,4 +39,15 @@ public static String functionName(final FunctionName fun, final int index) { return fun.text() + "_" + index; } + public static String argumentAccessor(final String name, + final SqlType type) { + final Class javaType = SchemaConverters.sqlToJavaConverter().toJavaType(type); + return argumentAccessor(name, javaType); + } + + public static String argumentAccessor(final String name, + final Class type) { + return String.format("((%s) arguments.get(\"%s\"))", type.getCanonicalName(), name); + } + } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CompiledExpression.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CompiledExpression.java index 467836b5dc5d..e31810ad2403 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CompiledExpression.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CompiledExpression.java @@ -28,7 +28,6 @@ import java.util.List; import java.util.Objects; import java.util.function.Supplier; -import org.apache.commons.lang3.ArrayUtils; import org.codehaus.commons.compiler.IExpressionEvaluator; @Immutable @@ -37,7 +36,6 @@ public class CompiledExpression implements ExpressionEvaluator { @EffectivelyImmutable private final IExpressionEvaluator expressionEvaluator; private final SqlType expressionType; - private final ThreadLocal threadLocalParameters; private final Expression expression; private final CodeGenSpec spec; @@ -51,7 +49,6 @@ public CompiledExpression( this.expressionType = Objects.requireNonNull(expressionType, "expressionType"); this.expression = Objects.requireNonNull(expression, "expression"); this.spec = Objects.requireNonNull(spec, "spec"); - this.threadLocalParameters = ThreadLocal.withInitial(() -> new Object[spec.arguments().size()]); } public List arguments() { @@ -85,8 +82,12 @@ public Object evaluate( final Supplier errorMsg ) { try { - return expressionEvaluator.evaluate( - ArrayUtils.addAll(getParameters(row), defaultValue, logger, row)); + return expressionEvaluator.evaluate(new Object[]{ + spec.resolveArguments(row), + defaultValue, + logger, + row + }); } catch (final Exception e) { final Throwable cause = e instanceof InvocationTargetException ? e.getCause() @@ -96,10 +97,4 @@ public Object evaluate( return defaultValue; } } - - private Object[] getParameters(final GenericRow row) { - final Object[] parameters = threadLocalParameters.get(); - spec.resolve(row, parameters); - return parameters; - } } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java index 7bcb1d58f334..a14a6a0c1fe3 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java @@ -89,6 +89,7 @@ import io.confluent.ksql.function.types.ArrayType; import io.confluent.ksql.function.types.ParamType; import io.confluent.ksql.function.types.ParamTypes; +import io.confluent.ksql.function.udf.Kudf; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.schema.Operator; @@ -124,6 +125,7 @@ import java.util.stream.Collectors; import org.apache.commons.lang3.StringEscapeUtils; import org.apache.commons.lang3.StringUtils; +import org.apache.kafka.connect.data.Schema; import org.apache.kafka.connect.data.SchemaBuilder; import org.apache.kafka.connect.data.Struct; @@ -465,7 +467,9 @@ public Pair visitUnqualifiedColumnReference( .orElseThrow(() -> new KsqlException("Field not found: " + node.getColumnName())); - return new Pair<>(colRefToCodeName.apply(fieldName), schemaColumn.type()); + final String codeName = colRefToCodeName.apply(fieldName); + final String paramAccessor = CodeGenUtil.argumentAccessor(codeName, schemaColumn.type()); + return new Pair<>(paramAccessor, schemaColumn.type()); } @Override @@ -515,6 +519,7 @@ public Pair visitFunctionCall( ) { final FunctionName functionName = node.getName(); final String instanceName = funNameToCodeName.apply(functionName); + final String functionAccessor = CodeGenUtil.argumentAccessor(instanceName, Kudf.class); final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName()); final FunctionTypeInfo argumentsAndContext = FunctionArgumentsUtil .getFunctionTypeInfo( @@ -561,7 +566,7 @@ public Pair visitFunctionCall( } final String argumentsString = joiner.toString(); - final String codeString = "((" + javaReturnType + ") " + instanceName + final String codeString = "((" + javaReturnType + ") " + functionAccessor + ".evaluate(" + argumentsString + "))"; return new Pair<>(codeString, returnType); } @@ -1165,7 +1170,10 @@ public Pair visitStructExpression( final Context context ) { final String schemaName = structToCodeName.apply(node); - final StringBuilder struct = new StringBuilder("new Struct(").append(schemaName).append(")"); + final String schemaAccessor = CodeGenUtil.argumentAccessor(schemaName, Schema.class); + final StringBuilder struct = new StringBuilder("new Struct(") + .append(schemaAccessor) + .append(")"); for (final Field field : node.getFields()) { struct.append(".put(") .append('"') diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/CodeGenTestUtil.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/CodeGenTestUtil.java index 6e7218152606..a44a09738411 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/CodeGenTestUtil.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/CodeGenTestUtil.java @@ -2,13 +2,9 @@ import static java.util.Objects.requireNonNull; -import com.google.common.collect.ImmutableList; -import io.confluent.ksql.GenericRow; -import io.confluent.ksql.logging.processing.ProcessingLogger; import java.lang.reflect.InvocationTargetException; import java.util.Collections; -import java.util.List; -import org.apache.commons.lang3.ArrayUtils; +import java.util.Map; import org.codehaus.commons.compiler.IExpressionEvaluator; public final class CodeGenTestUtil { @@ -22,77 +18,27 @@ public static Object cookAndEval( return cookAndEval( javaCode, resultType, - ImmutableList.of(), - ImmutableList.of(), - ImmutableList.of() + Collections.emptyMap() ); } public static Object cookAndEval( final String javaCode, final Class resultType, - final String argName, - final Class argType, - final Object arg + final Map args ) { - return cookAndEval( - javaCode, - resultType, - ImmutableList.of(argName), - ImmutableList.of(argType), - Collections.singletonList(arg) - ); - } - - public static Object cookAndEval( - final String javaCode, - final Class resultType, - final List argNames, - final List> argTypes, - final List args - ) { - final Evaluator evaluator = CodeGenTestUtil.cookCode(javaCode, resultType, argNames, argTypes); + final Evaluator evaluator = CodeGenTestUtil.cookCode(javaCode, resultType); return evaluator.evaluate(args); } public static Evaluator cookCode( final String javaCode, final Class resultType - ) { - return cookCode( - javaCode, - resultType, - ImmutableList.of(), - ImmutableList.of() - ); - } - - public static Evaluator cookCode( - final String javaCode, - final Class resultType, - final String argName, - final Class argType - ) { - return cookCode( - javaCode, - resultType, - ImmutableList.of(argName), - ImmutableList.of(argType) - ); - } - - public static Evaluator cookCode( - final String javaCode, - final Class resultType, - final List argNames, - final List> argTypes ) { try { final IExpressionEvaluator ee = CodeGenRunner.cook( javaCode, - resultType, - argNames.toArray(new String[0]), - argTypes.toArray(new Class[0]) + resultType ); return new Evaluator(ee, javaCode); @@ -116,11 +62,15 @@ public Evaluator(final IExpressionEvaluator ee, final String javaCode) { this.javaCode = requireNonNull(javaCode, "javaCode"); } - public Object evaluate(final Object arg) { - return evaluate(Collections.singletonList(arg)); + public Object evaluate() { + return evaluate(Collections.emptyMap()); + } + + public Object evaluate(final String argName, final Object argValue) { + return evaluate(Collections.singletonMap(argName, argValue)); } - public Object evaluate(final List args) { + public Object evaluate(final Map args) { try { return rawEvaluate(args); } catch (final Exception e) { @@ -133,13 +83,13 @@ public Object evaluate(final List args) { } } - public Object rawEvaluate(final Object arg) throws Exception { - return rawEvaluate(Collections.singletonList(arg)); + public Object rawEvaluate(final String argName, final Object argValue) throws Exception { + return rawEvaluate(Collections.singletonMap(argName, argValue)); } - public Object rawEvaluate(final List args) throws Exception { + public Object rawEvaluate(final Map args) throws Exception { try { - return ee.evaluate(ArrayUtils.addAll(args == null ? new Object[]{null} : args.toArray(), null, null, null)); + return ee.evaluate(new Object[]{args, null, null, null}); } catch (final InvocationTargetException e) { throw e.getTargetException() instanceof Exception ? (Exception) e.getTargetException() diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/CompiledExpressionTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/CompiledExpressionTest.java index c402ee62cfd7..f2525eadf6c9 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/CompiledExpressionTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/CompiledExpressionTest.java @@ -15,13 +15,14 @@ import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.function.udf.Kudf; import io.confluent.ksql.logging.processing.ProcessingLogger; -import io.confluent.ksql.logging.processing.ProcessingLogger.ErrorMessage; import io.confluent.ksql.logging.processing.RecordProcessingError; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.ksql.types.SqlTypes; import java.lang.reflect.InvocationTargetException; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; @@ -90,7 +91,12 @@ public void shouldEvaluateExpressionWithValueColumnSpecs() throws Exception { // Then: assertThat(result, equalTo(RETURN_VALUE)); - verify(expressionEvaluator).evaluate(new Object[]{123, 456, DEFAULT_VAL, processingLogger, genericRow(123, 456)}); + + final Map arguments = new HashMap<>(); + arguments.put("var0", 123); + arguments.put("var1", 456); + + verify(expressionEvaluator).evaluate(new Object[]{arguments, DEFAULT_VAL, processingLogger, genericRow(123, 456)}); } @Test @@ -119,7 +125,12 @@ public void shouldEvaluateExpressionWithUdfsSpecs() throws Exception { // Then: assertThat(result, equalTo(RETURN_VALUE)); - verify(expressionEvaluator).evaluate(new Object[]{udf, 123, DEFAULT_VAL, processingLogger, genericRow(123)}); + + final Map arguments = new HashMap<>(); + arguments.put("var1", 123); + arguments.put("foo_0", udf); + + verify(expressionEvaluator).evaluate(new Object[]{arguments, DEFAULT_VAL, processingLogger, genericRow(123)}); } @Test @@ -139,7 +150,17 @@ public void shouldPerformThreadSafeParameterEvaluation() throws Exception { final CountDownLatch threadLatch = new CountDownLatch(1); final CountDownLatch mainLatch = new CountDownLatch(1); - when(expressionEvaluator.evaluate(new Object[]{123, 456, DEFAULT_VAL, processingLogger, genericRow(123, 456)})) + final Map arguments1 = new HashMap() {{ + put("var0", 123); + put("var1", 456); + }}; + + final Map arguments2 = new HashMap() {{ + put("var0", 100); + put("var1", 200); + }}; + + when(expressionEvaluator.evaluate(new Object[]{arguments1, DEFAULT_VAL, processingLogger, genericRow(123, 456)})) .thenAnswer( invocation -> { threadLatch.countDown(); @@ -173,9 +194,9 @@ public void shouldPerformThreadSafeParameterEvaluation() throws Exception { // Then: thread.join(); verify(expressionEvaluator, times(1)) - .evaluate(new Object[]{123, 456, DEFAULT_VAL, processingLogger, genericRow(123, 456)}); + .evaluate(new Object[]{arguments1, DEFAULT_VAL, processingLogger, genericRow(123, 456)}); verify(expressionEvaluator, times(1)) - .evaluate(new Object[]{100, 200, DEFAULT_VAL, processingLogger, genericRow(100, 200)}); + .evaluate(new Object[]{arguments2, DEFAULT_VAL, processingLogger, genericRow(100, 200)}); } @Test diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java index 117fea238f27..f3b5d03fa1e6 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java @@ -84,6 +84,7 @@ import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.schema.Operator; +import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.util.KsqlConfig; @@ -91,10 +92,15 @@ import java.math.BigDecimal; import java.sql.Date; import java.sql.Time; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -137,7 +143,7 @@ public void shouldProcessBasicJavaMath() { final String javaExpression = sqlToJavaVisitor.process(expression); // Then: - assertThat(javaExpression, equalTo("(COL0 + COL3)")); + assertThat(javaExpression, equalTo("(((java.lang.Long) arguments.get(\"COL0\")) + ((java.lang.Double) arguments.get(\"COL3\")))")); } @Test @@ -151,7 +157,7 @@ public void shouldProcessArrayExpressionCorrectly() { // Then: assertThat( javaExpression, - equalTo("((Double) (COL4 == null ? null : (ArrayAccess.arrayAccess((java.util.List) COL4, ((int) 0)))))") + equalTo("((Double) (((java.util.List) arguments.get(\"COL4\")) == null ? null : (ArrayAccess.arrayAccess((java.util.List) ((java.util.List) arguments.get(\"COL4\")), ((int) 0)))))") ); } @@ -164,7 +170,7 @@ public void shouldProcessMapExpressionCorrectly() { final String javaExpression = sqlToJavaVisitor.process(expression); // Then: - assertThat(javaExpression, equalTo("((Double) (COL5 == null ? null : ((java.util.Map)COL5).get(\"key1\")))")); + assertThat(javaExpression, equalTo("((Double) (((java.util.Map) arguments.get(\"COL5\")) == null ? null : ((java.util.Map)((java.util.Map) arguments.get(\"COL5\"))).get(\"key1\")))")); } @Test @@ -183,9 +189,9 @@ public void shouldProcessCreateArrayExpressionCorrectly() { // Then: assertThat( java, - equalTo("((List)new ArrayBuilder(2)" - + ".add( (new Supplier() {@Override public Object get() { try { return ((Double) (COL5 == null ? null : ((java.util.Map)COL5).get(\"key1\"))); } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing array item\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get())" - + ".add( (new Supplier() {@Override public Object get() { try { return 1E0; } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing array item\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get()).build())")); + equalTo("((List)new ArrayBuilder(2)" + + ".add( (new Supplier() {@Override public Object get() { try { return ((Double) (((java.util.Map) arguments.get(\"COL5\")) == null ? null : ((java.util.Map)((java.util.Map) arguments.get(\"COL5\"))).get(\"key1\"))); } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing array item\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get())" + + ".add( (new Supplier() {@Override public Object get() { try { return 1E0; } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing array item\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get()).build())")); } @Test @@ -204,9 +210,7 @@ public void shouldProcessCreateMapExpressionCorrectly() { String java = sqlToJavaVisitor.process(expression); // Then: - assertThat(java, equalTo("((Map)new MapBuilder(2)" - + ".put( (new Supplier() {@Override public Object get() { try { return \"foo\"; } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing map key\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get(), (new Supplier() {@Override public Object get() { try { return ((Double) (COL5 == null ? null : ((java.util.Map)COL5).get(\"key1\"))); } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing map value\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get())" - + ".put( (new Supplier() {@Override public Object get() { try { return \"bar\"; } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing map key\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get(), (new Supplier() {@Override public Object get() { try { return 1E0; } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing map value\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get()).build())")); + assertThat(java, equalTo("((Map)new MapBuilder(2).put( (new Supplier() {@Override public Object get() { try { return \"foo\"; } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing map key\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get(), (new Supplier() {@Override public Object get() { try { return ((Double) (((java.util.Map) arguments.get(\"COL5\")) == null ? null : ((java.util.Map)((java.util.Map) arguments.get(\"COL5\"))).get(\"key1\"))); } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing map value\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get()).put( (new Supplier() {@Override public Object get() { try { return \"bar\"; } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing map key\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get(), (new Supplier() {@Override public Object get() { try { return 1E0; } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing map value\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get()).build())")); } @Test @@ -225,9 +229,7 @@ public void shouldProcessStructExpressionCorrectly() { // Then: assertThat( javaExpression, - equalTo("((Struct)new Struct(schema0)" - + ".put(\"col1\", (new Supplier() {@Override public Object get() { try { return \"foo\"; } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing struct field\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get())" - + ".put(\"col2\", (new Supplier() {@Override public Object get() { try { return ((Double) (COL5 == null ? null : ((java.util.Map)COL5).get(\"key1\"))); } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing struct field\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get()))")); + equalTo("((Struct)new Struct(((org.apache.kafka.connect.data.Schema) arguments.get(\"schema0\"))).put(\"col1\", (new Supplier() {@Override public Object get() { try { return \"foo\"; } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing struct field\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get()).put(\"col2\", (new Supplier() {@Override public Object get() { try { return ((Double) (((java.util.Map) arguments.get(\"COL5\")) == null ? null : ((java.util.Map)((java.util.Map) arguments.get(\"COL5\"))).get(\"key1\"))); } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing struct field\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get()))")); } @Test @@ -251,9 +253,7 @@ public void shouldProcessStructExpressionWithDereferencesCorrectly() { // Then: assertThat( javaExpression, - equalTo("((Double)(((Struct)new Struct(schema0).put(\"col1\", (new Supplier() " - + "{@Override public Object get() { try { return \"foo\"; } catch (Exception e) { " - + "logger.error(RecordProcessingError.recordProcessingError( \"Error processing struct field\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get()).put(\"col2\", (new Supplier() {@Override public Object get() { try { return ((Double) (COL5 == null ? null : ((java.util.Map)COL5).get(\"key1\"))); } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing struct field\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get())) == null ? null : ((Struct)new Struct(schema0).put(\"col1\", (new Supplier() {@Override public Object get() { try { return \"foo\"; } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing struct field\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get()).put(\"col2\", (new Supplier() {@Override public Object get() { try { return ((Double) (COL5 == null ? null : ((java.util.Map)COL5).get(\"key1\"))); } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing struct field\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get())).get(\"col2\")))")); + equalTo("((Double)(((Struct)new Struct(((org.apache.kafka.connect.data.Schema) arguments.get(\"schema0\"))).put(\"col1\", (new Supplier() {@Override public Object get() { try { return \"foo\"; } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing struct field\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get()).put(\"col2\", (new Supplier() {@Override public Object get() { try { return ((Double) (((java.util.Map) arguments.get(\"COL5\")) == null ? null : ((java.util.Map)((java.util.Map) arguments.get(\"COL5\"))).get(\"key1\"))); } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing struct field\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get())) == null ? null : ((Struct)new Struct(((org.apache.kafka.connect.data.Schema) arguments.get(\"schema0\"))).put(\"col1\", (new Supplier() {@Override public Object get() { try { return \"foo\"; } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing struct field\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get()).put(\"col2\", (new Supplier() {@Override public Object get() { try { return ((Double) (((java.util.Map) arguments.get(\"COL5\")) == null ? null : ((java.util.Map)((java.util.Map) arguments.get(\"COL5\"))).get(\"key1\"))); } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing struct field\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get())).get(\"col2\")))")); } @Test @@ -292,21 +292,7 @@ public void shouldPostfixFunctionInstancesWithUniqueId() { final String javaExpression = sqlToJavaVisitor.process(expression); // Then: - final String expected = "((String) CONCAT_0.evaluate( (new Supplier() " - + "{@Override public Object get() { try { return ((String) " - + "SUBSTRING_1.evaluate(COL1, 1, 3)); } catch (Exception e) { " - + "logger.error(RecordProcessingError.recordProcessingError( " - + "\"Error processing SUBSTRING\", e instanceof InvocationTargetException? " - + "e.getCause() : e, row)); return defaultValue; }}}).get(), (new Supplier() " - + "{@Override public Object get() { try { return ((String) CONCAT_2.evaluate(\"-\", " - + "(new Supplier() {@Override public Object get() { try { return ((String) " - + "SUBSTRING_3.evaluate(COL1, 4, 5)); } catch (Exception e) { " - + "logger.error(RecordProcessingError.recordProcessingError( " - + "\"Error processing SUBSTRING\", e instanceof InvocationTargetException? " - + "e.getCause() : e, row)); return defaultValue; }}}).get())); } catch (Exception e) " - + "{ logger.error(RecordProcessingError.recordProcessingError( " - + "\"Error processing CONCAT\", e instanceof InvocationTargetException? " - + "e.getCause() : e, row)); return defaultValue; }}}).get()))"; + final String expected = "((String) ((io.confluent.ksql.function.udf.Kudf) arguments.get(\"CONCAT_0\")).evaluate( (new Supplier() {@Override public Object get() { try { return ((String) ((io.confluent.ksql.function.udf.Kudf) arguments.get(\"SUBSTRING_1\")).evaluate(((java.lang.String) arguments.get(\"COL1\")), 1, 3)); } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing SUBSTRING\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get(), (new Supplier() {@Override public Object get() { try { return ((String) ((io.confluent.ksql.function.udf.Kudf) arguments.get(\"CONCAT_2\")).evaluate(\"-\", (new Supplier() {@Override public Object get() { try { return ((String) ((io.confluent.ksql.function.udf.Kudf) arguments.get(\"SUBSTRING_3\")).evaluate(((java.lang.String) arguments.get(\"COL1\")), 4, 5)); } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing SUBSTRING\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get())); } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing CONCAT\", e instanceof InvocationTargetException? e.getCause() : e, row)); return defaultValue; }}}).get()))"; assertThat(javaExpression, is(expected)); } @@ -337,7 +323,7 @@ public void shouldImplicitlyCastFunctionCallParameters() { new io.confluent.ksql.execution.expression.tree.Type(SqlTypes.BIGINT))); assertThat(javaExpression, is( - "((String) FOO_0.evaluate(" +doubleCast + ", " + longCast + "))" + "((String) ((io.confluent.ksql.function.udf.Kudf) arguments.get(\"FOO_0\")).evaluate(" +doubleCast + ", " + longCast + "))" )); } @@ -371,7 +357,7 @@ public void shouldImplicitlyCastFunctionCallParametersVariadic() { new io.confluent.ksql.execution.expression.tree.Type(SqlTypes.BIGINT))); assertThat(javaExpression, is( - "((String) FOO_0.evaluate(" +doubleCast + ", " + longCast + ", " + longCast + "))" + "((String) ((io.confluent.ksql.function.udf.Kudf) arguments.get(\"FOO_0\")).evaluate(" +doubleCast + ", " + longCast + ", " + longCast + "))" )); } @@ -394,7 +380,7 @@ public void shouldHandleFunctionCallsWithGenerics() { ); // Then: - assertThat(javaExpression, is("((String) FOO_0.evaluate(1, 1))")); + assertThat(javaExpression, is("((String) ((io.confluent.ksql.function.udf.Kudf) arguments.get(\"FOO_0\")).evaluate(1, 1))")); } @Test @@ -436,7 +422,7 @@ public void shouldGenerateCorrectCodeForComparisonWithNegativeNumbers() { // Then: assertThat( javaExpression, equalTo( - "((((Object)(COL3)) == null || ((Object)(-1E1)) == null) ? false : (COL3 > -1E1))")); + "((((Object)(((java.lang.Double) arguments.get(\"COL3\")))) == null || ((Object)(-1E1)) == null) ? false : (((java.lang.Double) arguments.get(\"COL3\")) > -1E1))")); } @Test @@ -448,7 +434,7 @@ public void shouldGenerateCorrectCodeForLikePattern() { final String javaExpression = sqlToJavaVisitor.process(expression); // Then: - assertThat(javaExpression, equalTo("LikeEvaluator.matches(COL1, \"%foo\")")); + assertThat(javaExpression, equalTo("LikeEvaluator.matches(((java.lang.String) arguments.get(\"COL1\")), \"%foo\")")); } @Test @@ -460,7 +446,7 @@ public void shouldGenerateCorrectCodeForLikePatternWithEscape() { final String javaExpression = sqlToJavaVisitor.process(expression); // Then: - assertThat(javaExpression, equalTo("LikeEvaluator.matches(COL1, \"%foo\", '!')")); + assertThat(javaExpression, equalTo("LikeEvaluator.matches(((java.lang.String) arguments.get(\"COL1\")), \"%foo\", '!')")); } @Test @@ -472,7 +458,7 @@ public void shouldGenerateCorrectCodeForLikePatternWithColRef() { final String javaExpression = sqlToJavaVisitor.process(expression); // Then: - assertThat(javaExpression, equalTo("LikeEvaluator.matches(COL1, COL1)")); + assertThat(javaExpression, equalTo("LikeEvaluator.matches(((java.lang.String) arguments.get(\"COL1\")), ((java.lang.String) arguments.get(\"COL1\")))")); } @Test @@ -500,7 +486,7 @@ ComparisonExpression.Type.LESS_THAN, COL7, new IntegerLiteral(100)), // ThenL assertThat( javaExpression, equalTo( - "((java.lang.String)SearchedCaseFunction.searchedCaseFunction(ImmutableList.copyOf(Arrays.asList( SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(COL7)) == null || ((Object)(10)) == null) ? false : (COL7 < 10)); }}, new Supplier() { @Override public java.lang.String get() { return \"small\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(COL7)) == null || ((Object)(100)) == null) ? false : (COL7 < 100)); }}, new Supplier() { @Override public java.lang.String get() { return \"medium\"; }}))), new Supplier() { @Override public java.lang.String get() { return \"large\"; }}))")); + "((java.lang.String)SearchedCaseFunction.searchedCaseFunction(ImmutableList.copyOf(Arrays.asList( SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(((java.lang.Integer) arguments.get(\"COL7\")))) == null || ((Object)(10)) == null) ? false : (((java.lang.Integer) arguments.get(\"COL7\")) < 10)); }}, new Supplier() { @Override public java.lang.String get() { return \"small\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(((java.lang.Integer) arguments.get(\"COL7\")))) == null || ((Object)(100)) == null) ? false : (((java.lang.Integer) arguments.get(\"COL7\")) < 100)); }}, new Supplier() { @Override public java.lang.String get() { return \"medium\"; }}))), new Supplier() { @Override public java.lang.String get() { return \"large\"; }}))")); } @Test @@ -530,7 +516,7 @@ ComparisonExpression.Type.EQUAL, COL7, new IntegerLiteral(n)), // ThenL assertThat( javaExpression, equalTo( - "((java.lang.String)SearchedCaseFunction.searchedCaseFunction(ImmutableList.copyOf(Arrays.asList( SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(COL7)) == null || ((Object)(0)) == null) ? false : ((COL7 <= 0) && (COL7 >= 0))); }}, new Supplier() { @Override public java.lang.String get() { return \"zero\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(COL7)) == null || ((Object)(1)) == null) ? false : ((COL7 <= 1) && (COL7 >= 1))); }}, new Supplier() { @Override public java.lang.String get() { return \"one\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(COL7)) == null || ((Object)(2)) == null) ? false : ((COL7 <= 2) && (COL7 >= 2))); }}, new Supplier() { @Override public java.lang.String get() { return \"two\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(COL7)) == null || ((Object)(3)) == null) ? false : ((COL7 <= 3) && (COL7 >= 3))); }}, new Supplier() { @Override public java.lang.String get() { return \"three\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(COL7)) == null || ((Object)(4)) == null) ? false : ((COL7 <= 4) && (COL7 >= 4))); }}, new Supplier() { @Override public java.lang.String get() { return \"four\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(COL7)) == null || ((Object)(5)) == null) ? false : ((COL7 <= 5) && (COL7 >= 5))); }}, new Supplier() { @Override public java.lang.String get() { return \"five\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(COL7)) == null || ((Object)(6)) == null) ? false : ((COL7 <= 6) && (COL7 >= 6))); }}, new Supplier() { @Override public java.lang.String get() { return \"six\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(COL7)) == null || ((Object)(7)) == null) ? false : ((COL7 <= 7) && (COL7 >= 7))); }}, new Supplier() { @Override public java.lang.String get() { return \"seven\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(COL7)) == null || ((Object)(8)) == null) ? false : ((COL7 <= 8) && (COL7 >= 8))); }}, new Supplier() { @Override public java.lang.String get() { return \"eight\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(COL7)) == null || ((Object)(9)) == null) ? false : ((COL7 <= 9) && (COL7 >= 9))); }}, new Supplier() { @Override public java.lang.String get() { return \"nine\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(COL7)) == null || ((Object)(10)) == null) ? false : ((COL7 <= 10) && (COL7 >= 10))); }}, new Supplier() { @Override public java.lang.String get() { return \"ten\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(COL7)) == null || ((Object)(11)) == null) ? false : ((COL7 <= 11) && (COL7 >= 11))); }}, new Supplier() { @Override public java.lang.String get() { return \"eleven\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(COL7)) == null || ((Object)(12)) == null) ? false : ((COL7 <= 12) && (COL7 >= 12))); }}, new Supplier() { @Override public java.lang.String get() { return \"twelve\"; }}))), new Supplier() { @Override public java.lang.String get() { return null; }}))")); + "((java.lang.String)SearchedCaseFunction.searchedCaseFunction(ImmutableList.copyOf(Arrays.asList( SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(((java.lang.Integer) arguments.get(\"COL7\")))) == null || ((Object)(0)) == null) ? false : ((((java.lang.Integer) arguments.get(\"COL7\")) <= 0) && (((java.lang.Integer) arguments.get(\"COL7\")) >= 0))); }}, new Supplier() { @Override public java.lang.String get() { return \"zero\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(((java.lang.Integer) arguments.get(\"COL7\")))) == null || ((Object)(1)) == null) ? false : ((((java.lang.Integer) arguments.get(\"COL7\")) <= 1) && (((java.lang.Integer) arguments.get(\"COL7\")) >= 1))); }}, new Supplier() { @Override public java.lang.String get() { return \"one\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(((java.lang.Integer) arguments.get(\"COL7\")))) == null || ((Object)(2)) == null) ? false : ((((java.lang.Integer) arguments.get(\"COL7\")) <= 2) && (((java.lang.Integer) arguments.get(\"COL7\")) >= 2))); }}, new Supplier() { @Override public java.lang.String get() { return \"two\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(((java.lang.Integer) arguments.get(\"COL7\")))) == null || ((Object)(3)) == null) ? false : ((((java.lang.Integer) arguments.get(\"COL7\")) <= 3) && (((java.lang.Integer) arguments.get(\"COL7\")) >= 3))); }}, new Supplier() { @Override public java.lang.String get() { return \"three\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(((java.lang.Integer) arguments.get(\"COL7\")))) == null || ((Object)(4)) == null) ? false : ((((java.lang.Integer) arguments.get(\"COL7\")) <= 4) && (((java.lang.Integer) arguments.get(\"COL7\")) >= 4))); }}, new Supplier() { @Override public java.lang.String get() { return \"four\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(((java.lang.Integer) arguments.get(\"COL7\")))) == null || ((Object)(5)) == null) ? false : ((((java.lang.Integer) arguments.get(\"COL7\")) <= 5) && (((java.lang.Integer) arguments.get(\"COL7\")) >= 5))); }}, new Supplier() { @Override public java.lang.String get() { return \"five\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(((java.lang.Integer) arguments.get(\"COL7\")))) == null || ((Object)(6)) == null) ? false : ((((java.lang.Integer) arguments.get(\"COL7\")) <= 6) && (((java.lang.Integer) arguments.get(\"COL7\")) >= 6))); }}, new Supplier() { @Override public java.lang.String get() { return \"six\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(((java.lang.Integer) arguments.get(\"COL7\")))) == null || ((Object)(7)) == null) ? false : ((((java.lang.Integer) arguments.get(\"COL7\")) <= 7) && (((java.lang.Integer) arguments.get(\"COL7\")) >= 7))); }}, new Supplier() { @Override public java.lang.String get() { return \"seven\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(((java.lang.Integer) arguments.get(\"COL7\")))) == null || ((Object)(8)) == null) ? false : ((((java.lang.Integer) arguments.get(\"COL7\")) <= 8) && (((java.lang.Integer) arguments.get(\"COL7\")) >= 8))); }}, new Supplier() { @Override public java.lang.String get() { return \"eight\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(((java.lang.Integer) arguments.get(\"COL7\")))) == null || ((Object)(9)) == null) ? false : ((((java.lang.Integer) arguments.get(\"COL7\")) <= 9) && (((java.lang.Integer) arguments.get(\"COL7\")) >= 9))); }}, new Supplier() { @Override public java.lang.String get() { return \"nine\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(((java.lang.Integer) arguments.get(\"COL7\")))) == null || ((Object)(10)) == null) ? false : ((((java.lang.Integer) arguments.get(\"COL7\")) <= 10) && (((java.lang.Integer) arguments.get(\"COL7\")) >= 10))); }}, new Supplier() { @Override public java.lang.String get() { return \"ten\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(((java.lang.Integer) arguments.get(\"COL7\")))) == null || ((Object)(11)) == null) ? false : ((((java.lang.Integer) arguments.get(\"COL7\")) <= 11) && (((java.lang.Integer) arguments.get(\"COL7\")) >= 11))); }}, new Supplier() { @Override public java.lang.String get() { return \"eleven\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(((java.lang.Integer) arguments.get(\"COL7\")))) == null || ((Object)(12)) == null) ? false : ((((java.lang.Integer) arguments.get(\"COL7\")) <= 12) && (((java.lang.Integer) arguments.get(\"COL7\")) >= 12))); }}, new Supplier() { @Override public java.lang.String get() { return \"twelve\"; }}))), new Supplier() { @Override public java.lang.String get() { return null; }}))")); } @Test @@ -558,7 +544,7 @@ ComparisonExpression.Type.LESS_THAN, COL7, new IntegerLiteral(100)), // ThenL assertThat( javaExpression, equalTo( - "((java.lang.String)SearchedCaseFunction.searchedCaseFunction(ImmutableList.copyOf(Arrays.asList( SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(COL7)) == null || ((Object)(10)) == null) ? false : (COL7 < 10)); }}, new Supplier() { @Override public java.lang.String get() { return \"small\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(COL7)) == null || ((Object)(100)) == null) ? false : (COL7 < 100)); }}, new Supplier() { @Override public java.lang.String get() { return \"medium\"; }}))), new Supplier() { @Override public java.lang.String get() { return null; }}))")); + "((java.lang.String)SearchedCaseFunction.searchedCaseFunction(ImmutableList.copyOf(Arrays.asList( SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(((java.lang.Integer) arguments.get(\"COL7\")))) == null || ((Object)(10)) == null) ? false : (((java.lang.Integer) arguments.get(\"COL7\")) < 10)); }}, new Supplier() { @Override public java.lang.String get() { return \"small\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(((java.lang.Integer) arguments.get(\"COL7\")))) == null || ((Object)(100)) == null) ? false : (((java.lang.Integer) arguments.get(\"COL7\")) < 100)); }}, new Supplier() { @Override public java.lang.String get() { return \"medium\"; }}))), new Supplier() { @Override public java.lang.String get() { return null; }}))")); } @Test @@ -576,7 +562,7 @@ public void shouldGenerateCorrectCodeForDecimalAdd() { // Then: assertThat( java, - is("(COL8.add(COL8, new MathContext(3, RoundingMode.UNNECESSARY)).setScale(1))") + is("(((java.math.BigDecimal) arguments.get(\"COL8\")).add(((java.math.BigDecimal) arguments.get(\"COL8\")), new MathContext(3, RoundingMode.UNNECESSARY)).setScale(1))") ); } @@ -593,7 +579,7 @@ public void shouldGenerateCastLongToDecimalInBinaryExpression() { final String java = sqlToJavaVisitor.process(binExp); // Then: - assertThat(java, containsString("DecimalUtil.cast(COL0, 19, 0)")); + assertThat(java, containsString("DecimalUtil.cast(((java.lang.Long) arguments.get(\"COL0\")), 19, 0)")); } @Test @@ -610,7 +596,7 @@ public void shouldGenerateCastDecimalToDoubleInBinaryExpression() { // Then: final String doubleCast = CastEvaluator.generateCode( - "COL8", SqlTypes.decimal(2, 1), SqlTypes.DOUBLE, ksqlConfig); + "((java.math.BigDecimal) arguments.get(\"COL8\"))", SqlTypes.decimal(2, 1), SqlTypes.DOUBLE, ksqlConfig); assertThat(java, containsString(doubleCast)); } @@ -625,7 +611,7 @@ public void shouldGenerateCastExpressionsWhichAreComparable() { // Then: final Evaluator evaluator = CodeGenTestUtil.cookCode(java, Boolean.class); - evaluator.evaluate(Collections.emptyList()); + evaluator.evaluate(); } @Test @@ -643,7 +629,7 @@ public void shouldGenerateCorrectCodeForDecimalSubtract() { // Then: assertThat( java, - is("(COL8.subtract(COL8, new MathContext(3, RoundingMode.UNNECESSARY)).setScale(1))") + is("(((java.math.BigDecimal) arguments.get(\"COL8\")).subtract(((java.math.BigDecimal) arguments.get(\"COL8\")), new MathContext(3, RoundingMode.UNNECESSARY)).setScale(1))") ); } @@ -662,7 +648,7 @@ public void shouldGenerateCorrectCodeForDecimalMultiply() { // Then: assertThat( java, - is("(COL8.multiply(COL8, new MathContext(5, RoundingMode.UNNECESSARY)).setScale(2))") + is("(((java.math.BigDecimal) arguments.get(\"COL8\")).multiply(((java.math.BigDecimal) arguments.get(\"COL8\")), new MathContext(5, RoundingMode.UNNECESSARY)).setScale(2))") ); } @@ -681,7 +667,7 @@ public void shouldGenerateCorrectCodeForDecimalDivide() { // Then: assertThat( java, - is("(COL8.divide(COL8, new MathContext(8, RoundingMode.UNNECESSARY)).setScale(6))") + is("(((java.math.BigDecimal) arguments.get(\"COL8\")).divide(((java.math.BigDecimal) arguments.get(\"COL8\")), new MathContext(8, RoundingMode.UNNECESSARY)).setScale(6))") ); } @@ -700,7 +686,7 @@ public void shouldGenerateCorrectCodeForDecimalMod() { // Then: assertThat( java, - is("(COL8.remainder(COL8, new MathContext(2, RoundingMode.UNNECESSARY)).setScale(1))") + is("(((java.math.BigDecimal) arguments.get(\"COL8\")).remainder(((java.math.BigDecimal) arguments.get(\"COL8\")), new MathContext(2, RoundingMode.UNNECESSARY)).setScale(1))") ); } @@ -717,7 +703,7 @@ public void shouldGenerateCorrectCodeForDecimalDecimalEQ() { final String java = sqlToJavaVisitor.process(compExp); // Then: - assertThat(java, containsString("(COL8.compareTo(COL9) == 0))")); + assertThat(java, containsString("(((java.math.BigDecimal) arguments.get(\"COL8\")).compareTo(((java.math.BigDecimal) arguments.get(\"COL9\"))) == 0)")); } @Test @@ -733,7 +719,7 @@ public void shouldGenerateCorrectCodeForDecimalDecimalGT() { final String java = sqlToJavaVisitor.process(compExp); // Then: - assertThat(java, containsString("(COL8.compareTo(COL9) > 0))")); + assertThat(java, containsString("(((java.math.BigDecimal) arguments.get(\"COL8\")).compareTo(((java.math.BigDecimal) arguments.get(\"COL9\"))) > 0)")); } @Test @@ -749,7 +735,7 @@ public void shouldGenerateCorrectCodeForDecimalDecimalGEQ() { final String java = sqlToJavaVisitor.process(compExp); // Then: - assertThat(java, containsString("(COL8.compareTo(COL9) >= 0))")); + assertThat(java, containsString("(((java.math.BigDecimal) arguments.get(\"COL8\")).compareTo(((java.math.BigDecimal) arguments.get(\"COL9\"))) >= 0)")); } @Test @@ -765,7 +751,7 @@ public void shouldGenerateCorrectCodeForDecimalDecimalLT() { final String java = sqlToJavaVisitor.process(compExp); // Then: - assertThat(java, containsString("(COL8.compareTo(COL9) < 0))")); + assertThat(java, containsString("(((java.math.BigDecimal) arguments.get(\"COL8\")).compareTo(((java.math.BigDecimal) arguments.get(\"COL9\"))) < 0)")); } @Test @@ -781,7 +767,7 @@ public void shouldGenerateCorrectCodeForDecimalDecimalLEQ() { final String java = sqlToJavaVisitor.process(compExp); // Then: - assertThat(java, containsString("(COL8.compareTo(COL9) <= 0))")); + assertThat(java, containsString("(((java.math.BigDecimal) arguments.get(\"COL8\")).compareTo(((java.math.BigDecimal) arguments.get(\"COL9\"))) <= 0)")); } @Test @@ -797,7 +783,7 @@ public void shouldGenerateCorrectCodeForDecimalDecimalIsDistinct() { final String java = sqlToJavaVisitor.process(compExp); // Then: - assertThat(java, containsString("(COL8.compareTo(COL9) != 0))")); + assertThat(java, containsString("(((java.math.BigDecimal) arguments.get(\"COL8\")).compareTo(((java.math.BigDecimal) arguments.get(\"COL9\"))) != 0))")); } @Test @@ -813,7 +799,7 @@ public void shouldGenerateCorrectCodeForDecimalDoubleEQ() { final String java = sqlToJavaVisitor.process(compExp); // Then: - assertThat(java, containsString("(COL8.compareTo(BigDecimal.valueOf(COL3)) == 0))")); + assertThat(java, containsString("(((java.math.BigDecimal) arguments.get(\"COL8\")).compareTo(BigDecimal.valueOf(((java.lang.Double) arguments.get(\"COL3\")))) == 0)")); } @Test @@ -829,7 +815,7 @@ public void shouldGenerateCorrectCodeForDoubleDecimalEQ() { final String java = sqlToJavaVisitor.process(compExp); // Then: - assertThat(java, containsString("(BigDecimal.valueOf(COL3).compareTo(COL8) == 0))")); + assertThat(java, containsString("(BigDecimal.valueOf(((java.lang.Double) arguments.get(\"COL3\"))).compareTo(((java.math.BigDecimal) arguments.get(\"COL8\"))) == 0)")); } @Test @@ -845,7 +831,7 @@ public void shouldGenerateCorrectCodeForDecimalNegation() { final String java = sqlToJavaVisitor.process(binExp); // Then: - assertThat(java, is("(COL8.negate(new MathContext(2, RoundingMode.UNNECESSARY)))")); + assertThat(java, is("(((java.math.BigDecimal) arguments.get(\"COL8\")).negate(new MathContext(2, RoundingMode.UNNECESSARY)))")); } @Test @@ -861,7 +847,7 @@ public void shouldGenerateCorrectCodeForDecimalUnaryPlus() { final String java = sqlToJavaVisitor.process(binExp); // Then: - assertThat(java, is("(COL8.plus(new MathContext(2, RoundingMode.UNNECESSARY)))")); + assertThat(java, is("(((java.math.BigDecimal) arguments.get(\"COL8\")).plus(new MathContext(2, RoundingMode.UNNECESSARY)))")); } @Test @@ -877,7 +863,7 @@ public void shouldGenerateCorrectCodeForTimeTimeLT() { final String java = sqlToJavaVisitor.process(compExp); // Then: - assertThat(java, containsString("(COL12.compareTo(COL12) < 0)")); + assertThat(java, containsString("(((java.sql.Time) arguments.get(\"COL12\")).compareTo(((java.sql.Time) arguments.get(\"COL12\"))) < 0)")); } @Test @@ -893,7 +879,7 @@ public void shouldGenerateCorrectCodeForTimeStringEQ() { final String java = sqlToJavaVisitor.process(compExp); // Then: - assertThat(java, containsString("(COL12.compareTo(SqlTimeTypes.parseTime(\"01:23:45\")) == 0)")); + assertThat(java, containsString("(((java.sql.Time) arguments.get(\"COL12\")).compareTo(SqlTimeTypes.parseTime(\"01:23:45\")) == 0)")); } @Test @@ -937,7 +923,7 @@ public void shouldGenerateCorrectCodeForDateDateLT() { final String java = sqlToJavaVisitor.process(compExp); // Then: - assertThat(java, containsString("(COL13.compareTo(COL13) < 0)")); + assertThat(java, containsString("(((java.sql.Date) arguments.get(\"COL13\")).compareTo(((java.sql.Date) arguments.get(\"COL13\"))) < 0)")); } @Test @@ -953,7 +939,7 @@ public void shouldGenerateCorrectCodeForDateStringEQ() { final String java = sqlToJavaVisitor.process(compExp); // Then: - assertThat(java, containsString("(COL13.compareTo(SqlTimeTypes.parseDate(\"2021-06-23\")) == 0)")); + assertThat(java, containsString("(((java.sql.Date) arguments.get(\"COL13\")).compareTo(SqlTimeTypes.parseDate(\"2021-06-23\")) == 0)")); } @Test @@ -969,7 +955,7 @@ public void shouldGenerateCorrectCodeForTimestampTimestampLT() { final String java = sqlToJavaVisitor.process(compExp); // Then: - assertThat(java, containsString("(COL10.compareTo(COL10) < 0)")); + assertThat(java, containsString("(((java.sql.Timestamp) arguments.get(\"COL10\")).compareTo(((java.sql.Timestamp) arguments.get(\"COL10\"))) < 0)")); } @Test @@ -985,7 +971,7 @@ public void shouldGenerateCorrectCodeForTimestampStringEQ() { final String java = sqlToJavaVisitor.process(compExp); // Then: - assertThat(java, containsString("(COL10.compareTo(SqlTimeTypes.parseTimestamp(\"2020-01-01T00:00:00\")) == 0)")); + assertThat(java, containsString("(((java.sql.Timestamp) arguments.get(\"COL10\")).compareTo(SqlTimeTypes.parseTimestamp(\"2020-01-01T00:00:00\")) == 0)")); } @Test @@ -1001,7 +987,7 @@ public void shouldGenerateCorrectCodeForTimestampStringGEQ() { final String java = sqlToJavaVisitor.process(compExp); // Then: - assertThat(java, containsString("(SqlTimeTypes.parseTimestamp(\"2020-01-01T00:00:00\").compareTo(COL10) >= 0)")); + assertThat(java, containsString("(SqlTimeTypes.parseTimestamp(\"2020-01-01T00:00:00\").compareTo(((java.sql.Timestamp) arguments.get(\"COL10\"))) >= 0)")); } @Test @@ -1017,7 +1003,7 @@ public void shouldGenerateCorrectCodeForTimestampDateGT() { final String java = sqlToJavaVisitor.process(compExp); // Then: - assertThat(java, containsString("(COL10.compareTo(COL13) > 0)")); + assertThat(java, containsString("(((java.sql.Timestamp) arguments.get(\"COL10\")).compareTo(((java.sql.Date) arguments.get(\"COL13\"))) > 0)")); } @Test @@ -1033,7 +1019,7 @@ public void shouldGenerateCorrectCodeForBytesBytesGT() { final String java = sqlToJavaVisitor.process(compExp); // Then: - assertThat(java, containsString("(COL14.compareTo(COL14) > 0)")); + assertThat(java, containsString("(((java.nio.ByteBuffer) arguments.get(\"COL14\")).compareTo(((java.nio.ByteBuffer) arguments.get(\"COL14\"))) > 0)")); } @Test @@ -1099,7 +1085,7 @@ public void shouldGenerateCorrectCodeForInPredicate() { final String java = sqlToJavaVisitor.process(expression); // Then: - assertThat(java, is("InListEvaluator.matches(COL0,1L,2L)")); + assertThat(java, is("InListEvaluator.matches(((java.lang.Long) arguments.get(\"COL0\")),1L,2L)")); } @Test @@ -1131,7 +1117,13 @@ public void shouldGenerateCorrectCodeForLambdaExpression() { // Then assertThat( javaExpression, equalTo( - "((String) TRANSFORM_0.evaluate(COL4, new Function() {\n @Override\n public Object apply(Object arg1) {\n final Double x = (Double) arg1;\n return ((String) ABS_1.evaluate(X));\n }\n}))")); + "((String) ((io.confluent.ksql.function.udf.Kudf) arguments.get(\"TRANSFORM_0\")).evaluate(((java.util.List) arguments.get(\"COL4\")), new Function() {\n" + + " @Override\n" + + " public Object apply(Object arg1) {\n" + + " final Double x = (Double) arg1;\n" + + " return ((String) ((io.confluent.ksql.function.udf.Kudf) arguments.get(\"ABS_1\")).evaluate(X));\n" + + " }\n" + + "}))")); } @Test @@ -1168,7 +1160,7 @@ public void shouldGenerateCorrectCodeForLambdaExpressionWithTwoArguments() { // Then assertThat( javaExpression, equalTo( - "((String) REDUCE_0.evaluate(COL4, COL3, new BiFunction() {\n" + + "((String) ((io.confluent.ksql.function.udf.Kudf) arguments.get(\"REDUCE_0\")).evaluate(((java.util.List) arguments.get(\"COL4\")), ((java.lang.Double) arguments.get(\"COL3\")), new BiFunction() {\n" + " @Override\n" + " public Object apply(Object arg1, Object arg2) {\n" + " final Double X = (Double) arg1;\n" + @@ -1232,7 +1224,7 @@ ComparisonExpression.Type.LESS_THAN, new LambdaVariable("X"), new IntegerLiteral // Then assertThat( - javaExpression, equalTo("((String) function_0.evaluate(COL4, COL1, new BiFunction() {\n" + javaExpression, equalTo("((String) ((io.confluent.ksql.function.udf.Kudf) arguments.get(\"function_0\")).evaluate(((java.util.List) arguments.get(\"COL4\")), ((java.lang.String) arguments.get(\"COL1\")), new BiFunction() {\n" + " @Override\n" + " public Object apply(Object arg1, Object arg2) {\n" + " final Double X = (Double) arg1;\n" @@ -1298,41 +1290,33 @@ public void shouldGenerateCorrectCodeForNestedLambdas() { // Then assertThat( javaExpression, equalTo( - "(((Double) nested_0.evaluate(COL4, (new Supplier() " + - "{@Override public java.lang.Double get() { " + - "try { return ((Double)NullSafe.apply(0,new Function() {\n" - + " @Override\n" - + " public Object apply(Object arg1) {\n" - + " final Integer val = (Integer) arg1;\n" - + " return val.doubleValue();\n" - + " }\n" - + "})); } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( " + - "\"Error processing DOUBLE\", e instanceof InvocationTargetException? e.getCause() : e, " + - "row)); return (java.lang.Double) defaultValue; }}}).get(), new BiFunction() {\n" - + " @Override\n" - + " public Object apply(Object arg1, Object arg2) {\n" - + " final Double A = (Double) arg1;\n" - + " final Integer B = (Integer) arg2;\n" - + " return (((Double) nested_1.evaluate(COL4, (new Supplier() " + - "{@Override public java.lang.Double get() { try { " + - "return ((Double)NullSafe.apply(0,new Function() {\n" - + " @Override\n" - + " public Object apply(Object arg1) {\n" - + " final Integer val = (Integer) arg1;\n" - + " return val.doubleValue();\n" - + " }\n" - + "})); } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( " + - "\"Error processing DOUBLE\", e instanceof InvocationTargetException? e.getCause() : e, " + - "row)); return (java.lang.Double) defaultValue; }}}).get(), new BiFunction() {\n" - + " @Override\n" - + " public Object apply(Object arg1, Object arg2) {\n" - + " final Double Q = (Double) arg1;\n" - + " final Integer V = (Integer) arg2;\n" - + " return (Q + V);\n" - + " }\n" - + "})) + B);\n" - + " }\n" - + "})) + 5)")); + "(((Double) ((io.confluent.ksql.function.udf.Kudf) arguments.get(\"nested_0\")).evaluate(((java.util.List) arguments.get(\"COL4\")), (new Supplier() {@Override public java.lang.Double get() { try { return ((Double)NullSafe.apply(0,new Function() {\n" + + " @Override\n" + + " public Object apply(Object arg1) {\n" + + " final Integer val = (Integer) arg1;\n" + + " return val.doubleValue();\n" + + " }\n" + + "})); } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing DOUBLE\", e instanceof InvocationTargetException? e.getCause() : e, row)); return (java.lang.Double) defaultValue; }}}).get(), new BiFunction() {\n" + + " @Override\n" + + " public Object apply(Object arg1, Object arg2) {\n" + + " final Double A = (Double) arg1;\n" + + " final Integer B = (Integer) arg2;\n" + + " return (((Double) ((io.confluent.ksql.function.udf.Kudf) arguments.get(\"nested_1\")).evaluate(((java.util.List) arguments.get(\"COL4\")), (new Supplier() {@Override public java.lang.Double get() { try { return ((Double)NullSafe.apply(0,new Function() {\n" + + " @Override\n" + + " public Object apply(Object arg1) {\n" + + " final Integer val = (Integer) arg1;\n" + + " return val.doubleValue();\n" + + " }\n" + + "})); } catch (Exception e) { logger.error(RecordProcessingError.recordProcessingError( \"Error processing DOUBLE\", e instanceof InvocationTargetException? e.getCause() : e, row)); return (java.lang.Double) defaultValue; }}}).get(), new BiFunction() {\n" + + " @Override\n" + + " public Object apply(Object arg1, Object arg2) {\n" + + " final Double Q = (Double) arg1;\n" + + " final Integer V = (Integer) arg2;\n" + + " return (Q + V);\n" + + " }\n" + + "})) + B);\n" + + " }\n" + + "})) + 5)")); } @Test @@ -1356,6 +1340,41 @@ public void shouldProcessTimeLiteral() { assertThat(sqlToJavaVisitor.process(new TimeLiteral(new Time(1000))), is("00:00:01")); } + @Test + public void shouldHandleManyArguments() { + // Given: + final LogicalSchema.Builder schemaBuilder = LogicalSchema.builder(); + for (int i = 0; i < 500; i++) { + schemaBuilder.valueColumn(ColumnName.of("COL" + i), SqlTypes.STRING); + } + final LogicalSchema schema = schemaBuilder.build(); + + final AtomicInteger funCounter = new AtomicInteger(); + final AtomicInteger structCounter = new AtomicInteger(); + final SqlToJavaVisitor sqlToJavaVisitor = new SqlToJavaVisitor( + schema, + functionRegistry, + ref -> ref.text().replace(".", "_"), + name -> name.text() + "_" + funCounter.getAndIncrement(), + struct -> "schema" + structCounter.getAndIncrement(), + ksqlConfig + ); + + final List expressions = new ArrayList<>(); + for (int i = 0; i < 500 ; i++) { + expressions.add(new UnqualifiedColumnReferenceExp(ColumnName.of("COL" + i))); + } + final Expression expression = new CreateArrayExpression(expressions); + + // When: + final String java = sqlToJavaVisitor.process(expression); + + // Then: + final Map arguments = IntStream.range(0, 500).boxed().collect(Collectors.toMap(i -> "COL" + i, String::valueOf)); + final Evaluator evaluator = CodeGenTestUtil.cookCode(java, List.class); + evaluator.evaluate(arguments); + } + private void givenUdf( final String name, final UdfFactory factory, diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/CastEvaluatorTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/CastEvaluatorTest.java index a409d700ad49..237c5278fba9 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/CastEvaluatorTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/CastEvaluatorTest.java @@ -38,6 +38,7 @@ import com.google.common.collect.ImmutableSet; import io.confluent.ksql.execution.codegen.CodeGenTestUtil; import io.confluent.ksql.execution.codegen.CodeGenTestUtil.Evaluator; +import io.confluent.ksql.execution.codegen.CodeGenUtil; import io.confluent.ksql.schema.ksql.DefaultSqlValueCoercerTest; import io.confluent.ksql.schema.ksql.DefaultSqlValueCoercerTest.LaxValueCoercionTest.LaxOnly; import io.confluent.ksql.schema.ksql.SchemaConverters; @@ -189,7 +190,7 @@ public void shouldNotCastPositiveInfinite() { // When: final Exception exception = assertThrows( NumberFormatException.class, - () -> evaluator.rawEvaluate("Infinity") + () -> evaluator.rawEvaluate(INNER_CODE, "Infinity") ); // Then: @@ -203,7 +204,7 @@ public void shouldNotCastNegativeInfinite() { // When: final Exception exception = assertThrows( NumberFormatException.class, - () -> evaluator.rawEvaluate("-Infinity") + () -> evaluator.rawEvaluate(INNER_CODE, "-Infinity") ); // Then: @@ -217,7 +218,7 @@ public void shouldNotCastNaN() { // When: final Exception exception = assertThrows( NumberFormatException.class, - () -> evaluator.rawEvaluate("NaN") + () -> evaluator.rawEvaluate(INNER_CODE, "NaN") ); // Then: @@ -238,7 +239,7 @@ public void shouldNotCastIncorrectlyFormattedString() { // When: final Exception exception = assertThrows( KsqlException.class, - () -> evaluator.rawEvaluate("woof") + () -> evaluator.rawEvaluate(INNER_CODE, "woof") ); // Then: @@ -259,7 +260,7 @@ public void shouldNotCastIncorrectlyFormattedString() { // When: final Exception exception = assertThrows( KsqlException.class, - () -> evaluator.rawEvaluate("woof") + () -> evaluator.rawEvaluate(INNER_CODE, "woof") ); // Then: @@ -272,7 +273,7 @@ public void shouldCastDateToTimestamp() throws Exception { final Evaluator evaluator = cookCode(SqlTypes.DATE, SqlTypes.TIMESTAMP, config); // Then: - assertThat(evaluator.rawEvaluate(new Date(864000000)), is(new Timestamp(864000000))); + assertThat(evaluator.rawEvaluate(INNER_CODE, new Date(864000000)), is(new Timestamp(864000000))); } } @@ -289,7 +290,7 @@ public void shouldNotCastIncorrectlyFormattedString() { // When: final Exception exception = assertThrows( KsqlException.class, - () -> evaluator.rawEvaluate("woof") + () -> evaluator.rawEvaluate(INNER_CODE, "woof") ); // Then: @@ -302,7 +303,7 @@ public void shouldCastTimestampToDate() throws Exception { final Evaluator evaluator = cookCode(SqlTypes.TIMESTAMP, SqlTypes.DATE, config); // Then: - assertThat(evaluator.rawEvaluate(new Timestamp(864033000)), is(new Date(864000000))); + assertThat(evaluator.rawEvaluate(INNER_CODE, new Timestamp(864033000)), is(new Date(864000000))); } @Test @@ -311,7 +312,7 @@ public void shouldCastTimestampToTime() throws Exception { final Evaluator evaluator = cookCode(SqlTypes.TIMESTAMP, SqlTypes.TIME, config); // Then: - assertThat(evaluator.rawEvaluate(new Timestamp(864033000)), is(new Time(33000))); + assertThat(evaluator.rawEvaluate(INNER_CODE, new Timestamp(864033000)), is(new Time(33000))); } } @@ -688,15 +689,13 @@ private static Evaluator cookCode( final SqlType to, final KsqlConfig config ) { - final String javaCode = CastEvaluator.generateCode(INNER_CODE, from, to, config); - - final Class fromJavaType = SchemaConverters.sqlToJavaConverter() - .toJavaType(from); + final String paramAccessor = CodeGenUtil.argumentAccessor(INNER_CODE, from); + final String javaCode = CastEvaluator.generateCode(paramAccessor, from, to, config); final Class toJavaType = SchemaConverters.sqlToJavaConverter() .toJavaType(to); - return CodeGenTestUtil.cookCode(javaCode, toJavaType, INNER_CODE, fromJavaType); + return CodeGenTestUtil.cookCode(javaCode, toJavaType); } private static Object eval( @@ -705,7 +704,9 @@ private static Object eval( final KsqlConfig config, final Object argument ) { - return cookCode(from, to, config).evaluate(argument); + final Map arguments = new HashMap<>(); + arguments.put("val0", argument); + return cookCode(from, to, config).evaluate(arguments); } private static void assertUnsupported( @@ -716,7 +717,10 @@ private static void assertUnsupported( // When: final Exception e = assertThrows( KsqlException.class, - () -> CastEvaluator.generateCode(INNER_CODE, from, to, config) + () -> { + final String paramAccessor = CodeGenUtil.argumentAccessor(INNER_CODE, from); + CastEvaluator.generateCode(paramAccessor, from, to, config); + } ); // Then: diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/NullSafeTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/NullSafeTest.java index 0bff0f78c164..9c79519c3716 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/NullSafeTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/NullSafeTest.java @@ -18,12 +18,12 @@ public void shouldGenerateApply() { // When: final String javaCode = NullSafe - .generateApply("input", mapperCode, Long.class); + .generateApply("arguments.get(\"input\")", mapperCode, Long.class); // Then: - final Evaluator evaluator = CodeGenTestUtil.cookCode(javaCode, Long.class, "input", Long.class); - assertThat(evaluator.evaluate(10L), is(11L)); - assertThat(evaluator.evaluate(null), is(nullValue())); + final Evaluator evaluator = CodeGenTestUtil.cookCode(javaCode, Long.class); + assertThat(evaluator.evaluate("input", 10L), is(11L)); + assertThat(evaluator.evaluate("input", null), is(nullValue())); } @Test @@ -34,11 +34,11 @@ public void shouldGenerateApplyOrDefault() { // When: final String javaCode = NullSafe - .generateApplyOrDefault("input", mapperCode, "99L", Long.class); + .generateApplyOrDefault("arguments.get(\"input\")", mapperCode, "99L", Long.class); // Then: - final Evaluator evaluator = CodeGenTestUtil.cookCode(javaCode, Long.class, "input", Long.class); - assertThat(evaluator.evaluate(10L), is(11L)); - assertThat(evaluator.evaluate(null), is(99L)); + final Evaluator evaluator = CodeGenTestUtil.cookCode(javaCode, Long.class); + assertThat(evaluator.evaluate("input", 10L), is(11L)); + assertThat(evaluator.evaluate("input", null), is(99L)); } } \ No newline at end of file diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java index 42512b1ea689..82d6caa0dfa6 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java @@ -551,10 +551,9 @@ public void shouldEvaluateLambdaArgsToType() { ); // Then: - assertThat(e.getUnloggedMessage(), Matchers.containsString( - "Error processing expression: (A + B). " + - "Unsupported arithmetic types. DOUBLE STRING\n" + - "Statement: (A + B)")); + assertThat(e.getUnloggedMessage(), Matchers.containsString("Error processing expression: (A + B). Unsupported arithmetic types. DOUBLE STRING" + + System.lineSeparator() + + "Statement: (A + B)")); assertThat(e.getMessage(), Matchers.is( "Error processing expression.")); }