Skip to content

PLUGIN-1883: SnowFlake Plugin - Fetch schema using Named Table #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/Snowflake-batchsource.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ log in to Snowflake, minus the "snowflakecomputing.com"). E.g. "myaccount.us-cen

**Role:** Role to use (e.g. `ACCOUNTADMIN`).

**Import Query:** Query for data import.
**Import Query Type** - Method used to retrieve schema from the source.
* **Table Name**: The name of the table to retrieve the schema.
* **Import Query**: Query for data import.

### Credentials

Expand Down
4 changes: 2 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
<surefire.redirectTestOutputToFile>true</surefire.redirectTestOutputToFile>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<!-- version properties -->
<cdap.version>6.11.0-SNAPSHOT</cdap.version>
<hydrator.version>2.13.0-SNAPSHOT</hydrator.version>
<cdap.version>6.11.0</cdap.version>
<hydrator.version>2.13.0</hydrator.version>
<commons.csv.version>1.6</commons.csv.version>
<hadoop.version>3.3.6</hadoop.version>
<spark3.version>3.3.2</spark3.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import java.io.IOException;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
Expand Down Expand Up @@ -74,6 +75,37 @@ public void runSQL(String query) {
}
}

/**
* Returns field descriptors for specified table name
* @param schemaName The name of schema containing the table
* @param tableName The name of table whose metadata needs to be retrieved
* @return list of field descriptors
*/
public List<SnowflakeFieldDescriptor> getFieldDescriptors(String schemaName, String tableName) {
List<SnowflakeFieldDescriptor> fieldDescriptors = new ArrayList<>();
try (Connection connection = dataSource.getConnection()) {
DatabaseMetaData dbMetaData = connection.getMetaData();
try (ResultSet columns = dbMetaData.getColumns(null, schemaName, tableName, null)) {
while (columns.next()) {
String columnName = columns.getString("COLUMN_NAME");
int columnType = columns.getInt("DATA_TYPE");
boolean nullable = columns.getInt("NULLABLE") == DatabaseMetaData.columnNullable;
fieldDescriptors.add(new SnowflakeFieldDescriptor(columnName, columnType, nullable));
}
}
} catch (SQLException e) {
String errorMessage = String.format(
"Failed to retrieve table metadata with SQL State %s and error code %s with message: %s.",
e.getSQLState(), e.getErrorCode(), e.getMessage()
);
String errorReason = String.format("Failed to retrieve table metadata with SQL State %s and error " +
"code %s. For more details %s", e.getSQLState(), e.getErrorCode(),
DocumentUrlUtil.getSupportedDocumentUrl());
throw SnowflakeErrorType.fetchProgramFailureException(e, errorReason, errorMessage);
}
return fieldDescriptors;
}

