Skip to content

Commit db5bac6

Browse files
authored
Merge pull request #185 from staticlibs/result_set_ignore_case
Make ResultSet getters case insensitive
2 parents c0e947f + e61b89c commit db5bac6

File tree

3 files changed

+119
-36
lines changed

3 files changed

+119
-36
lines changed

src/main/java/org/duckdb/DuckDBArrayResultSet.java

+51-31
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,25 @@ public boolean wasNull() throws SQLException {
4141
return wasNull;
4242
}
4343

44-
private <T> T getValue(String columnLabel, SqlValueGetter<T> getter) throws SQLException {
45-
return getValue(findColumn(columnLabel), getter);
46-
}
47-
4844
private <T> T getValue(int columnIndex, SqlValueGetter<T> getter) throws SQLException {
4945
if (columnIndex == 1) {
50-
throw new IllegalArgumentException(
51-
"The first element of Array-backed ResultSet can only be retrieved with getInt()");
46+
throw new SQLException(
47+
"The first element of Array-backed ResultSet can only be retrieved with numeric getters");
5248
}
5349
if (columnIndex != 2) {
54-
throw new IllegalArgumentException("Array-backed ResultSet can only have two columns");
50+
throw new SQLException("Array-backed ResultSet can only have two columns");
5551
}
5652
T value = getter.getValue(offset + currentValueIndex);
5753

5854
wasNull = value == null;
5955
return value;
6056
}
6157

