1616 */
1717package com .google .edwmigration .dumper .application .dumper .connector .snowflake ;
1818
19+ import static org .junit .Assert .assertEquals ;
20+ import static org .junit .Assert .assertThrows ;
21+ import static org .junit .Assert .assertTrue ;
22+
1923import com .google .common .collect .ImmutableMap ;
2024import com .google .common .io .Resources ;
2125import com .google .edwmigration .dumper .application .dumper .ConnectorArguments ;
3438import java .util .List ;
3539import java .util .Map ;
3640import javax .annotation .Nonnull ;
37- import org .junit .Assert ;
41+ import org .apache .commons .lang3 .ArrayUtils ;
42+ import org .apache .commons .lang3 .StringUtils ;
3843import org .junit .Assume ;
3944import org .junit .Test ;
4045import org .junit .runner .RunWith ;
@@ -122,19 +127,19 @@ public void testDatabaseNameFailure() {
122127 Assume .assumeTrue (isDumperTest ());
123128
124129 MetadataDumperUsageException exception =
125- Assert . assertThrows (
130+ assertThrows (
126131 MetadataDumperUsageException .class ,
127132 () -> {
128133 File outputFile =
129134 TestUtils .newOutputFile ("compilerworks-snowflake-metadata-fail.zip" );
130135 String [] args = ARGS (connector , outputFile );
131136
132- Assert . assertEquals ("--database" , args [6 ]);
137+ assertEquals ("--database" , args [6 ]);
133138 args [7 ] = args [7 ] + "_NOT_EXISTS" ;
134139 run (args );
135140 });
136141
137- Assert . assertTrue (exception .getMessage ().startsWith ("Database name not found" ));
142+ assertTrue (exception .getMessage ().startsWith ("Database name not found" ));
138143 }
139144
140145 @ Test
@@ -147,17 +152,42 @@ public void connector_generatesExpectedSql() throws IOException {
147152 StandardCharsets .UTF_8 ),
148153 TaskSqlMap .class );
149154
150- Assert . assertEquals (expectedSqls .size (), actualSqls .size ());
151- Assert . assertEquals (expectedSqls .keySet (), actualSqls .keySet ());
155+ assertEquals (expectedSqls .size (), actualSqls .size ());
156+ assertEquals (expectedSqls .keySet (), actualSqls .keySet ());
152157 for (String name : expectedSqls .keySet ()) {
153- Assert . assertEquals (expectedSqls .get (name ), actualSqls .get (name ));
158+ assertEquals (expectedSqls .get (name ), actualSqls .get (name ));
154159 }
155160 }
156161
157- private static Map <String , String > collectSqlStatements () throws IOException {
162+ @ Test
163+ public void connector_generatesExpectedSql_withQueryOverrides () throws IOException {
164+ Map <String , String > actualSqls =
165+ collectSqlStatements ("-Dsnowflake.metadata.columns.query=SQL_OVERRIDE" );
166+
167+ assertEquals ("SQL_OVERRIDE" , actualSqls .get ("columns-au.csv" ));
168+ assertEquals ("SQL_OVERRIDE" , actualSqls .get ("columns.csv" ));
169+ }
170+
171+ @ Test
172+ public void connector_generatesExpectedSql_withWhereOverrides () throws IOException {
173+ Map <String , String > actualSqls =
174+ collectSqlStatements ("-Dsnowflake.metadata.columns.where=SQL_OVERRIDE" );
175+
176+ // TODO: should be endsWith("WHERE SQL_OVERRIDE")
177+ assertTrue (
178+ actualSqls .get ("columns-au.csv" ).endsWith ("WHERE DELETED IS NULL WHERE SQL_OVERRIDE" ));
179+ // TODO: should be 1
180+ assertEquals (2 , StringUtils .countMatches (actualSqls .get ("columns-au.csv" ), " WHERE " ));
181+
182+ assertTrue (actualSqls .get ("columns.csv" ).endsWith ("WHERE SQL_OVERRIDE" ));
183+ assertEquals (1 , StringUtils .countMatches (actualSqls .get ("columns.csv" ), " WHERE " ));
184+ }
185+
186+ private static Map <String , String > collectSqlStatements (String ... extraArgs ) throws IOException {
158187 List <Task <?>> tasks = new ArrayList <>();
159188 SnowflakeMetadataConnector connector = new SnowflakeMetadataConnector ();
160- connector .addTasksTo (tasks , new ConnectorArguments ("--connector" , connector .getName ()));
189+ String [] args = ArrayUtils .addAll (new String [] {"--connector" , connector .getName ()}, extraArgs );
190+ connector .addTasksTo (tasks , new ConnectorArguments (args ));
161191 return tasks .stream ()
162192 .filter (t -> t instanceof JdbcSelectTask )
163193 .map (t -> (JdbcSelectTask ) t )
0 commit comments