/**
* Returns field descriptors for specified import query.
*
Expand Down Expand Up @@ -193,4 +225,13 @@ private static String writeTextToTmpFile(String text) {
throw new RuntimeException("Cannot write key to temporary file", e);
}
}

/**
* Retrieves schema name from the configuration
*
* @return The schema name
*/
public String getSchema() {
return config.getSchemaName();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package io.cdap.plugin.snowflake.common.util;

import com.google.common.base.Strings;

/**
* Transforms import query.
*/
Expand All @@ -29,6 +31,9 @@ private QueryUtil() {
}

public static String removeSemicolon(String importQuery) {
if (Strings.isNullOrEmpty(importQuery)) {
return null;
}
if (importQuery.endsWith(";")) {
importQuery = importQuery.substring(0, importQuery.length() - 1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
import io.cdap.plugin.snowflake.source.batch.SnowflakeInputFormatProvider;
import io.cdap.plugin.snowflake.source.batch.SnowflakeSourceAccessor;
import java.io.IOException;
import java.sql.SQLException;
import java.sql.Types;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import javax.annotation.Nullable;

/**
* Resolves schema.
Expand All @@ -58,23 +60,46 @@ public class SchemaHelper {
private SchemaHelper() {
}

/**
* Retrieves schema for the Snowflake batch source based on the given configuration.
*
* @param config The configuration for Snowflake batch source
* @param collector The failure collector to capture any schema retrieval errors.
* @return The resolved schema for Snowflake source
*/
public static Schema getSchema(SnowflakeBatchSourceConfig config, FailureCollector collector) {
if (!config.canConnect()) {
return getParsedSchema(config.getSchema());
}

SnowflakeSourceAccessor snowflakeSourceAccessor =
new SnowflakeSourceAccessor(config, SnowflakeInputFormatProvider.PROPERTY_DEFAULT_ESCAPE_CHAR);
return getSchema(snowflakeSourceAccessor, config.getSchema(), collector, config.getImportQuery());
return getSchema(
snowflakeSourceAccessor,
config.getSchema(),
collector,
config.getTableName(),
config.getImportQuery()
);
}

/**
* Retrieves schema for a Snowflake source based on the provided parameters.
*
* @param snowflakeAccessor The {@link SnowflakeSourceAccessor} used to connect to Snowflake.
* @param schema A JSON-format schema string
* @param collector The {@link FailureCollector} to collect errors if schema retrieval fails.
* @param tableName The name of the table in Snowflake.
* @param importQuery The query to fetch data from Snowflake, used when `tableName` is not provided.
* @return The parsed {@link Schema} if successful, or {@code null} if an error occurs.
*/
public static Schema getSchema(SnowflakeSourceAccessor snowflakeAccessor, String schema,
FailureCollector collector, String importQuery) {
FailureCollector collector, String tableName, String importQuery) {
try {
if (!Strings.isNullOrEmpty(schema)) {
return getParsedSchema(schema);
}
return Strings.isNullOrEmpty(importQuery) ? null : getSchema(snowflakeAccessor, importQuery);
return getSchema(snowflakeAccessor, tableName, importQuery);
} catch (SchemaParseException e) {
collector.addFailure(String.format("Unable to retrieve output schema. Reason: '%s'", e.getMessage()),
null)
Expand All @@ -84,7 +109,7 @@ public static Schema getSchema(SnowflakeSourceAccessor snowflakeAccessor, String
}
}

private static Schema getParsedSchema(String schema) {
public static Schema getParsedSchema(String schema) {
if (Strings.isNullOrEmpty(schema)) {
return null;
}
Expand All @@ -95,6 +120,26 @@ private static Schema getParsedSchema(String schema) {
}
}

private static Schema getSchema(SnowflakeAccessor snowflakeAccessor, @Nullable String tableName,
@Nullable String importQuery) {
try {
List<SnowflakeFieldDescriptor> result;
if (!Strings.isNullOrEmpty(tableName)) {
result = snowflakeAccessor.getFieldDescriptors(snowflakeAccessor.getSchema(), tableName);
} else if (!Strings.isNullOrEmpty(importQuery)) {
result = snowflakeAccessor.describeQuery(importQuery);
} else {
return null;
}
List<Schema.Field> fields = result.stream()
.map(fieldDescriptor -> Schema.Field.of(fieldDescriptor.getName(), getSchema(fieldDescriptor)))
.collect(Collectors.toList());
return Schema.recordOf("data", fields);
} catch (IOException e) {
throw new SchemaParseException(e);
}
}

public static Schema getSchema(SnowflakeAccessor snowflakeAccessor, String importQuery) {
try {
List<SnowflakeFieldDescriptor> result = snowflakeAccessor.describeQuery(importQuery);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright © 2020 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/

package io.cdap.plugin.snowflake.source.batch;


/**
* Enum to specify the import query type used in Snowflake Batch Source.
*/
public enum ImportQueryType {
IMPORT_QUERY ("importQuery"),
TABLE_NAME ("tableName");

private String value;

ImportQueryType(String value) {
this.value = value;
}

public String getValue() {
return value;
}

public static ImportQueryType fromString(String value) {
if (value == null) {
return ImportQueryType.IMPORT_QUERY;
}

for (ImportQueryType type : ImportQueryType.values()) {
if (type.value.equalsIgnoreCase(value)) {
return type;
}
}
return ImportQueryType.IMPORT_QUERY;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.cdap.plugin.snowflake.source.batch;

import com.google.common.base.Strings;
import io.cdap.cdap.api.annotation.Description;
import io.cdap.cdap.api.annotation.Macro;
import io.cdap.cdap.api.annotation.Name;
Expand All @@ -34,6 +35,8 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig {
public static final String PROPERTY_IMPORT_QUERY = "importQuery";
public static final String PROPERTY_MAX_SPLIT_SIZE = "maxSplitSize";
public static final String PROPERTY_SCHEMA = "schema";
public static final String PROPERTY_TABLE_NAME = "tableName";
public static final String PROPERTY_IMPORT_QUERY_TYPE = "importQueryType";

@Name(PROPERTY_REFERENCE_NAME)
@Description("This will be used to uniquely identify this source/sink for lineage, annotating metadata, etc.")
Expand All @@ -42,6 +45,7 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig {
@Name(PROPERTY_IMPORT_QUERY)
@Description("Query for import data.")
@Macro
@Nullable
private String importQuery;

@Name(PROPERTY_MAX_SPLIT_SIZE)
Expand All @@ -55,19 +59,40 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig {
@Macro
private String schema;


@Name(PROPERTY_TABLE_NAME)
@Description("The name of the table used to retrieve the schema.")
@Macro
@Nullable
private final String tableName;

@Name(PROPERTY_IMPORT_QUERY_TYPE)
@Description("Whether to select Table Name or Import Query to extract the data.")
@Macro
@Nullable
private final String importQueryType;


public SnowflakeBatchSourceConfig(String referenceName, String accountName, String database,
String schemaName, String importQuery, String username, String password,
String schemaName, @Nullable String importQuery, String username, String password,
@Nullable Boolean keyPairEnabled, @Nullable String path,
@Nullable String passphrase, @Nullable Boolean oauth2Enabled,
@Nullable String clientId, @Nullable String clientSecret,
@Nullable String refreshToken, Long maxSplitSize,
@Nullable String connectionArguments, @Nullable String schema) {
super(accountName, database, schemaName, username, password,
keyPairEnabled, path, passphrase, oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments);
@Nullable String connectionArguments,
@Nullable String schema,
@Nullable String tableName,
@Nullable String importQueryType) {
super(
accountName, database, schemaName, username, password, keyPairEnabled, path, passphrase,
oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments
);
this.referenceName = referenceName;
this.importQuery = importQuery;
this.maxSplitSize = maxSplitSize;
this.schema = schema;
this.tableName = tableName;
this.importQueryType = getImportQueryType();
}

public String getImportQuery() {
Expand All @@ -87,6 +112,16 @@ public String getSchema() {
return schema;
}

@Nullable
public String getTableName() {
return tableName;
}

@Nullable
public String getImportQueryType() {
return importQueryType == null ? ImportQueryType.IMPORT_QUERY.name() : importQueryType;
}

public void validate(FailureCollector collector) {
super.validate(collector);

Expand All @@ -95,5 +130,28 @@ public void validate(FailureCollector collector) {
collector.addFailure("Maximum Slit Size cannot be a negative number.", null)
.withConfigProperty(PROPERTY_MAX_SPLIT_SIZE);
}

if (!containsMacro(PROPERTY_IMPORT_QUERY_TYPE)) {
boolean isImportQuerySelected = ImportQueryType.IMPORT_QUERY.getValue().equals(importQueryType);
boolean isTableNameSelected = ImportQueryType.TABLE_NAME.getValue().equals(importQueryType);

if (isImportQuerySelected && !containsMacro(PROPERTY_IMPORT_QUERY) && Strings.isNullOrEmpty(importQuery)) {
collector.addFailure("Import Query cannot be empty", null)
.withConfigProperty(PROPERTY_IMPORT_QUERY);

} else if (isTableNameSelected && !containsMacro(PROPERTY_TABLE_NAME) && Strings.isNullOrEmpty(tableName)) {
collector.addFailure("Table Name cannot be empty", null)
.withConfigProperty(PROPERTY_TABLE_NAME);
}
} else {
boolean isImportQueryMissing = !containsMacro(PROPERTY_IMPORT_QUERY) && Strings.isNullOrEmpty(importQuery);
boolean isTableNameMissing = !containsMacro(PROPERTY_TABLE_NAME) && Strings.isNullOrEmpty(tableName);

if (isImportQueryMissing && isTableNameMissing) {
collector.addFailure("Either 'Import Query' or 'Table Name' must be provided.", null)
.withConfigProperty(PROPERTY_IMPORT_QUERY)
.withConfigProperty(PROPERTY_TABLE_NAME);
}
}
}
}
Loading