diff --git a/src/main/java/net/snowflake/client/loader/ProcessQueue.java b/src/main/java/net/snowflake/client/loader/ProcessQueue.java index d6472771ad..fc4e422564 100644 --- a/src/main/java/net/snowflake/client/loader/ProcessQueue.java +++ b/src/main/java/net/snowflake/client/loader/ProcessQueue.java @@ -265,7 +265,9 @@ public void run() { + "(" + _loader.getColumnsAsString() + ")" - + " SELECT * FROM \"" + + " SELECT " + + _loader.getStageColumnsAsString() + + " FROM \"" + stage.getId() + "\""; break; diff --git a/src/main/java/net/snowflake/client/loader/StreamLoader.java b/src/main/java/net/snowflake/client/loader/StreamLoader.java index 59ac8c2ddc..6df560e6f5 100644 --- a/src/main/java/net/snowflake/client/loader/StreamLoader.java +++ b/src/main/java/net/snowflake/client/loader/StreamLoader.java @@ -5,12 +5,14 @@ import java.io.IOException; import java.sql.Connection; import java.sql.DatabaseMetaData; +import java.sql.ResultSet; import java.sql.SQLException; import java.text.DateFormat; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Calendar; import java.util.GregorianCalendar; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Properties; @@ -83,6 +85,11 @@ public class StreamLoader implements Loader, Runnable { private List _columns; + private Map _vectorColumnsNameAndSize = new HashMap(); + + // Vector type can be FLOAT or INT + private String _vectorType; + private List _keys; private long _batchRowSize = DEFAULT_BATCH_ROW_SIZE; @@ -178,6 +185,7 @@ public void setProperty(LoaderProperty property, Object value) { typeCheckedColumns.add((String) e); } _columns = typeCheckedColumns; + setVectorColumns(); } break; case keys: @@ -598,6 +606,31 @@ private void truncateTargetTable() { } } + public void setVectorColumnType(String vectorType) { + this._vectorType = vectorType; + } + + public void setVectorColumns() { + try { + DatabaseMetaData dbmd = _processConn.getMetaData(); + for (String col : _columns) { + try (ResultSet rs = dbmd.getColumns(_database, _schema, _table, col)) { + rs.next(); + if (isColumnTypeVector(rs.getString(6))) { + _vectorColumnsNameAndSize.put(col, rs.getInt(7)); + } + } + } + } catch (SQLException e) { + logger.error(e.getMessage(), e); + abort(new Loader.ConnectionError(Utils.getCause(e))); + } + } + + private boolean isColumnTypeVector(String col) { + return col != null && col.equalsIgnoreCase("vector"); + } + @Override public void run() { try { @@ -750,6 +783,10 @@ List getColumns() { return this._columns; } + Map getVectorColumns() { + return this._vectorColumnsNameAndSize; + } + String getColumnsAsString() { // comma separate list of column names StringBuilder sb = new StringBuilder("\""); @@ -904,4 +941,38 @@ public int getSubmittedRowCount() { void setTestMode(boolean mode) { this._testMode = mode; } + + public String getStageColumnsAsString() { + // if there are no vector columns in the target table just select * is needed from the staging + // table. + if (_vectorColumnsNameAndSize.isEmpty()) { + return "*"; + } + if (_vectorType == null) { + throw new IllegalArgumentException( + "Target table with vector columns must use setVectorColumnType with \"INT\" or \"FLOAT\""); + } + + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < _columns.size(); i++) { + String colName = _columns.get(i); + if (_vectorColumnsNameAndSize.containsKey(colName)) { + sb.append( + colName + + "::VECTOR(" + + _vectorType + + ", " + + _vectorColumnsNameAndSize.get(colName) + + ")"); + } else { + sb.append("\""); + sb.append(colName); + sb.append("\""); + } + if (i != _columns.size() - 1) { + sb.append(", "); + } + } + return sb.toString(); + } } diff --git a/src/test/java/net/snowflake/client/loader/LoaderLatestIT.java b/src/test/java/net/snowflake/client/loader/LoaderLatestIT.java index 07180cffef..ff9b195ced 100644 --- a/src/test/java/net/snowflake/client/loader/LoaderLatestIT.java +++ b/src/test/java/net/snowflake/client/loader/LoaderLatestIT.java @@ -201,4 +201,130 @@ public void testKeyClusteringTable() throws Exception { } } } + + @Test + private void testVectorColumnInTable() throws Exception { + String tableName = "VECTOR_TABLE"; + try { + testConnection + .createStatement() + .execute( + String.format("CREATE OR REPLACE TABLE %s (vector_col VECTOR(FLOAT, 3))", tableName)); + + TestDataConfigBuilder tdcb = new TestDataConfigBuilder(testConnection, putConnection); + tdcb.setOperation(Operation.INSERT) + .setStartTransaction(true) + .setTruncateTable(true) + .setTableName(tableName) + .setColumns(Arrays.asList("vector_col")); + StreamLoader loader = tdcb.getStreamLoader(); + TestDataConfigBuilder.ResultListener listener = tdcb.getListener(); + loader.start(); + + loader.submitRow(new Object[] {"[12, 14.0, 100]"}); + loader.setVectorColumnType("FLOAT"); + loader.finish(); + int submitted = listener.getSubmittedRowCount(); + assertThat("submitted rows", submitted, equalTo(1)); + + } finally { + testConnection.createStatement().execute(String.format("DROP TABLE IF EXISTS %s", tableName)); + } + } + + @Test + private void testMultipleFloatVectorColumnsInTable() throws Exception { + String tableName = "VECTOR_TABLE"; + try { + testConnection + .createStatement() + .execute( + String.format( + "CREATE OR REPLACE TABLE %s (vec1 VECTOR(FLOAT, 3), vec2 VECTOR(FLOAT, 3))", + tableName)); + + TestDataConfigBuilder tdcb = new TestDataConfigBuilder(testConnection, putConnection); + tdcb.setOperation(Operation.INSERT) + .setStartTransaction(true) + .setTruncateTable(true) + .setTableName(tableName) + .setColumns(Arrays.asList("vector_col")); + StreamLoader loader = tdcb.getStreamLoader(); + TestDataConfigBuilder.ResultListener listener = tdcb.getListener(); + loader.start(); + + loader.submitRow(new Object[] {"[12, 14.0, 100]", "[12, 14.0, 100]"}); + loader.setVectorColumnType("FLOAT"); + loader.finish(); + int submitted = listener.getSubmittedRowCount(); + assertThat("submitted rows", submitted, equalTo(1)); + + } finally { + testConnection.createStatement().execute(String.format("DROP TABLE IF EXISTS %s", tableName)); + } + } + + @Test + private void testMultipleIntVectorColumnsInTable() throws Exception { + String tableName = "VECTOR_TABLE"; + try { + testConnection + .createStatement() + .execute( + String.format( + "CREATE OR REPLACE TABLE %s (vec1 VECTOR(INT, 3), vec2 VECTOR(INT, 3))", + tableName)); + + TestDataConfigBuilder tdcb = new TestDataConfigBuilder(testConnection, putConnection); + tdcb.setOperation(Operation.INSERT) + .setStartTransaction(true) + .setTruncateTable(true) + .setTableName(tableName) + .setColumns(Arrays.asList("vector_col")); + StreamLoader loader = tdcb.getStreamLoader(); + TestDataConfigBuilder.ResultListener listener = tdcb.getListener(); + loader.start(); + + loader.submitRow(new Object[] {"[12, 14, 100]", "[12, 14, 100]"}); + loader.setVectorColumnType("INT"); + loader.finish(); + int submitted = listener.getSubmittedRowCount(); + assertThat("submitted rows", submitted, equalTo(1)); + + } finally { + testConnection.createStatement().execute(String.format("DROP TABLE IF EXISTS %s", tableName)); + } + } + + @Test + private void testMultipleTypesWithVectorColumnsInTable() throws Exception { + String tableName = "VECTOR_TABLE"; + try { + testConnection + .createStatement() + .execute( + String.format( + "CREATE OR REPLACE TABLE %s (vec1 VECTOR(FLOAT, 3), vec2 VECTOR(FLOAT, 3), ID int, colA varchar(255), colB date)", + tableName)); + + TestDataConfigBuilder tdcb = new TestDataConfigBuilder(testConnection, putConnection); + tdcb.setOperation(Operation.INSERT) + .setStartTransaction(true) + .setTruncateTable(true) + .setTableName(tableName) + .setColumns(Arrays.asList("vector_col")); + StreamLoader loader = tdcb.getStreamLoader(); + TestDataConfigBuilder.ResultListener listener = tdcb.getListener(); + loader.start(); + + loader.setVectorColumnType("FLOAT"); + loader.submitRow(new Object[] {"[12, 14.0, 100]", "[12, 14.0, 100]", 10, "abc", new Date()}); + loader.finish(); + int submitted = listener.getSubmittedRowCount(); + assertThat("submitted rows", submitted, equalTo(1)); + + } finally { + testConnection.createStatement().execute(String.format("DROP TABLE IF EXISTS %s", tableName)); + } + } }