diff --git a/docs/Snowflake-batchsource.md b/docs/Snowflake-batchsource.md
index 516bcd0..9b810fe 100644
--- a/docs/Snowflake-batchsource.md
+++ b/docs/Snowflake-batchsource.md
@@ -25,7 +25,9 @@ log in to Snowflake, minus the "snowflakecomputing.com"). E.g. "myaccount.us-cen
**Role:** Role to use (e.g. `ACCOUNTADMIN`).
-**Import Query:** Query for data import.
+**Import Query Type** - Method used to retrieve schema from the source.
+* **Table Name**: The name of the table to retrieve the schema.
+* **Import Query**: Query for data import.
### Credentials
diff --git a/pom.xml b/pom.xml
index 8b43a12..199b4a6 100644
--- a/pom.xml
+++ b/pom.xml
@@ -29,8 +29,8 @@
true
UTF-8
- 6.11.0-SNAPSHOT
- 2.13.0-SNAPSHOT
+ 6.11.0
+ 2.13.0
1.6
3.3.6
3.3.2
diff --git a/src/main/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessor.java b/src/main/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessor.java
index 911f696..8185ae8 100644
--- a/src/main/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessor.java
+++ b/src/main/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessor.java
@@ -37,6 +37,7 @@
import java.io.IOException;
import java.lang.reflect.Field;
import java.sql.Connection;
+import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
@@ -74,6 +75,37 @@ public void runSQL(String query) {
}
}
+ /**
+ * Returns field descriptors for specified table name
+ * @param schemaName The name of schema containing the table
+ * @param tableName The name of table whose metadata needs to be retrieved
+ * @return list of field descriptors
+ */
+ public List getFieldDescriptors(String schemaName, String tableName) {
+ List fieldDescriptors = new ArrayList<>();
+ try (Connection connection = dataSource.getConnection()) {
+ DatabaseMetaData dbMetaData = connection.getMetaData();
+ try (ResultSet columns = dbMetaData.getColumns(null, schemaName, tableName, null)) {
+ while (columns.next()) {
+ String columnName = columns.getString("COLUMN_NAME");
+ int columnType = columns.getInt("DATA_TYPE");
+ boolean nullable = columns.getInt("NULLABLE") == DatabaseMetaData.columnNullable;
+ fieldDescriptors.add(new SnowflakeFieldDescriptor(columnName, columnType, nullable));
+ }
+ }
+ } catch (SQLException e) {
+ String errorMessage = String.format(
+ "Failed to retrieve table metadata with SQL State %s and error code %s with message: %s.",
+ e.getSQLState(), e.getErrorCode(), e.getMessage()
+ );
+ String errorReason = String.format("Failed to retrieve table metadata with SQL State %s and error " +
+ "code %s. For more details %s", e.getSQLState(), e.getErrorCode(),
+ DocumentUrlUtil.getSupportedDocumentUrl());
+ throw SnowflakeErrorType.fetchProgramFailureException(e, errorReason, errorMessage);
+ }
+ return fieldDescriptors;
+ }
+
/**
* Returns field descriptors for specified import query.
*
@@ -193,4 +225,13 @@ private static String writeTextToTmpFile(String text) {
throw new RuntimeException("Cannot write key to temporary file", e);
}
}
+
+ /**
+ * Retrieves schema name from the configuration
+ *
+ * @return The schema name
+ */
+ public String getSchema() {
+ return config.getSchemaName();
+ }
}
diff --git a/src/main/java/io/cdap/plugin/snowflake/common/util/QueryUtil.java b/src/main/java/io/cdap/plugin/snowflake/common/util/QueryUtil.java
index 289ddc1..64f3e22 100644
--- a/src/main/java/io/cdap/plugin/snowflake/common/util/QueryUtil.java
+++ b/src/main/java/io/cdap/plugin/snowflake/common/util/QueryUtil.java
@@ -16,6 +16,8 @@
package io.cdap.plugin.snowflake.common.util;
+import com.google.common.base.Strings;
+
/**
* Transforms import query.
*/
@@ -29,6 +31,9 @@ private QueryUtil() {
}
public static String removeSemicolon(String importQuery) {
+ if (Strings.isNullOrEmpty(importQuery)) {
+ return null;
+ }
if (importQuery.endsWith(";")) {
importQuery = importQuery.substring(0, importQuery.length() - 1);
}
diff --git a/src/main/java/io/cdap/plugin/snowflake/common/util/SchemaHelper.java b/src/main/java/io/cdap/plugin/snowflake/common/util/SchemaHelper.java
index 9302eae..161c132 100644
--- a/src/main/java/io/cdap/plugin/snowflake/common/util/SchemaHelper.java
+++ b/src/main/java/io/cdap/plugin/snowflake/common/util/SchemaHelper.java
@@ -27,11 +27,13 @@
import io.cdap.plugin.snowflake.source.batch.SnowflakeInputFormatProvider;
import io.cdap.plugin.snowflake.source.batch.SnowflakeSourceAccessor;
import java.io.IOException;
+import java.sql.SQLException;
import java.sql.Types;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
+import javax.annotation.Nullable;
/**
* Resolves schema.
@@ -58,6 +60,13 @@ public class SchemaHelper {
private SchemaHelper() {
}
+ /**
+ * Retrieves schema for the Snowflake batch source based on the given configuration.
+ *
+ * @param config The configuration for Snowflake batch source
+ * @param collector The failure collector to capture any schema retrieval errors.
+ * @return The resolved schema for Snowflake source
+ */
public static Schema getSchema(SnowflakeBatchSourceConfig config, FailureCollector collector) {
if (!config.canConnect()) {
return getParsedSchema(config.getSchema());
@@ -65,16 +74,32 @@ public static Schema getSchema(SnowflakeBatchSourceConfig config, FailureCollect
SnowflakeSourceAccessor snowflakeSourceAccessor =
new SnowflakeSourceAccessor(config, SnowflakeInputFormatProvider.PROPERTY_DEFAULT_ESCAPE_CHAR);
- return getSchema(snowflakeSourceAccessor, config.getSchema(), collector, config.getImportQuery());
+ return getSchema(
+ snowflakeSourceAccessor,
+ config.getSchema(),
+ collector,
+ config.getTableName(),
+ config.getImportQuery()
+ );
}
+ /**
+ * Retrieves schema for a Snowflake source based on the provided parameters.
+ *
+ * @param snowflakeAccessor The {@link SnowflakeSourceAccessor} used to connect to Snowflake.
+ * @param schema A JSON-format schema string
+ * @param collector The {@link FailureCollector} to collect errors if schema retrieval fails.
+ * @param tableName The name of the table in Snowflake.
+ * @param importQuery The query to fetch data from Snowflake, used when `tableName` is not provided.
+ * @return The parsed {@link Schema} if successful, or {@code null} if an error occurs.
+ */
public static Schema getSchema(SnowflakeSourceAccessor snowflakeAccessor, String schema,
- FailureCollector collector, String importQuery) {
+ FailureCollector collector, String tableName, String importQuery) {
try {
if (!Strings.isNullOrEmpty(schema)) {
return getParsedSchema(schema);
}
- return Strings.isNullOrEmpty(importQuery) ? null : getSchema(snowflakeAccessor, importQuery);
+ return getSchema(snowflakeAccessor, tableName, importQuery);
} catch (SchemaParseException e) {
collector.addFailure(String.format("Unable to retrieve output schema. Reason: '%s'", e.getMessage()),
null)
@@ -84,7 +109,7 @@ public static Schema getSchema(SnowflakeSourceAccessor snowflakeAccessor, String
}
}
- private static Schema getParsedSchema(String schema) {
+ public static Schema getParsedSchema(String schema) {
if (Strings.isNullOrEmpty(schema)) {
return null;
}
@@ -95,6 +120,26 @@ private static Schema getParsedSchema(String schema) {
}
}
+ private static Schema getSchema(SnowflakeAccessor snowflakeAccessor, @Nullable String tableName,
+ @Nullable String importQuery) {
+ try {
+ List result;
+ if (!Strings.isNullOrEmpty(tableName)) {
+ result = snowflakeAccessor.getFieldDescriptors(snowflakeAccessor.getSchema(), tableName);
+ } else if (!Strings.isNullOrEmpty(importQuery)) {
+ result = snowflakeAccessor.describeQuery(importQuery);
+ } else {
+ return null;
+ }
+ List fields = result.stream()
+ .map(fieldDescriptor -> Schema.Field.of(fieldDescriptor.getName(), getSchema(fieldDescriptor)))
+ .collect(Collectors.toList());
+ return Schema.recordOf("data", fields);
+ } catch (IOException e) {
+ throw new SchemaParseException(e);
+ }
+ }
+
public static Schema getSchema(SnowflakeAccessor snowflakeAccessor, String importQuery) {
try {
List result = snowflakeAccessor.describeQuery(importQuery);
diff --git a/src/main/java/io/cdap/plugin/snowflake/source/batch/ImportQueryType.java b/src/main/java/io/cdap/plugin/snowflake/source/batch/ImportQueryType.java
new file mode 100644
index 0000000..e994ca7
--- /dev/null
+++ b/src/main/java/io/cdap/plugin/snowflake/source/batch/ImportQueryType.java
@@ -0,0 +1,49 @@
+/*
+ * Copyright © 2020 Cask Data, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package io.cdap.plugin.snowflake.source.batch;
+
+
+/**
+ * Enum to specify the import query type used in Snowflake Batch Source.
+ */
+public enum ImportQueryType {
+ IMPORT_QUERY ("importQuery"),
+ TABLE_NAME ("tableName");
+
+ private String value;
+
+ ImportQueryType(String value) {
+ this.value = value;
+ }
+
+ public String getValue() {
+ return value;
+ }
+
+ public static ImportQueryType fromString(String value) {
+ if (value == null) {
+ return ImportQueryType.IMPORT_QUERY;
+ }
+
+ for (ImportQueryType type : ImportQueryType.values()) {
+ if (type.value.equalsIgnoreCase(value)) {
+ return type;
+ }
+ }
+ return ImportQueryType.IMPORT_QUERY;
+ }
+}
diff --git a/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfig.java b/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfig.java
index 561a10f..66e23cc 100644
--- a/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfig.java
+++ b/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfig.java
@@ -16,6 +16,7 @@
package io.cdap.plugin.snowflake.source.batch;
+import com.google.common.base.Strings;
import io.cdap.cdap.api.annotation.Description;
import io.cdap.cdap.api.annotation.Macro;
import io.cdap.cdap.api.annotation.Name;
@@ -34,6 +35,8 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig {
public static final String PROPERTY_IMPORT_QUERY = "importQuery";
public static final String PROPERTY_MAX_SPLIT_SIZE = "maxSplitSize";
public static final String PROPERTY_SCHEMA = "schema";
+ public static final String PROPERTY_TABLE_NAME = "tableName";
+ public static final String PROPERTY_IMPORT_QUERY_TYPE = "importQueryType";
@Name(PROPERTY_REFERENCE_NAME)
@Description("This will be used to uniquely identify this source/sink for lineage, annotating metadata, etc.")
@@ -42,6 +45,7 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig {
@Name(PROPERTY_IMPORT_QUERY)
@Description("Query for import data.")
@Macro
+ @Nullable
private String importQuery;
@Name(PROPERTY_MAX_SPLIT_SIZE)
@@ -55,19 +59,40 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig {
@Macro
private String schema;
+
+ @Name(PROPERTY_TABLE_NAME)
+ @Description("The name of the table used to retrieve the schema.")
+ @Macro
+ @Nullable
+ private final String tableName;
+
+ @Name(PROPERTY_IMPORT_QUERY_TYPE)
+ @Description("Whether to select Table Name or Import Query to extract the data.")
+ @Macro
+ @Nullable
+ private final String importQueryType;
+
+
public SnowflakeBatchSourceConfig(String referenceName, String accountName, String database,
- String schemaName, String importQuery, String username, String password,
+ String schemaName, @Nullable String importQuery, String username, String password,
@Nullable Boolean keyPairEnabled, @Nullable String path,
@Nullable String passphrase, @Nullable Boolean oauth2Enabled,
@Nullable String clientId, @Nullable String clientSecret,
@Nullable String refreshToken, Long maxSplitSize,
- @Nullable String connectionArguments, @Nullable String schema) {
- super(accountName, database, schemaName, username, password,
- keyPairEnabled, path, passphrase, oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments);
+ @Nullable String connectionArguments,
+ @Nullable String schema,
+ @Nullable String tableName,
+ @Nullable String importQueryType) {
+ super(
+ accountName, database, schemaName, username, password, keyPairEnabled, path, passphrase,
+ oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments
+ );
this.referenceName = referenceName;
this.importQuery = importQuery;
this.maxSplitSize = maxSplitSize;
this.schema = schema;
+ this.tableName = tableName;
+ this.importQueryType = getImportQueryType();
}
public String getImportQuery() {
@@ -87,6 +112,16 @@ public String getSchema() {
return schema;
}
+ @Nullable
+ public String getTableName() {
+ return tableName;
+ }
+
+ @Nullable
+ public String getImportQueryType() {
+ return importQueryType == null ? ImportQueryType.IMPORT_QUERY.name() : importQueryType;
+ }
+
public void validate(FailureCollector collector) {
super.validate(collector);
@@ -95,5 +130,28 @@ public void validate(FailureCollector collector) {
collector.addFailure("Maximum Slit Size cannot be a negative number.", null)
.withConfigProperty(PROPERTY_MAX_SPLIT_SIZE);
}
+
+ if (!containsMacro(PROPERTY_IMPORT_QUERY_TYPE)) {
+ boolean isImportQuerySelected = ImportQueryType.IMPORT_QUERY.getValue().equals(importQueryType);
+ boolean isTableNameSelected = ImportQueryType.TABLE_NAME.getValue().equals(importQueryType);
+
+ if (isImportQuerySelected && !containsMacro(PROPERTY_IMPORT_QUERY) && Strings.isNullOrEmpty(importQuery)) {
+ collector.addFailure("Import Query cannot be empty", null)
+ .withConfigProperty(PROPERTY_IMPORT_QUERY);
+
+ } else if (isTableNameSelected && !containsMacro(PROPERTY_TABLE_NAME) && Strings.isNullOrEmpty(tableName)) {
+ collector.addFailure("Table Name cannot be empty", null)
+ .withConfigProperty(PROPERTY_TABLE_NAME);
+ }
+ } else {
+ boolean isImportQueryMissing = !containsMacro(PROPERTY_IMPORT_QUERY) && Strings.isNullOrEmpty(importQuery);
+ boolean isTableNameMissing = !containsMacro(PROPERTY_TABLE_NAME) && Strings.isNullOrEmpty(tableName);
+
+ if (isImportQueryMissing && isTableNameMissing) {
+ collector.addFailure("Either 'Import Query' or 'Table Name' must be provided.", null)
+ .withConfigProperty(PROPERTY_IMPORT_QUERY)
+ .withConfigProperty(PROPERTY_TABLE_NAME);
+ }
+ }
}
}
diff --git a/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeSourceAccessor.java b/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeSourceAccessor.java
index 202be53..9896ccd 100644
--- a/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeSourceAccessor.java
+++ b/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeSourceAccessor.java
@@ -17,10 +17,13 @@
package io.cdap.plugin.snowflake.source.batch;
import au.com.bytecode.opencsv.CSVReader;
+import com.google.common.base.Strings;
+import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.plugin.snowflake.common.SnowflakeErrorType;
import io.cdap.plugin.snowflake.common.client.SnowflakeAccessor;
import io.cdap.plugin.snowflake.common.util.DocumentUrlUtil;
import io.cdap.plugin.snowflake.common.util.QueryUtil;
+import io.cdap.plugin.snowflake.common.util.SchemaHelper;
import net.snowflake.client.jdbc.SnowflakeConnection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -34,6 +37,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
+import java.util.stream.Collectors;
/**
* A class which accesses Snowflake API to do actions used by batch source.
@@ -77,7 +81,20 @@ public SnowflakeSourceAccessor(SnowflakeBatchSourceConfig config, String escapeC
*/
public List prepareStageSplits() {
LOG.info("Loading data into stage: '{}'", STAGE_PATH);
- String copy = String.format(COMAND_COPY_INTO, QueryUtil.removeSemicolon(config.getImportQuery()));
+ String importQuery = config.getImportQuery();
+ if (Strings.isNullOrEmpty(importQuery)) {
+ String tableName = config.getTableName();
+ Schema schema = SchemaHelper.getParsedSchema(config.getSchema());
+ if (schema != null && schema.getFields() != null && !schema.getFields().isEmpty()) {
+ String columns = schema.getFields().stream()
+ .map(Schema.Field::getName)
+ .collect(Collectors.joining(","));
+ importQuery = String.format("SELECT %s FROM %s", columns, tableName);
+ } else {
+ importQuery = String.format("SELECT * FROM %s", tableName);
+ }
+ }
+ String copy = String.format(COMAND_COPY_INTO, QueryUtil.removeSemicolon(importQuery));
if (config.getMaxSplitSize() > 0) {
copy = copy + String.format(COMMAND_MAX_FILE_SIZE, config.getMaxSplitSize());
}
@@ -94,10 +111,13 @@ public List prepareStageSplits() {
}
} catch (SQLException e) {
String errorReason = String.format("Failed to load data into stage '%s' with sqlState %s and errorCode %s. " +
- "For more details, see %s.", STAGE_PATH, e.getErrorCode(), e.getSQLState(),
- DocumentUrlUtil.getSupportedDocumentUrl());
- String errorMessage = String.format("Failed to load data into stage '%s' with sqlState %s and errorCode %s. " +
- "Failed to execute query with message: %s.", STAGE_PATH, e.getSQLState(), e.getErrorCode(), e.getMessage());
+ "For more details, see %s.", STAGE_PATH, e.getErrorCode(), e.getSQLState(),
+ DocumentUrlUtil.getSupportedDocumentUrl());
+ String errorMessage = String.format(
+ "Failed to load data into stage '%s' with sqlState %s and errorCode %s. "
+ + "Failed to execute query with message: %s.",
+ STAGE_PATH, e.getSQLState(), e.getErrorCode(), e.getMessage()
+ );
throw SnowflakeErrorType.fetchProgramFailureException(e, errorReason, errorMessage);
}
return stageSplits;
diff --git a/src/test/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessorTest.java b/src/test/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessorTest.java
index 27b0154..2394d9b 100644
--- a/src/test/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessorTest.java
+++ b/src/test/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessorTest.java
@@ -90,6 +90,43 @@ public void testDescribeQuery() throws Exception {
Assert.assertEquals(expected, actual);
}
+ @Test
+ public void testDescribeTable() throws Exception {
+ String schemaName = "TEST_SCHEMA";
+ String tableName = "TEST_TABLE";
+ List expected = Arrays.asList(
+ new SnowflakeFieldDescriptor("COLUMN_NUMBER", -5, true),
+ new SnowflakeFieldDescriptor("COLUMN_DECIMAL", -5, true),
+ new SnowflakeFieldDescriptor("COLUMN_NUMERIC", -5, true),
+ new SnowflakeFieldDescriptor("COLUMN_INT", -5, true),
+ new SnowflakeFieldDescriptor("COLUMN_INTEGER", -5, true),
+ new SnowflakeFieldDescriptor("COLUMN_BIGINT", -5, true),
+ new SnowflakeFieldDescriptor("COLUMN_SMALLINT", -5, true),
+ new SnowflakeFieldDescriptor("COLUMN_FLOAT", 8, true),
+ new SnowflakeFieldDescriptor("COLUMN_DOUBLE", 8, true),
+ new SnowflakeFieldDescriptor("COLUMN_REAL", 8, true),
+ new SnowflakeFieldDescriptor("COLUMN_VARCHAR", 12, true),
+ new SnowflakeFieldDescriptor("COLUMN_CHAR", 12, true),
+ new SnowflakeFieldDescriptor("COLUMN_TEXT", 12, true),
+ new SnowflakeFieldDescriptor("COLUMN_BINARY", -2, true),
+ new SnowflakeFieldDescriptor("COLUMN_BOOLEAN", 16, true),
+ new SnowflakeFieldDescriptor("COLUMN_DATE", 91, true),
+ new SnowflakeFieldDescriptor("COLUMN_TIMESTAMP", 93, true),
+ new SnowflakeFieldDescriptor("COLUMN_VARIANT", 12, true),
+ new SnowflakeFieldDescriptor("COLUMN_OBJECT", 12, true),
+ new SnowflakeFieldDescriptor("COLUMN_ARRAY", 12, true)
+ );
+
+ List actual = snowflakeAccessor.getFieldDescriptors(
+ String.valueOf(Constants.TEST_TABLE_SCHEMA),
+ Constants.TEST_TABLE);
+
+ Assert.assertNotNull(actual);
+ Assert.assertFalse(actual.isEmpty());
+ Assert.assertEquals(expected, actual);
+ }
+
+
@Test
public void testPrepareStageSplits() throws Exception {
Pattern expected = Pattern.compile("cdap_stage/result.*data__0_0_0\\.csv\\.gz");
diff --git a/src/test/java/io/cdap/plugin/snowflake/common/util/SchemaHelperTest.java b/src/test/java/io/cdap/plugin/snowflake/common/util/SchemaHelperTest.java
index 63bf0af..c6b03b0 100644
--- a/src/test/java/io/cdap/plugin/snowflake/common/util/SchemaHelperTest.java
+++ b/src/test/java/io/cdap/plugin/snowflake/common/util/SchemaHelperTest.java
@@ -27,6 +27,7 @@
import org.mockito.Mockito;
import java.io.IOException;
+import java.sql.SQLException;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Arrays;
@@ -48,7 +49,7 @@ public void testGetSchema() {
);
MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE);
- Schema actual = SchemaHelper.getSchema(null, expected.toString(), collector, null);
+ Schema actual = SchemaHelper.getSchema(null, expected.toString(), collector, null, null);
Assert.assertTrue(collector.getValidationFailures().isEmpty());
Assert.assertEquals(expected, actual);
@@ -57,32 +58,34 @@ public void testGetSchema() {
@Test
public void testGetSchemaInvalidJson() {
MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE);
- SchemaHelper.getSchema(null, "{}", collector, null);
+ SchemaHelper.getSchema(null, "{}", collector, null, null);
ValidationAssertions.assertValidationFailed(
collector, Collections.singletonList(SnowflakeBatchSourceConfig.PROPERTY_SCHEMA));
}
@Test
- public void testGetSchemaFromSnowflakeUnknownType() throws IOException {
+ public void testGetSchemaFromSnowflakeUnknownType() throws IOException, SQLException {
String importQuery = "SELECT * FROM someTable";
+ String tableName = "USER";
MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE);
SnowflakeSourceAccessor snowflakeAccessor = Mockito.mock(SnowflakeSourceAccessor.class);
List sample = new ArrayList<>();
sample.add(new SnowflakeFieldDescriptor("field1", -1000, false));
- Mockito.when(snowflakeAccessor.describeQuery(importQuery)).thenReturn(sample);
+ Mockito.when(snowflakeAccessor.getFieldDescriptors(null, tableName)).thenReturn(sample);
- SchemaHelper.getSchema(snowflakeAccessor, null, collector, importQuery);
+ SchemaHelper.getSchema(snowflakeAccessor, null, collector, tableName, importQuery);
ValidationAssertions.assertValidationFailed(
collector, Collections.singletonList(SnowflakeBatchSourceConfig.PROPERTY_SCHEMA));
}
@Test
- public void testGetSchemaFromSnowflake() throws IOException {
+ public void testGetSchemaFromSnowflake() throws IOException, SQLException {
String importQuery = "SELECT * FROM someTable";
+ String tableName = "USER";
MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE);
SnowflakeSourceAccessor snowflakeAccessor = Mockito.mock(SnowflakeSourceAccessor.class);
@@ -145,8 +148,9 @@ public void testGetSchemaFromSnowflake() throws IOException {
);
Mockito.when(snowflakeAccessor.describeQuery(importQuery)).thenReturn(sample);
+ Mockito.when(snowflakeAccessor.getFieldDescriptors(Mockito.any(), Mockito.eq(tableName))).thenReturn(sample);
- Schema actual = SchemaHelper.getSchema(snowflakeAccessor, null, collector, importQuery);
+ Schema actual = SchemaHelper.getSchema(snowflakeAccessor, null, collector, tableName, importQuery);
Assert.assertTrue(collector.getValidationFailures().isEmpty());
Assert.assertEquals(expected, actual);
diff --git a/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigBuilder.java b/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigBuilder.java
index 7f2f035..febcd3f 100644
--- a/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigBuilder.java
+++ b/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigBuilder.java
@@ -38,7 +38,9 @@ public class SnowflakeBatchSourceConfigBuilder {
"",
0L,
"",
- "");
+ "",
+ "",
+ ImportQueryType.IMPORT_QUERY.name());
private String referenceName;
private String accountName;
@@ -57,6 +59,8 @@ public class SnowflakeBatchSourceConfigBuilder {
private Long maxSplitSize;
private String connectionArguments;
private String schema;
+ private String tableName;
+ private String importQueryType;
public SnowflakeBatchSourceConfigBuilder() {
}
@@ -79,6 +83,8 @@ public SnowflakeBatchSourceConfigBuilder(SnowflakeBatchSourceConfig config) {
this.maxSplitSize = config.getMaxSplitSize();
this.connectionArguments = config.getConnectionArguments();
this.schema = config.getSchema();
+ this.tableName = config.getTableName();
+ this.importQueryType = config.getImportQueryType();
}
public SnowflakeBatchSourceConfigBuilder setReferenceName(String referenceName) {
@@ -166,6 +172,16 @@ public SnowflakeBatchSourceConfigBuilder setSchema(String schema) {
return this;
}
+ public SnowflakeBatchSourceConfigBuilder setTableName(String tableName) {
+ this.tableName = tableName;
+ return this;
+ }
+
+ public SnowflakeBatchSourceConfigBuilder setImportQueryType(String importQueryType) {
+ this.importQueryType = importQueryType;
+ return this;
+ }
+
public SnowflakeBatchSourceConfig build() {
return new SnowflakeBatchSourceConfig(referenceName,
accountName,
@@ -183,6 +199,8 @@ public SnowflakeBatchSourceConfig build() {
refreshToken,
maxSplitSize,
connectionArguments,
- schema);
+ schema,
+ tableName,
+ importQueryType);
}
}
diff --git a/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigTest.java b/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigTest.java
index a071bea..82c1806 100644
--- a/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigTest.java
+++ b/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigTest.java
@@ -20,6 +20,7 @@
import io.cdap.plugin.snowflake.ValidationAssertions;
import org.junit.Assert;
import org.junit.Test;
+
import java.util.Collections;
/**
@@ -68,4 +69,19 @@ public void validatePassword() {
ValidationAssertions.assertValidationFailed(
collector, Collections.singletonList(SnowflakeBatchSourceConfig.PROPERTY_PASSWORD));
}
+
+ @Test
+ public void validateTableNameAndImportQuery() {
+ SnowflakeBatchSourceConfig config =
+ new SnowflakeBatchSourceConfigBuilder(SnowflakeBatchSourceConfigBuilder.CONFIG)
+ .setTableName(null)
+ .setImportQuery(null)
+ .build();
+
+ MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE);
+ config.validate(collector);
+
+ Assert.assertFalse("Both table name and import query cannot be null",
+ collector.getValidationFailures().isEmpty());
+ }
}
diff --git a/widgets/Snowflake-batchsource.json b/widgets/Snowflake-batchsource.json
index 18a670e..df3096a 100644
--- a/widgets/Snowflake-batchsource.json
+++ b/widgets/Snowflake-batchsource.json
@@ -36,6 +36,30 @@
"label": "Role",
"name": "role"
},
+ {
+ "name": "importQueryType",
+ "label": "Import Query Type",
+ "widget-type": "radio-group",
+ "widget-attributes": {
+ "layout": "inline",
+ "default": "importQuery",
+ "options": [
+ {
+ "id": "importQuery",
+ "label": "Native Query"
+ },
+ {
+ "id": "tableName",
+ "label": "Named Table"
+ }
+ ]
+ }
+ },
+ {
+ "widget-type": "textbox",
+ "label": "Table Name",
+ "name": "tableName"
+ },
{
"widget-type": "textarea",
"label": "Import Query",
@@ -200,6 +224,30 @@
}
]
},
+ {
+ "name": "ImportQuery",
+ "condition": {
+ "expression": "importQueryType != 'tableName'"
+ },
+ "show": [
+ {
+ "type": "property",
+ "name": "importQuery"
+ }
+ ]
+ },
+ {
+ "name": "NativeTableName",
+ "condition": {
+ "expression": "importQueryType == 'tableName'"
+ },
+ "show": [
+ {
+ "type": "property",
+ "name": "tableName"
+ }
+ ]
+ },
{
"name": "OAuth2EnabledFilter",
"condition": {