58+
private int getIndexColumnValue() {
59+
wasNull = false;
60+
return currentValueIndex + 1;
61+
}
62+
6263
@Override
6364
public String getString(int columnIndex) throws SQLException {
6465
return getValue(columnIndex, vector::getLazyString);
@@ -71,40 +72,57 @@ public boolean getBoolean(int columnIndex) throws SQLException {
7172

7273
@Override
7374
public byte getByte(int columnIndex) throws SQLException {
75+
if (columnIndex == 1) {
76+
return (byte) getIndexColumnValue();
77+
}
7478
return getValue(columnIndex, vector::getByte);
7579
}
7680

7781
@Override
7882
public short getShort(int columnIndex) throws SQLException {
83+
if (columnIndex == 1) {
84+
return (short) getIndexColumnValue();
85+
}
7986
return getValue(columnIndex, vector::getShort);
8087
}
8188

8289
@Override
8390
public int getInt(int columnIndex) throws SQLException {
8491
if (columnIndex == 1) {
85-
wasNull = false;
86-
return currentValueIndex + 1;
92+
return getIndexColumnValue();
8793
}
8894
return getValue(columnIndex, vector::getInt);
8995
}
9096

9197
@Override
9298
public long getLong(int columnIndex) throws SQLException {
93-
return getInt(columnIndex);
99+
if (columnIndex == 1) {
100+
return getIndexColumnValue();
101+
}
102+
return getValue(columnIndex, vector::getLong);
94103
}
95104

96105
@Override
97106
public float getFloat(int columnIndex) throws SQLException {
107+
if (columnIndex == 1) {
108+
return getIndexColumnValue();
109+
}
98110
return getValue(columnIndex, vector::getFloat);
99111
}
100112

101113
@Override
102114
public double getDouble(int columnIndex) throws SQLException {
115+
if (columnIndex == 1) {
116+
return getIndexColumnValue();
117+
}
103118
return getValue(columnIndex, vector::getDouble);
104119
}
105120

106121
@Override
107122
public BigDecimal getBigDecimal(int columnIndex, int scale) throws SQLException {
123+
if (columnIndex == 1) {
124+
return BigDecimal.valueOf(getIndexColumnValue());
125+
}
108126
return getValue(columnIndex, vector::getBigDecimal);
109127
}
110128

@@ -145,51 +163,47 @@ public InputStream getBinaryStream(int columnIndex) throws SQLException {
145163

146164
@Override
147165
public String getString(String columnLabel) throws SQLException {
148-
return getValue(columnLabel, vector::getLazyString);
166+
return getString(findColumn(columnLabel));
149167
}
150168

151169
@Override
152170
public boolean getBoolean(String columnLabel) throws SQLException {
153-
return getValue(columnLabel, vector::getBoolean);
171+
return getBoolean(findColumn(columnLabel));
154172
}
155173

156174
@Override
157175
public byte getByte(String columnLabel) throws SQLException {
158-
return getValue(columnLabel, vector::getByte);
176+
return getByte(findColumn(columnLabel));
159177
}
160178

161179
@Override
162180
public short getShort(String columnLabel) throws SQLException {
163-
return getValue(columnLabel, vector::getShort);
181+
return getShort(findColumn(columnLabel));
164182
}
165183

166184
@Override
167185
public int getInt(String columnLabel) throws SQLException {
168-
int columnIndex = findColumn(columnLabel);
169-
if (columnIndex == 1) {
170-
return currentValueIndex;
171-
}
172-
return getValue(columnIndex, vector::getInt);
186+
return getInt(findColumn(columnLabel));
173187
}
174188

175189
@Override
176190
public long getLong(String columnLabel) throws SQLException {
177-
return getInt(columnLabel);
191+
return getLong(findColumn(columnLabel));
178192
}
179193

180194
@Override
181195
public float getFloat(String columnLabel) throws SQLException {
182-
return getValue(columnLabel, vector::getFloat);
196+
return getFloat(findColumn(columnLabel));
183197
}
184198

185199
@Override
186200
public double getDouble(String columnLabel) throws SQLException {
187-
return getValue(columnLabel, vector::getDouble);
201+
return getDouble(findColumn(columnLabel));
188202
}
189203

190204
@Override
191205
public BigDecimal getBigDecimal(String columnLabel, int scale) throws SQLException {
192-
return getValue(columnLabel, vector::getBigDecimal);
206+
return getBigDecimal(findColumn(columnLabel), scale);
193207
}
194208

195209
@Override
@@ -199,17 +213,17 @@ public byte[] getBytes(String columnLabel) throws SQLException {
199213

200214
@Override
201215
public Date getDate(String columnLabel) throws SQLException {
202-
return getValue(columnLabel, vector::getDate);
216+
return getDate(findColumn(columnLabel));
203217
}
204218

205219
@Override
206220
public Time getTime(String columnLabel) throws SQLException {
207-
return getValue(columnLabel, vector::getTime);
221+
return getTime(findColumn(columnLabel));
208222
}
209223

210224
@Override
211225
public Timestamp getTimestamp(String columnLabel) throws SQLException {
212-
return getValue(columnLabel, vector::getTimestamp);
226+
return getTimestamp(findColumn(columnLabel));
213227
}
214228

215229
@Override
@@ -254,12 +268,18 @@ public Object getObject(int columnIndex) throws SQLException {
254268

255269
@Override
256270
public Object getObject(String columnLabel) throws SQLException {
257-
return getValue(columnLabel, vector::getTimestamp);
271+
return getObject(findColumn(columnLabel));
258272
}
259273

260274
@Override
261275
public int findColumn(String columnLabel) throws SQLException {
262-
return Integer.parseInt(columnLabel);
276+
if ("INDEX".equalsIgnoreCase(columnLabel)) {
277+
return 1;
278+
}
279+
if ("VALUE".equalsIgnoreCase(columnLabel)) {
280+
return 2;
281+
}
282+
throw new SQLException("Could not find column with label " + columnLabel);
263283
}
264284

265285
@Override
@@ -279,7 +299,7 @@ public BigDecimal getBigDecimal(int columnIndex) throws SQLException {
279299

280300
@Override
281301
public BigDecimal getBigDecimal(String columnLabel) throws SQLException {
282-
return getValue(columnLabel, vector::getBigDecimal);
302+
return getBigDecimal(findColumn(columnLabel));
283303
}
284304

285305
@Override
@@ -671,7 +691,7 @@ public Array getArray(int columnIndex) throws SQLException {
671691

672692
@Override
673693
public Object getObject(String columnLabel, Map<String, Class<?>> map) throws SQLException {
674-
return getValue(columnLabel, vector::getObject);
694+
return getObject(findColumn(columnLabel));
675695
}
676696

677697
@Override
@@ -691,7 +711,7 @@ public Clob getClob(String columnLabel) throws SQLException {
691711

692712
@Override
693713
public Array getArray(String columnLabel) throws SQLException {
694-
return getValue(columnLabel, vector::getArray);
714+
return getArray(findColumn(columnLabel));
695715
}
696716

697717
@Override

src/main/java/org/duckdb/DuckDBResultSet.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ public int findColumn(String columnLabel) throws SQLException {
308308
throw new SQLException("ResultSet was closed");
309309
}
310310
for (int col_idx = 0; col_idx < meta.column_count; col_idx++) {
311-
if (meta.column_names[col_idx].contentEquals(columnLabel)) {
311+
if (meta.column_names[col_idx].equalsIgnoreCase(columnLabel)) {
312312
return col_idx + 1;
313313
}
314314
}

src/test/java/org/duckdb/TestDuckDBJDBC.java

+67-4
Original file line numberDiff line numberDiff line change
@@ -3748,7 +3748,17 @@ public static void test_array_resultset() throws Exception {
37483748
ResultSet arrayResultSet = rs.getArray(1).getResultSet();
37493749
assertTrue(arrayResultSet.next());
37503750
assertEquals(arrayResultSet.getInt(1), 1);
3751+
assertEquals(arrayResultSet.getInt("index"), 1);
3752+
assertEquals(arrayResultSet.getInt("Index"), 1);
3753+
assertEquals(arrayResultSet.getInt("INDEX"), 1);
3754+
assertEquals(arrayResultSet.getByte(2), (byte) 42);
3755+
assertEquals(arrayResultSet.getShort(2), (short) 42);
37513756
assertEquals(arrayResultSet.getInt(2), 42);
3757+
assertEquals(arrayResultSet.getLong(2), (long) 42);
3758+
assertEquals(arrayResultSet.getFloat(2), (float) 42);
3759+
assertEquals(arrayResultSet.getDouble(2), (double) 42);
3760+
assertEquals(arrayResultSet.getBigDecimal(2), BigDecimal.valueOf(42));
3761+
assertEquals(arrayResultSet.getInt("value"), 42);
37523762
assertTrue(arrayResultSet.next());
37533763
assertEquals(arrayResultSet.getInt(1), 2);
37543764
assertEquals(arrayResultSet.getInt(2), 69);
@@ -3760,10 +3770,18 @@ public static void test_array_resultset() throws Exception {
37603770
ResultSet arrayResultSet = rs.getArray(1).getResultSet();
37613771
assertTrue(arrayResultSet.next());
37623772
assertEquals(arrayResultSet.getInt(1), 1);
3763-
Array subArray = arrayResultSet.getArray(2);
3764-
assertNotNull(subArray);
3765-
ResultSet subArrayResultSet = subArray.getResultSet();
3766-
assertFalse(subArrayResultSet.next()); // empty array
3773+
{
3774+
Array subArray = arrayResultSet.getArray(2);
3775+
assertNotNull(subArray);
3776+
ResultSet subArrayResultSet = subArray.getResultSet();
3777+
assertFalse(subArrayResultSet.next()); // empty array
3778+
}
3779+
{
3780+
Array subArray = arrayResultSet.getArray("value");
3781+
assertNotNull(subArray);
3782+
ResultSet subArrayResultSet = subArray.getResultSet();
3783+
assertFalse(subArrayResultSet.next()); // empty array
3784+
}
37673785

37683786
assertTrue(arrayResultSet.next());
37693787
assertEquals(arrayResultSet.getInt(1), 2);
@@ -3845,6 +3863,13 @@ public static void test_array_resultset() throws Exception {
38453863
assertEquals(arrayResultSet2.getInt(2), 69);
38463864
assertFalse(arrayResultSet2.next());
38473865
}
3866+
3867+
try (ResultSet rs = statement.executeQuery("select [" + Integer.MAX_VALUE + "::BIGINT + 1]")) {
3868+
assertTrue(rs.next());
3869+
ResultSet arrayResultSet = rs.getArray(1).getResultSet();
3870+
assertTrue(arrayResultSet.next());
3871+
assertEquals(arrayResultSet.getLong(2), ((long) Integer.MAX_VALUE) + 1);
3872+
}
38483873
}
38493874
}
38503875

@@ -4577,6 +4602,44 @@ public static void test_get_bytes() throws Exception {
45774602
}
45784603
}
45794604

4605+
public static void test_case_insensitivity() throws Exception {
4606+
try (Connection connection = DriverManager.getConnection("jdbc:duckdb:")) {
4607+
try (Statement s = connection.createStatement()) {
4608+
s.execute("CREATE TABLE someTable (lowercase INT, mixedCASE INT, UPPERCASE INT)");
4609+
s.execute("INSERT INTO someTable VALUES (0, 1, 2)");
4610+
}
4611+
4612+
String[] tableNameVariations = new String[] {"sometable", "someTable", "SOMETABLE"};
4613+
String[][] columnNameVariations = new String[][] {{"lowercase", "mixedcase", "uppercase"},
4614+
{"lowerCASE", "mixedCASE", "upperCASE"},
4615+
{"LOWERCASE", "MIXEDCASE", "UPPERCASE"}};
4616+
4617+
int totalTestsRun = 0;
4618+
4619+
// Test every combination of upper, lower and mixedcase column and table names.
4620+
for (String tableName : tableNameVariations) {
4621+
for (int columnVariation = 0; columnVariation < columnNameVariations.length; columnVariation++) {
4622+
try (Statement s = connection.createStatement()) {
4623+
String query = String.format("SELECT %s, %s, %s from %s;", columnNameVariations[0][0],
4624+
columnNameVariations[0][1], columnNameVariations[0][2], tableName);
4625+
4626+
ResultSet resultSet = s.executeQuery(query);
4627+
assertTrue(resultSet.next());
4628+
for (int i = 0; i < columnNameVariations[0].length; i++) {
4629+
assertEquals(resultSet.getInt(columnNameVariations[columnVariation][i]), i,
4630+
"Query " + query + " did not get correct result back for column number " + i);
4631+
totalTestsRun++;
4632+
}
4633+
}
4634+
}
4635+
}
4636+
4637+
assertEquals(totalTestsRun,
4638+
tableNameVariations.length * columnNameVariations.length * columnNameVariations[0].length,
4639+
"Number of test cases actually run did not match number expected to be run.");
4640+
}
4641+
}
4642+
45804643
public static void test_fractional_time() throws Exception {
45814644
try (Connection conn = DriverManager.getConnection(JDBC_URL);
45824645
PreparedStatement stmt = conn.prepareStatement("SELECT '01:02:03.123'::TIME");

0 commit comments

Comments
 (0)