Skip to content

Commit 1394b84

Browse files
sfc-gh-ext-simba-jfsfc-gh-pbulawasfc-gh-ext-simba-hxsfc-gh-ext-simba-vb
authored
Snow-1936378 add support for vector type for loader (#2161)
Co-authored-by: Piotr Bulawa <piotr.bulawa@snowflake.com> Co-authored-by: Harry Xi <harry.xi@insightsoftware.com> Co-authored-by: vikram barbade <vikramsinh.barbade@insightsoftware.com>
1 parent 882d1e5 commit 1394b84

File tree

3 files changed

+200
-1
lines changed

3 files changed

+200
-1
lines changed

src/main/java/net/snowflake/client/loader/ProcessQueue.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,9 @@ public void run() {
265265
+ "("
266266
+ _loader.getColumnsAsString()
267267
+ ")"
268-
+ " SELECT * FROM \""
268+
+ " SELECT "
269+
+ _loader.getStageColumnsAsString()
270+
+ " FROM \""
269271
+ stage.getId()
270272
+ "\"";
271273
break;

src/main/java/net/snowflake/client/loader/StreamLoader.java

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
import java.io.IOException;
66
import java.sql.Connection;
77
import java.sql.DatabaseMetaData;
8+
import java.sql.ResultSet;
89
import java.sql.SQLException;
910
import java.text.DateFormat;
1011
import java.text.SimpleDateFormat;
1112
import java.util.ArrayList;
1213
import java.util.Calendar;
1314
import java.util.GregorianCalendar;
15+
import java.util.HashMap;
1416
import java.util.List;
1517
import java.util.Map;
1618
import java.util.Properties;
@@ -83,6 +85,11 @@ public class StreamLoader implements Loader, Runnable {
8385

8486
private List<String> _columns;
8587

88+
private Map<String, Integer> _vectorColumnsNameAndSize = new HashMap<String, Integer>();
89+
90+
// Vector type can be FLOAT or INT
91+
private String _vectorType;
92+
8693
private List<String> _keys;
8794

8895
private long _batchRowSize = DEFAULT_BATCH_ROW_SIZE;
@@ -178,6 +185,7 @@ public void setProperty(LoaderProperty property, Object value) {
178185
typeCheckedColumns.add((String) e);
179186
}
180187
_columns = typeCheckedColumns;
188+
setVectorColumns();
181189
}
182190
break;
183191
case keys:
@@ -598,6 +606,31 @@ private void truncateTargetTable() {
598606
}
599607
}
600608

609+
public void setVectorColumnType(String vectorType) {
610+
this._vectorType = vectorType;
611+
}
612+
613+
public void setVectorColumns() {
614+
try {
615+
DatabaseMetaData dbmd = _processConn.getMetaData();
616+
for (String col : _columns) {
617+
try (ResultSet rs = dbmd.getColumns(_database, _schema, _table, col)) {
618+
rs.next();
619+
if (isColumnTypeVector(rs.getString(6))) {
620+
_vectorColumnsNameAndSize.put(col, rs.getInt(7));
621+
}
622+
}
623+
}
624+
} catch (SQLException e) {
625+
logger.error(e.getMessage(), e);
626+
abort(new Loader.ConnectionError(Utils.getCause(e)));
627+
}
628+
}
629+
630+
private boolean isColumnTypeVector(String col) {
631+
return col != null && col.equalsIgnoreCase("vector");
632+
}
633+
601634
@Override
602635
public void run() {
603636
try {
@@ -750,6 +783,10 @@ List<String> getColumns() {
750783
return this._columns;
751784
}
752785

786+
Map<String, Integer> getVectorColumns() {
787+
return this._vectorColumnsNameAndSize;
788+
}
789+
753790
String getColumnsAsString() {
754791
// comma separate list of column names
755792
StringBuilder sb = new StringBuilder("\"");
@@ -904,4 +941,38 @@ public int getSubmittedRowCount() {
904941
void setTestMode(boolean mode) {
905942
this._testMode = mode;
906943
}
944+
945+
public String getStageColumnsAsString() {
946+
// if there are no vector columns in the target table just select * is needed from the staging
947+
// table.
948+
if (_vectorColumnsNameAndSize.isEmpty()) {
949+
return "*";
950+
}
951+
if (_vectorType == null) {
952+
throw new IllegalArgumentException(
953+
"Target table with vector columns must use setVectorColumnType with \"INT\" or \"FLOAT\"");
954+
}
955+
956+
StringBuilder sb = new StringBuilder();
957+
for (int i = 0; i < _columns.size(); i++) {
958+
String colName = _columns.get(i);
959+
if (_vectorColumnsNameAndSize.containsKey(colName)) {
960+
sb.append(
961+
colName
962+
+ "::VECTOR("
963+
+ _vectorType
964+
+ ", "
965+
+ _vectorColumnsNameAndSize.get(colName)
966+
+ ")");
967+
} else {
968+
sb.append("\"");
969+
sb.append(colName);
970+
sb.append("\"");
971+
}
972+
if (i != _columns.size() - 1) {
973+
sb.append(", ");
974+
}
975+
}
976+
return sb.toString();
977+
}
907978
}

src/test/java/net/snowflake/client/loader/LoaderLatestIT.java

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,4 +201,130 @@ public void testKeyClusteringTable() throws Exception {
201201
}
202202
}
203203
}
204+
205+
@Test
206+
private void testVectorColumnInTable() throws Exception {
207+
String tableName = "VECTOR_TABLE";
208+
try {
209+
testConnection
210+
.createStatement()
211+
.execute(
212+
String.format("CREATE OR REPLACE TABLE %s (vector_col VECTOR(FLOAT, 3))", tableName));
213+
214+
TestDataConfigBuilder tdcb = new TestDataConfigBuilder(testConnection, putConnection);
215+
tdcb.setOperation(Operation.INSERT)
216+
.setStartTransaction(true)
217+
.setTruncateTable(true)
218+
.setTableName(tableName)
219+
.setColumns(Arrays.asList("vector_col"));
220+
StreamLoader loader = tdcb.getStreamLoader();
221+
TestDataConfigBuilder.ResultListener listener = tdcb.getListener();
222+
loader.start();
223+
224+
loader.submitRow(new Object[] {"[12, 14.0, 100]"});
225+
loader.setVectorColumnType("FLOAT");
226+
loader.finish();
227+
int submitted = listener.getSubmittedRowCount();
228+
assertThat("submitted rows", submitted, equalTo(1));
229+
230+
} finally {
231+
testConnection.createStatement().execute(String.format("DROP TABLE IF EXISTS %s", tableName));
232+
}
233+
}
234+
235+
@Test
236+
private void testMultipleFloatVectorColumnsInTable() throws Exception {
237+
String tableName = "VECTOR_TABLE";
238+
try {
239+
testConnection
240+
.createStatement()
241+
.execute(
242+
String.format(
243+
"CREATE OR REPLACE TABLE %s (vec1 VECTOR(FLOAT, 3), vec2 VECTOR(FLOAT, 3))",
244+
tableName));
245+
246+
TestDataConfigBuilder tdcb = new TestDataConfigBuilder(testConnection, putConnection);
247+
tdcb.setOperation(Operation.INSERT)
248+
.setStartTransaction(true)
249+
.setTruncateTable(true)
250+
.setTableName(tableName)
251+
.setColumns(Arrays.asList("vector_col"));
252+
StreamLoader loader = tdcb.getStreamLoader();
253+
TestDataConfigBuilder.ResultListener listener = tdcb.getListener();
254+
loader.start();
255+
256+
loader.submitRow(new Object[] {"[12, 14.0, 100]", "[12, 14.0, 100]"});
257+
loader.setVectorColumnType("FLOAT");
258+
loader.finish();
259+
int submitted = listener.getSubmittedRowCount();
260+
assertThat("submitted rows", submitted, equalTo(1));
261+
262+
} finally {
263+
testConnection.createStatement().execute(String.format("DROP TABLE IF EXISTS %s", tableName));
264+
}
265+
}
266+
267+
@Test
268+
private void testMultipleIntVectorColumnsInTable() throws Exception {
269+
String tableName = "VECTOR_TABLE";
270+
try {
271+
testConnection
272+
.createStatement()
273+
.execute(
274+
String.format(
275+
"CREATE OR REPLACE TABLE %s (vec1 VECTOR(INT, 3), vec2 VECTOR(INT, 3))",
276+
tableName));
277+
278+
TestDataConfigBuilder tdcb = new TestDataConfigBuilder(testConnection, putConnection);
279+
tdcb.setOperation(Operation.INSERT)
280+
.setStartTransaction(true)
281+
.setTruncateTable(true)
282+
.setTableName(tableName)
283+
.setColumns(Arrays.asList("vector_col"));
284+
StreamLoader loader = tdcb.getStreamLoader();
285+
TestDataConfigBuilder.ResultListener listener = tdcb.getListener();
286+
loader.start();
287+
288+
loader.submitRow(new Object[] {"[12, 14, 100]", "[12, 14, 100]"});
289+
loader.setVectorColumnType("INT");
290+
loader.finish();
291+
int submitted = listener.getSubmittedRowCount();
292+
assertThat("submitted rows", submitted, equalTo(1));
293+
294+
} finally {
295+
testConnection.createStatement().execute(String.format("DROP TABLE IF EXISTS %s", tableName));
296+
}
297+
}
298+
299+
@Test
300+
private void testMultipleTypesWithVectorColumnsInTable() throws Exception {
301+
String tableName = "VECTOR_TABLE";
302+
try {
303+
testConnection
304+
.createStatement()
305+
.execute(
306+
String.format(
307+
"CREATE OR REPLACE TABLE %s (vec1 VECTOR(FLOAT, 3), vec2 VECTOR(FLOAT, 3), ID int, colA varchar(255), colB date)",
308+
tableName));
309+
310+
TestDataConfigBuilder tdcb = new TestDataConfigBuilder(testConnection, putConnection);
311+
tdcb.setOperation(Operation.INSERT)
312+
.setStartTransaction(true)
313+
.setTruncateTable(true)
314+
.setTableName(tableName)
315+
.setColumns(Arrays.asList("vector_col"));
316+
StreamLoader loader = tdcb.getStreamLoader();
317+
TestDataConfigBuilder.ResultListener listener = tdcb.getListener();
318+
loader.start();
319+
320+
loader.setVectorColumnType("FLOAT");
321+
loader.submitRow(new Object[] {"[12, 14.0, 100]", "[12, 14.0, 100]", 10, "abc", new Date()});
322+
loader.finish();
323+
int submitted = listener.getSubmittedRowCount();
324+
assertThat("submitted rows", submitted, equalTo(1));
325+
326+
} finally {
327+
testConnection.createStatement().execute(String.format("DROP TABLE IF EXISTS %s", tableName));
328+
}
329+
}
204330
}

0 commit comments

Comments
 (0)