1818
1919import static com .google .common .base .CaseFormat .UPPER_CAMEL ;
2020import static com .google .common .base .CaseFormat .UPPER_UNDERSCORE ;
21+ import static org .apache .hadoop .util .Preconditions .checkNotNull ;
2122
2223import com .google .common .base .CharMatcher ;
2324import com .google .common .base .Joiner ;
4142import java .util .ArrayList ;
4243import java .util .List ;
4344import java .util .Properties ;
45+ import java .util .function .Supplier ;
4446import javax .annotation .Nonnull ;
4547import javax .sql .DataSource ;
46- import org .apache .commons .lang3 .StringUtils ;
4748import org .springframework .jdbc .core .JdbcTemplate ;
4849import org .springframework .jdbc .datasource .SimpleDriverDataSource ;
4950
@@ -72,31 +73,32 @@ public AbstractSnowflakeConnector(@Nonnull String name) {
7273 super (name );
7374 }
7475
75- private static final int MAX_DATABASE_CHAR_LENGTH = 255 ;
76- private static final String DEFAULT_DATABASE = "SNOWFLAKE" ;
77-
7876 @ Nonnull
7977 @ Override
8078 public abstract String getDescription ();
8179
8280 @ Nonnull
8381 @ Override
84- public Handle open (@ Nonnull ConnectorArguments arguments )
82+ public final Handle open (@ Nonnull ConnectorArguments arguments )
8583 throws MetadataDumperUsageException , SQLException {
86- String url = arguments .getUri () != null ? arguments .getUri () : getUrlFromArguments (arguments );
87- String databaseName =
88- arguments .getDatabases ().isEmpty ()
89- ? DEFAULT_DATABASE
90- : sanitizeDatabaseName (arguments .getDatabases ().get (0 ));
91-
92- DataSource dataSource =
93- arguments .isPrivateKeyFileProvided ()
94- ? createPrivateKeyDataSource (arguments , url )
95- : createUserPasswordDataSource (arguments , url );
96- JdbcHandle jdbcHandle = new JdbcHandle (dataSource );
97-
98- setCurrentDatabase (databaseName , jdbcHandle .getJdbcTemplate ());
99- return jdbcHandle ;
84+ Properties properties = dataSourceProperties (arguments );
85+ String url = getUrlFromArguments (arguments );
86+ DataSource dataSource = new SimpleDriverDataSource (newDriver (arguments ), url , properties );
87+ if (arguments .isAssessment ()) {
88+ JdbcHandle handle = new JdbcHandle (dataSource );
89+ JdbcTemplate template = handle .getJdbcTemplate ();
90+ String actualDatabase = template .queryForObject ("USE DATABASE SNOWFLAKE;" , String .class );
91+ checkNotNull (actualDatabase );
92+ return handle ;
93+ } else {
94+ String databaseName =
95+ arguments .getDatabases ().isEmpty ()
96+ ? "SNOWFLAKE"
97+ : sanitizeDatabaseName (arguments .getDatabases ().get (0 ));
98+ JdbcHandle handle = new JdbcHandle (dataSource );
99+ setCurrentDatabase (databaseName , handle .getJdbcTemplate ());
100+ return handle ;
101+ }
100102 }
101103
102104 @ Override
@@ -121,40 +123,56 @@ public final void validate(@Nonnull ConnectorArguments arguments) {
121123 */
122124 protected abstract void validateForConnector (@ Nonnull ConnectorArguments arguments );
123125
124- private DataSource createUserPasswordDataSource (@ Nonnull ConnectorArguments arguments , String url )
126+ @ Nonnull
127+ private Driver newDriver (@ Nonnull ConnectorArguments arguments ) throws SQLException {
128+ return newDriver (arguments .getDriverPaths (), "net.snowflake.client.jdbc.SnowflakeDriver" );
129+ }
130+
131+ @ Nonnull
132+ private static Properties dataSourceProperties (@ Nonnull ConnectorArguments arguments )
125133 throws SQLException {
126- Driver driver =
127- newDriver (arguments .getDriverPaths (), "net.snowflake.client.jdbc.SnowflakeDriver" );
128- Properties prop = new Properties ();
134+ String user = arguments .getUserOrFail ();
135+ if (arguments .isPrivateKeyFileProvided ()) {
136+ return createPrivateKeyProperties (arguments , user );
137+ } else {
138+ return createUserPasswordProperties (arguments , user );
139+ }
140+ }
141+
142+ private static Properties createUserPasswordProperties (
143+ @ Nonnull ConnectorArguments arguments , @ Nonnull String user ) {
144+ Properties properties = new Properties ();
129145
130- prop .put ("user" , arguments . getUser () );
146+ properties .put ("user" , user );
131147 if (arguments .isPasswordFlagProvided ()) {
132- prop .put ("password" , arguments .getPasswordOrPrompt ());
148+ properties .put ("password" , arguments .getPasswordOrPrompt ());
133149 }
134150 // Set default authenticator only if url is not provided to allow user overriding it
135151 if (arguments .getUri () == null ) {
136- prop .put ("authenticator" , "username_password_mfa" );
152+ properties .put ("authenticator" , "username_password_mfa" );
137153 }
138- return new SimpleDriverDataSource ( driver , url , prop ) ;
154+ return properties ;
139155 }
140156
141- private DataSource createPrivateKeyDataSource (@ Nonnull ConnectorArguments arguments , String url )
142- throws SQLException {
143- Driver driver =
144- newDriver (arguments .getDriverPaths (), "net.snowflake.client.jdbc.SnowflakeDriver" );
145- Properties prop = new Properties ();
157+ private static Properties createPrivateKeyProperties (
158+ @ Nonnull ConnectorArguments arguments , @ Nonnull String user ) {
159+ Properties properties = new Properties ();
160+ properties .put ("user" , user );
146161
147- prop .put ("private_key_file" , arguments .getPrivateKeyFile ());
148- prop .put ("user" , arguments .getUser ());
162+ properties .put ("private_key_file" , arguments .getPrivateKeyFile ());
149163 if (arguments .getPrivateKeyPassword () != null ) {
150- prop .put ("private_key_pwd" , arguments .getPrivateKeyPassword ());
164+ properties .put ("private_key_pwd" , arguments .getPrivateKeyPassword ());
151165 }
152-
153- return new SimpleDriverDataSource (driver , url , prop );
166+ return properties ;
154167 }
155168
156169 @ Nonnull
157170 private String getUrlFromArguments (@ Nonnull ConnectorArguments arguments ) {
171+ String url = arguments .getUri ();
172+ if (url != null ) {
173+ return url ;
174+ }
175+
158176 StringBuilder buf = new StringBuilder ("jdbc:snowflake://" );
159177 String host = arguments .getHost ("host.snowflakecomputing.com" );
160178 buf .append (host ).append ("/" );
@@ -178,26 +196,34 @@ private void setCurrentDatabase(@Nonnull String databaseName, @Nonnull JdbcTempl
178196 String currentDatabase =
179197 jdbcTemplate .queryForObject (String .format ("USE DATABASE %s;" , databaseName ), String .class );
180198 if (currentDatabase == null ) {
181- List <String > dbNames =
182- jdbcTemplate .query ("SHOW DATABASES" , (rs , rowNum ) -> rs .getString ("name" ));
183- throw new MetadataDumperUsageException (
184- "Database name not found "
185- + databaseName
186- + ", use one of: "
187- + StringUtils .join (dbNames , ", " ));
199+ Supplier <List <String >> showQuery =
200+ () -> jdbcTemplate .query ("SHOW DATABASES" , (rs , rowNum ) -> rs .getString ("name" ));
201+ throw unrecognizedDatabase (databaseName , showQuery );
188202 }
189203 }
190204
205+ @ Nonnull
206+ static MetadataDumperUsageException unrecognizedDatabase (
207+ @ Nonnull String database , @ Nonnull Supplier <List <String >> availableDatabases ) {
208+ List <String > names = availableDatabases .get ();
209+ String joinedNames = String .join (", " , names );
210+ String message =
211+ String .format ("Database name not found %s, use one of: %s" , database , joinedNames );
212+
213+ return new MetadataDumperUsageException (message );
214+ }
215+
191216 String sanitizeDatabaseName (@ Nonnull String databaseName ) throws MetadataDumperUsageException {
192- CharMatcher doubleQuoteMatcher = CharMatcher .is ('"' );
193- String trimmedName = doubleQuoteMatcher .trimFrom (databaseName );
194- int charLengthWithQuotes = databaseName .length () + 2 ;
195- if (charLengthWithQuotes > 255 ) {
217+ int lengthWithQuotes = databaseName .length () + 2 ;
218+ int maxLength = 255 ;
219+ if (lengthWithQuotes > maxLength ) {
196220 throw new MetadataDumperUsageException (
197221 String .format (
198222 "The provided database name has %d characters, which is longer than the maximum allowed number %d for Snowflake identifiers." ,
199- charLengthWithQuotes , MAX_DATABASE_CHAR_LENGTH ));
223+ lengthWithQuotes , maxLength ));
200224 }
225+ CharMatcher doubleQuoteMatcher = CharMatcher .is ('"' );
226+ String trimmedName = doubleQuoteMatcher .trimFrom (databaseName );
201227 if (doubleQuoteMatcher .matchesAnyOf (trimmedName )) {
202228 throw new MetadataDumperUsageException (
203229 "Database name has incorrectly placed double quote(s). Aborting query." );
0 commit comments