diff --git a/docs/Snowflake-batchsource.md b/docs/Snowflake-batchsource.md index 516bcd0..9b810fe 100644 --- a/docs/Snowflake-batchsource.md +++ b/docs/Snowflake-batchsource.md @@ -25,7 +25,9 @@ log in to Snowflake, minus the "snowflakecomputing.com"). E.g. "myaccount.us-cen **Role:** Role to use (e.g. `ACCOUNTADMIN`). -**Import Query:** Query for data import. +**Import Query Type** - Method used to retrieve schema from the source. +* **Table Name**: The name of the table to retrieve the schema. +* **Import Query**: Query for data import. ### Credentials diff --git a/pom.xml b/pom.xml index 8b43a12..199b4a6 100644 --- a/pom.xml +++ b/pom.xml @@ -29,8 +29,8 @@ true UTF-8 - 6.11.0-SNAPSHOT - 2.13.0-SNAPSHOT + 6.11.0 + 2.13.0 1.6 3.3.6 3.3.2 diff --git a/src/main/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessor.java b/src/main/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessor.java index 911f696..8185ae8 100644 --- a/src/main/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessor.java +++ b/src/main/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessor.java @@ -37,6 +37,7 @@ import java.io.IOException; import java.lang.reflect.Field; import java.sql.Connection; +import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.ResultSetMetaData; @@ -74,6 +75,37 @@ public void runSQL(String query) { } } + /** + * Returns field descriptors for specified table name + * @param schemaName The name of schema containing the table + * @param tableName The name of table whose metadata needs to be retrieved + * @return list of field descriptors + */ + public List getFieldDescriptors(String schemaName, String tableName) { + List fieldDescriptors = new ArrayList<>(); + try (Connection connection = dataSource.getConnection()) { + DatabaseMetaData dbMetaData = connection.getMetaData(); + try (ResultSet columns = dbMetaData.getColumns(null, schemaName, tableName, null)) { + while (columns.next()) { + String columnName = columns.getString("COLUMN_NAME"); + int columnType = columns.getInt("DATA_TYPE"); + boolean nullable = columns.getInt("NULLABLE") == DatabaseMetaData.columnNullable; + fieldDescriptors.add(new SnowflakeFieldDescriptor(columnName, columnType, nullable)); + } + } + } catch (SQLException e) { + String errorMessage = String.format( + "Failed to retrieve table metadata with SQL State %s and error code %s with message: %s.", + e.getSQLState(), e.getErrorCode(), e.getMessage() + ); + String errorReason = String.format("Failed to retrieve table metadata with SQL State %s and error " + + "code %s. For more details %s", e.getSQLState(), e.getErrorCode(), + DocumentUrlUtil.getSupportedDocumentUrl()); + throw SnowflakeErrorType.fetchProgramFailureException(e, errorReason, errorMessage); + } + return fieldDescriptors; + } + /** * Returns field descriptors for specified import query. * @@ -193,4 +225,13 @@ private static String writeTextToTmpFile(String text) { throw new RuntimeException("Cannot write key to temporary file", e); } } + + /** + * Retrieves schema name from the configuration + * + * @return The schema name + */ + public String getSchema() { + return config.getSchemaName(); + } } diff --git a/src/main/java/io/cdap/plugin/snowflake/common/util/QueryUtil.java b/src/main/java/io/cdap/plugin/snowflake/common/util/QueryUtil.java index 289ddc1..64f3e22 100644 --- a/src/main/java/io/cdap/plugin/snowflake/common/util/QueryUtil.java +++ b/src/main/java/io/cdap/plugin/snowflake/common/util/QueryUtil.java @@ -16,6 +16,8 @@ package io.cdap.plugin.snowflake.common.util; +import com.google.common.base.Strings; + /** * Transforms import query. */ @@ -29,6 +31,9 @@ private QueryUtil() { } public static String removeSemicolon(String importQuery) { + if (Strings.isNullOrEmpty(importQuery)) { + return null; + } if (importQuery.endsWith(";")) { importQuery = importQuery.substring(0, importQuery.length() - 1); } diff --git a/src/main/java/io/cdap/plugin/snowflake/common/util/SchemaHelper.java b/src/main/java/io/cdap/plugin/snowflake/common/util/SchemaHelper.java index 9302eae..161c132 100644 --- a/src/main/java/io/cdap/plugin/snowflake/common/util/SchemaHelper.java +++ b/src/main/java/io/cdap/plugin/snowflake/common/util/SchemaHelper.java @@ -27,11 +27,13 @@ import io.cdap.plugin.snowflake.source.batch.SnowflakeInputFormatProvider; import io.cdap.plugin.snowflake.source.batch.SnowflakeSourceAccessor; import java.io.IOException; +import java.sql.SQLException; import java.sql.Types; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; +import javax.annotation.Nullable; /** * Resolves schema. @@ -58,6 +60,13 @@ public class SchemaHelper { private SchemaHelper() { } + /** + * Retrieves schema for the Snowflake batch source based on the given configuration. + * + * @param config The configuration for Snowflake batch source + * @param collector The failure collector to capture any schema retrieval errors. + * @return The resolved schema for Snowflake source + */ public static Schema getSchema(SnowflakeBatchSourceConfig config, FailureCollector collector) { if (!config.canConnect()) { return getParsedSchema(config.getSchema()); @@ -65,16 +74,32 @@ public static Schema getSchema(SnowflakeBatchSourceConfig config, FailureCollect SnowflakeSourceAccessor snowflakeSourceAccessor = new SnowflakeSourceAccessor(config, SnowflakeInputFormatProvider.PROPERTY_DEFAULT_ESCAPE_CHAR); - return getSchema(snowflakeSourceAccessor, config.getSchema(), collector, config.getImportQuery()); + return getSchema( + snowflakeSourceAccessor, + config.getSchema(), + collector, + config.getTableName(), + config.getImportQuery() + ); } + /** + * Retrieves schema for a Snowflake source based on the provided parameters. + * + * @param snowflakeAccessor The {@link SnowflakeSourceAccessor} used to connect to Snowflake. + * @param schema A JSON-format schema string + * @param collector The {@link FailureCollector} to collect errors if schema retrieval fails. + * @param tableName The name of the table in Snowflake. + * @param importQuery The query to fetch data from Snowflake, used when `tableName` is not provided. + * @return The parsed {@link Schema} if successful, or {@code null} if an error occurs. + */ public static Schema getSchema(SnowflakeSourceAccessor snowflakeAccessor, String schema, - FailureCollector collector, String importQuery) { + FailureCollector collector, String tableName, String importQuery) { try { if (!Strings.isNullOrEmpty(schema)) { return getParsedSchema(schema); } - return Strings.isNullOrEmpty(importQuery) ? null : getSchema(snowflakeAccessor, importQuery); + return getSchema(snowflakeAccessor, tableName, importQuery); } catch (SchemaParseException e) { collector.addFailure(String.format("Unable to retrieve output schema. Reason: '%s'", e.getMessage()), null) @@ -84,7 +109,7 @@ public static Schema getSchema(SnowflakeSourceAccessor snowflakeAccessor, String } } - private static Schema getParsedSchema(String schema) { + public static Schema getParsedSchema(String schema) { if (Strings.isNullOrEmpty(schema)) { return null; } @@ -95,6 +120,26 @@ private static Schema getParsedSchema(String schema) { } } + private static Schema getSchema(SnowflakeAccessor snowflakeAccessor, @Nullable String tableName, + @Nullable String importQuery) { + try { + List result; + if (!Strings.isNullOrEmpty(tableName)) { + result = snowflakeAccessor.getFieldDescriptors(snowflakeAccessor.getSchema(), tableName); + } else if (!Strings.isNullOrEmpty(importQuery)) { + result = snowflakeAccessor.describeQuery(importQuery); + } else { + return null; + } + List fields = result.stream() + .map(fieldDescriptor -> Schema.Field.of(fieldDescriptor.getName(), getSchema(fieldDescriptor))) + .collect(Collectors.toList()); + return Schema.recordOf("data", fields); + } catch (IOException e) { + throw new SchemaParseException(e); + } + } + public static Schema getSchema(SnowflakeAccessor snowflakeAccessor, String importQuery) { try { List result = snowflakeAccessor.describeQuery(importQuery); diff --git a/src/main/java/io/cdap/plugin/snowflake/source/batch/ImportQueryType.java b/src/main/java/io/cdap/plugin/snowflake/source/batch/ImportQueryType.java new file mode 100644 index 0000000..e994ca7 --- /dev/null +++ b/src/main/java/io/cdap/plugin/snowflake/source/batch/ImportQueryType.java @@ -0,0 +1,49 @@ +/* + * Copyright © 2020 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.plugin.snowflake.source.batch; + + +/** + * Enum to specify the import query type used in Snowflake Batch Source. + */ +public enum ImportQueryType { + IMPORT_QUERY ("importQuery"), + TABLE_NAME ("tableName"); + + private String value; + + ImportQueryType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + public static ImportQueryType fromString(String value) { + if (value == null) { + return ImportQueryType.IMPORT_QUERY; + } + + for (ImportQueryType type : ImportQueryType.values()) { + if (type.value.equalsIgnoreCase(value)) { + return type; + } + } + return ImportQueryType.IMPORT_QUERY; + } +} diff --git a/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfig.java b/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfig.java index 561a10f..66e23cc 100644 --- a/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfig.java +++ b/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfig.java @@ -16,6 +16,7 @@ package io.cdap.plugin.snowflake.source.batch; +import com.google.common.base.Strings; import io.cdap.cdap.api.annotation.Description; import io.cdap.cdap.api.annotation.Macro; import io.cdap.cdap.api.annotation.Name; @@ -34,6 +35,8 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig { public static final String PROPERTY_IMPORT_QUERY = "importQuery"; public static final String PROPERTY_MAX_SPLIT_SIZE = "maxSplitSize"; public static final String PROPERTY_SCHEMA = "schema"; + public static final String PROPERTY_TABLE_NAME = "tableName"; + public static final String PROPERTY_IMPORT_QUERY_TYPE = "importQueryType"; @Name(PROPERTY_REFERENCE_NAME) @Description("This will be used to uniquely identify this source/sink for lineage, annotating metadata, etc.") @@ -42,6 +45,7 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig { @Name(PROPERTY_IMPORT_QUERY) @Description("Query for import data.") @Macro + @Nullable private String importQuery; @Name(PROPERTY_MAX_SPLIT_SIZE) @@ -55,19 +59,40 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig { @Macro private String schema; + + @Name(PROPERTY_TABLE_NAME) + @Description("The name of the table used to retrieve the schema.") + @Macro + @Nullable + private final String tableName; + + @Name(PROPERTY_IMPORT_QUERY_TYPE) + @Description("Whether to select Table Name or Import Query to extract the data.") + @Macro + @Nullable + private final String importQueryType; + + public SnowflakeBatchSourceConfig(String referenceName, String accountName, String database, - String schemaName, String importQuery, String username, String password, + String schemaName, @Nullable String importQuery, String username, String password, @Nullable Boolean keyPairEnabled, @Nullable String path, @Nullable String passphrase, @Nullable Boolean oauth2Enabled, @Nullable String clientId, @Nullable String clientSecret, @Nullable String refreshToken, Long maxSplitSize, - @Nullable String connectionArguments, @Nullable String schema) { - super(accountName, database, schemaName, username, password, - keyPairEnabled, path, passphrase, oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments); + @Nullable String connectionArguments, + @Nullable String schema, + @Nullable String tableName, + @Nullable String importQueryType) { + super( + accountName, database, schemaName, username, password, keyPairEnabled, path, passphrase, + oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments + ); this.referenceName = referenceName; this.importQuery = importQuery; this.maxSplitSize = maxSplitSize; this.schema = schema; + this.tableName = tableName; + this.importQueryType = getImportQueryType(); } public String getImportQuery() { @@ -87,6 +112,16 @@ public String getSchema() { return schema; } + @Nullable + public String getTableName() { + return tableName; + } + + @Nullable + public String getImportQueryType() { + return importQueryType == null ? ImportQueryType.IMPORT_QUERY.name() : importQueryType; + } + public void validate(FailureCollector collector) { super.validate(collector); @@ -95,5 +130,28 @@ public void validate(FailureCollector collector) { collector.addFailure("Maximum Slit Size cannot be a negative number.", null) .withConfigProperty(PROPERTY_MAX_SPLIT_SIZE); } + + if (!containsMacro(PROPERTY_IMPORT_QUERY_TYPE)) { + boolean isImportQuerySelected = ImportQueryType.IMPORT_QUERY.getValue().equals(importQueryType); + boolean isTableNameSelected = ImportQueryType.TABLE_NAME.getValue().equals(importQueryType); + + if (isImportQuerySelected && !containsMacro(PROPERTY_IMPORT_QUERY) && Strings.isNullOrEmpty(importQuery)) { + collector.addFailure("Import Query cannot be empty", null) + .withConfigProperty(PROPERTY_IMPORT_QUERY); + + } else if (isTableNameSelected && !containsMacro(PROPERTY_TABLE_NAME) && Strings.isNullOrEmpty(tableName)) { + collector.addFailure("Table Name cannot be empty", null) + .withConfigProperty(PROPERTY_TABLE_NAME); + } + } else { + boolean isImportQueryMissing = !containsMacro(PROPERTY_IMPORT_QUERY) && Strings.isNullOrEmpty(importQuery); + boolean isTableNameMissing = !containsMacro(PROPERTY_TABLE_NAME) && Strings.isNullOrEmpty(tableName); + + if (isImportQueryMissing && isTableNameMissing) { + collector.addFailure("Either 'Import Query' or 'Table Name' must be provided.", null) + .withConfigProperty(PROPERTY_IMPORT_QUERY) + .withConfigProperty(PROPERTY_TABLE_NAME); + } + } } } diff --git a/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeSourceAccessor.java b/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeSourceAccessor.java index 202be53..9896ccd 100644 --- a/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeSourceAccessor.java +++ b/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeSourceAccessor.java @@ -17,10 +17,13 @@ package io.cdap.plugin.snowflake.source.batch; import au.com.bytecode.opencsv.CSVReader; +import com.google.common.base.Strings; +import io.cdap.cdap.api.data.schema.Schema; import io.cdap.plugin.snowflake.common.SnowflakeErrorType; import io.cdap.plugin.snowflake.common.client.SnowflakeAccessor; import io.cdap.plugin.snowflake.common.util.DocumentUrlUtil; import io.cdap.plugin.snowflake.common.util.QueryUtil; +import io.cdap.plugin.snowflake.common.util.SchemaHelper; import net.snowflake.client.jdbc.SnowflakeConnection; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -34,6 +37,7 @@ import java.util.ArrayList; import java.util.List; import java.util.UUID; +import java.util.stream.Collectors; /** * A class which accesses Snowflake API to do actions used by batch source. @@ -77,7 +81,20 @@ public SnowflakeSourceAccessor(SnowflakeBatchSourceConfig config, String escapeC */ public List prepareStageSplits() { LOG.info("Loading data into stage: '{}'", STAGE_PATH); - String copy = String.format(COMAND_COPY_INTO, QueryUtil.removeSemicolon(config.getImportQuery())); + String importQuery = config.getImportQuery(); + if (Strings.isNullOrEmpty(importQuery)) { + String tableName = config.getTableName(); + Schema schema = SchemaHelper.getParsedSchema(config.getSchema()); + if (schema != null && schema.getFields() != null && !schema.getFields().isEmpty()) { + String columns = schema.getFields().stream() + .map(Schema.Field::getName) + .collect(Collectors.joining(",")); + importQuery = String.format("SELECT %s FROM %s", columns, tableName); + } else { + importQuery = String.format("SELECT * FROM %s", tableName); + } + } + String copy = String.format(COMAND_COPY_INTO, QueryUtil.removeSemicolon(importQuery)); if (config.getMaxSplitSize() > 0) { copy = copy + String.format(COMMAND_MAX_FILE_SIZE, config.getMaxSplitSize()); } @@ -94,10 +111,13 @@ public List prepareStageSplits() { } } catch (SQLException e) { String errorReason = String.format("Failed to load data into stage '%s' with sqlState %s and errorCode %s. " + - "For more details, see %s.", STAGE_PATH, e.getErrorCode(), e.getSQLState(), - DocumentUrlUtil.getSupportedDocumentUrl()); - String errorMessage = String.format("Failed to load data into stage '%s' with sqlState %s and errorCode %s. " + - "Failed to execute query with message: %s.", STAGE_PATH, e.getSQLState(), e.getErrorCode(), e.getMessage()); + "For more details, see %s.", STAGE_PATH, e.getErrorCode(), e.getSQLState(), + DocumentUrlUtil.getSupportedDocumentUrl()); + String errorMessage = String.format( + "Failed to load data into stage '%s' with sqlState %s and errorCode %s. " + + "Failed to execute query with message: %s.", + STAGE_PATH, e.getSQLState(), e.getErrorCode(), e.getMessage() + ); throw SnowflakeErrorType.fetchProgramFailureException(e, errorReason, errorMessage); } return stageSplits; diff --git a/src/test/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessorTest.java b/src/test/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessorTest.java index 27b0154..2394d9b 100644 --- a/src/test/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessorTest.java +++ b/src/test/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessorTest.java @@ -90,6 +90,43 @@ public void testDescribeQuery() throws Exception { Assert.assertEquals(expected, actual); } + @Test + public void testDescribeTable() throws Exception { + String schemaName = "TEST_SCHEMA"; + String tableName = "TEST_TABLE"; + List expected = Arrays.asList( + new SnowflakeFieldDescriptor("COLUMN_NUMBER", -5, true), + new SnowflakeFieldDescriptor("COLUMN_DECIMAL", -5, true), + new SnowflakeFieldDescriptor("COLUMN_NUMERIC", -5, true), + new SnowflakeFieldDescriptor("COLUMN_INT", -5, true), + new SnowflakeFieldDescriptor("COLUMN_INTEGER", -5, true), + new SnowflakeFieldDescriptor("COLUMN_BIGINT", -5, true), + new SnowflakeFieldDescriptor("COLUMN_SMALLINT", -5, true), + new SnowflakeFieldDescriptor("COLUMN_FLOAT", 8, true), + new SnowflakeFieldDescriptor("COLUMN_DOUBLE", 8, true), + new SnowflakeFieldDescriptor("COLUMN_REAL", 8, true), + new SnowflakeFieldDescriptor("COLUMN_VARCHAR", 12, true), + new SnowflakeFieldDescriptor("COLUMN_CHAR", 12, true), + new SnowflakeFieldDescriptor("COLUMN_TEXT", 12, true), + new SnowflakeFieldDescriptor("COLUMN_BINARY", -2, true), + new SnowflakeFieldDescriptor("COLUMN_BOOLEAN", 16, true), + new SnowflakeFieldDescriptor("COLUMN_DATE", 91, true), + new SnowflakeFieldDescriptor("COLUMN_TIMESTAMP", 93, true), + new SnowflakeFieldDescriptor("COLUMN_VARIANT", 12, true), + new SnowflakeFieldDescriptor("COLUMN_OBJECT", 12, true), + new SnowflakeFieldDescriptor("COLUMN_ARRAY", 12, true) + ); + + List actual = snowflakeAccessor.getFieldDescriptors( + String.valueOf(Constants.TEST_TABLE_SCHEMA), + Constants.TEST_TABLE); + + Assert.assertNotNull(actual); + Assert.assertFalse(actual.isEmpty()); + Assert.assertEquals(expected, actual); + } + + @Test public void testPrepareStageSplits() throws Exception { Pattern expected = Pattern.compile("cdap_stage/result.*data__0_0_0\\.csv\\.gz"); diff --git a/src/test/java/io/cdap/plugin/snowflake/common/util/SchemaHelperTest.java b/src/test/java/io/cdap/plugin/snowflake/common/util/SchemaHelperTest.java index 63bf0af..c6b03b0 100644 --- a/src/test/java/io/cdap/plugin/snowflake/common/util/SchemaHelperTest.java +++ b/src/test/java/io/cdap/plugin/snowflake/common/util/SchemaHelperTest.java @@ -27,6 +27,7 @@ import org.mockito.Mockito; import java.io.IOException; +import java.sql.SQLException; import java.sql.Types; import java.util.ArrayList; import java.util.Arrays; @@ -48,7 +49,7 @@ public void testGetSchema() { ); MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE); - Schema actual = SchemaHelper.getSchema(null, expected.toString(), collector, null); + Schema actual = SchemaHelper.getSchema(null, expected.toString(), collector, null, null); Assert.assertTrue(collector.getValidationFailures().isEmpty()); Assert.assertEquals(expected, actual); @@ -57,32 +58,34 @@ public void testGetSchema() { @Test public void testGetSchemaInvalidJson() { MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE); - SchemaHelper.getSchema(null, "{}", collector, null); + SchemaHelper.getSchema(null, "{}", collector, null, null); ValidationAssertions.assertValidationFailed( collector, Collections.singletonList(SnowflakeBatchSourceConfig.PROPERTY_SCHEMA)); } @Test - public void testGetSchemaFromSnowflakeUnknownType() throws IOException { + public void testGetSchemaFromSnowflakeUnknownType() throws IOException, SQLException { String importQuery = "SELECT * FROM someTable"; + String tableName = "USER"; MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE); SnowflakeSourceAccessor snowflakeAccessor = Mockito.mock(SnowflakeSourceAccessor.class); List sample = new ArrayList<>(); sample.add(new SnowflakeFieldDescriptor("field1", -1000, false)); - Mockito.when(snowflakeAccessor.describeQuery(importQuery)).thenReturn(sample); + Mockito.when(snowflakeAccessor.getFieldDescriptors(null, tableName)).thenReturn(sample); - SchemaHelper.getSchema(snowflakeAccessor, null, collector, importQuery); + SchemaHelper.getSchema(snowflakeAccessor, null, collector, tableName, importQuery); ValidationAssertions.assertValidationFailed( collector, Collections.singletonList(SnowflakeBatchSourceConfig.PROPERTY_SCHEMA)); } @Test - public void testGetSchemaFromSnowflake() throws IOException { + public void testGetSchemaFromSnowflake() throws IOException, SQLException { String importQuery = "SELECT * FROM someTable"; + String tableName = "USER"; MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE); SnowflakeSourceAccessor snowflakeAccessor = Mockito.mock(SnowflakeSourceAccessor.class); @@ -145,8 +148,9 @@ public void testGetSchemaFromSnowflake() throws IOException { ); Mockito.when(snowflakeAccessor.describeQuery(importQuery)).thenReturn(sample); + Mockito.when(snowflakeAccessor.getFieldDescriptors(Mockito.any(), Mockito.eq(tableName))).thenReturn(sample); - Schema actual = SchemaHelper.getSchema(snowflakeAccessor, null, collector, importQuery); + Schema actual = SchemaHelper.getSchema(snowflakeAccessor, null, collector, tableName, importQuery); Assert.assertTrue(collector.getValidationFailures().isEmpty()); Assert.assertEquals(expected, actual); diff --git a/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigBuilder.java b/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigBuilder.java index 7f2f035..febcd3f 100644 --- a/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigBuilder.java +++ b/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigBuilder.java @@ -38,7 +38,9 @@ public class SnowflakeBatchSourceConfigBuilder { "", 0L, "", - ""); + "", + "", + ImportQueryType.IMPORT_QUERY.name()); private String referenceName; private String accountName; @@ -57,6 +59,8 @@ public class SnowflakeBatchSourceConfigBuilder { private Long maxSplitSize; private String connectionArguments; private String schema; + private String tableName; + private String importQueryType; public SnowflakeBatchSourceConfigBuilder() { } @@ -79,6 +83,8 @@ public SnowflakeBatchSourceConfigBuilder(SnowflakeBatchSourceConfig config) { this.maxSplitSize = config.getMaxSplitSize(); this.connectionArguments = config.getConnectionArguments(); this.schema = config.getSchema(); + this.tableName = config.getTableName(); + this.importQueryType = config.getImportQueryType(); } public SnowflakeBatchSourceConfigBuilder setReferenceName(String referenceName) { @@ -166,6 +172,16 @@ public SnowflakeBatchSourceConfigBuilder setSchema(String schema) { return this; } + public SnowflakeBatchSourceConfigBuilder setTableName(String tableName) { + this.tableName = tableName; + return this; + } + + public SnowflakeBatchSourceConfigBuilder setImportQueryType(String importQueryType) { + this.importQueryType = importQueryType; + return this; + } + public SnowflakeBatchSourceConfig build() { return new SnowflakeBatchSourceConfig(referenceName, accountName, @@ -183,6 +199,8 @@ public SnowflakeBatchSourceConfig build() { refreshToken, maxSplitSize, connectionArguments, - schema); + schema, + tableName, + importQueryType); } } diff --git a/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigTest.java b/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigTest.java index a071bea..82c1806 100644 --- a/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigTest.java +++ b/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigTest.java @@ -20,6 +20,7 @@ import io.cdap.plugin.snowflake.ValidationAssertions; import org.junit.Assert; import org.junit.Test; + import java.util.Collections; /** @@ -68,4 +69,19 @@ public void validatePassword() { ValidationAssertions.assertValidationFailed( collector, Collections.singletonList(SnowflakeBatchSourceConfig.PROPERTY_PASSWORD)); } + + @Test + public void validateTableNameAndImportQuery() { + SnowflakeBatchSourceConfig config = + new SnowflakeBatchSourceConfigBuilder(SnowflakeBatchSourceConfigBuilder.CONFIG) + .setTableName(null) + .setImportQuery(null) + .build(); + + MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE); + config.validate(collector); + + Assert.assertFalse("Both table name and import query cannot be null", + collector.getValidationFailures().isEmpty()); + } } diff --git a/widgets/Snowflake-batchsource.json b/widgets/Snowflake-batchsource.json index 18a670e..df3096a 100644 --- a/widgets/Snowflake-batchsource.json +++ b/widgets/Snowflake-batchsource.json @@ -36,6 +36,30 @@ "label": "Role", "name": "role" }, + { + "name": "importQueryType", + "label": "Import Query Type", + "widget-type": "radio-group", + "widget-attributes": { + "layout": "inline", + "default": "importQuery", + "options": [ + { + "id": "importQuery", + "label": "Native Query" + }, + { + "id": "tableName", + "label": "Named Table" + } + ] + } + }, + { + "widget-type": "textbox", + "label": "Table Name", + "name": "tableName" + }, { "widget-type": "textarea", "label": "Import Query", @@ -200,6 +224,30 @@ } ] }, + { + "name": "ImportQuery", + "condition": { + "expression": "importQueryType != 'tableName'" + }, + "show": [ + { + "type": "property", + "name": "importQuery" + } + ] + }, + { + "name": "NativeTableName", + "condition": { + "expression": "importQueryType == 'tableName'" + }, + "show": [ + { + "type": "property", + "name": "tableName" + } + ] + }, { "name": "OAuth2EnabledFilter", "condition": {