diff --git a/database-commons/src/main/java/io/cdap/plugin/db/sink/AbstractDBSink.java b/database-commons/src/main/java/io/cdap/plugin/db/sink/AbstractDBSink.java index 797abfc23..0bb4bf123 100644 --- a/database-commons/src/main/java/io/cdap/plugin/db/sink/AbstractDBSink.java +++ b/database-commons/src/main/java/io/cdap/plugin/db/sink/AbstractDBSink.java @@ -228,6 +228,7 @@ public void prepareRun(BatchSinkContext context) { configAccessor.setInitQueries(dbSinkConfig.getInitQueries()); configAccessor.getConfiguration().set(DBConfiguration.DRIVER_CLASS_PROPERTY, driverClass.getName()); configAccessor.getConfiguration().set(DBConfiguration.URL_PROPERTY, connectionString); + configAccessor.getConfiguration().set(ETLDBOutputFormat.STAGE_NAME, context.getStageName()); String fullyQualifiedTableName = dbSchemaName == null ? dbSinkConfig.getEscapedTableName() : dbSinkConfig.getEscapedDbSchemaName() + "." + dbSinkConfig.getEscapedTableName(); configAccessor.getConfiguration().set(DBConfiguration.OUTPUT_TABLE_NAME_PROPERTY, fullyQualifiedTableName); diff --git a/database-commons/src/main/java/io/cdap/plugin/db/sink/ETLDBOutputFormat.java b/database-commons/src/main/java/io/cdap/plugin/db/sink/ETLDBOutputFormat.java index ad2b91ab1..ad196386c 100644 --- a/database-commons/src/main/java/io/cdap/plugin/db/sink/ETLDBOutputFormat.java +++ b/database-commons/src/main/java/io/cdap/plugin/db/sink/ETLDBOutputFormat.java @@ -25,6 +25,8 @@ import io.cdap.plugin.db.TransactionIsolationLevel; import io.cdap.plugin.util.DBUtils; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapreduce.JobContext; +import org.apache.hadoop.mapreduce.OutputCommitter; import org.apache.hadoop.mapreduce.RecordWriter; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.hadoop.mapreduce.lib.db.DBConfiguration; @@ -43,6 +45,7 @@ import java.sql.Statement; import java.util.Map; import java.util.Properties; +import java.util.concurrent.ConcurrentHashMap; import static io.cdap.plugin.db.ConnectionConfigAccessor.OPERATION_NAME; import static io.cdap.plugin.db.ConnectionConfigAccessor.RELATION_TABLE_KEY; @@ -56,15 +59,92 @@ public class ETLDBOutputFormat extends DBOutputFormat { // Batch size before submitting a batch to the SQL engine. If set to 0, no batches will be submitted until commit. public static final String COMMIT_BATCH_SIZE = "io.cdap.plugin.db.output.commit.batch.size"; + public static final String STAGE_NAME = "io.cdap.plugin.db.output.stage_name"; public static final int DEFAULT_COMMIT_BATCH_SIZE = 1000; private static final Character ESCAPE_CHAR = '"'; + // Format for connection map's key will be "taskAttemptId_stageName" + private static final String CONNECTION_MAP_KEY_FORMAT = "%s_%s"; + + // CONNECTION_MAP will be used to store connections with "taskAttemptId_stageName" as key and + // connection object as value. Making it static to be accessed from multiple task attempts within same executor. + private static final Map CONNECTION_MAP = new ConcurrentHashMap<>(); private static final Logger LOG = LoggerFactory.getLogger(ETLDBOutputFormat.class); private Configuration conf; private Driver driver; private JDBCDriverShim driverShim; + @Override + public OutputCommitter getOutputCommitter(TaskAttemptContext context) + throws IOException, InterruptedException { + return new OutputCommitter() { + @Override + public void setupJob(JobContext jobContext) throws IOException { + // do nothing + } + + @Override + public void setupTask(TaskAttemptContext taskContext) throws IOException { + // do nothing + } + + @Override + public boolean needsTaskCommit(TaskAttemptContext taskContext) throws IOException { + return true; + } + + @Override + public void commitTask(TaskAttemptContext taskContext) throws IOException { + conf = context.getConfiguration(); + String stageName = conf.get(STAGE_NAME); + String connectionId = getConnectionMapKeyFormat(context.getTaskAttemptID().toString(), stageName); + Connection connection; + if ((connection = CONNECTION_MAP.remove(connectionId)) != null) { + try { + connection.commit(); + } catch (SQLException e) { + try { + connection.rollback(); + } catch (SQLException ex) { + LOG.warn(StringUtils.stringifyException(ex)); + } + throw new IOException(e); + } finally { + try { + connection.close(); + LOG.debug("Connection Closed after committing the task with taskAttemptId {}", connectionId); + } catch (SQLException ex) { + LOG.warn(StringUtils.stringifyException(ex)); + } + } + } + } + + @Override + public void abortTask(TaskAttemptContext taskContext) throws IOException { + conf = context.getConfiguration(); + String stageName = conf.get(STAGE_NAME); + String connectionId = getConnectionMapKeyFormat(context.getTaskAttemptID().toString(), stageName); + Connection connection; + if ((connection = CONNECTION_MAP.remove(connectionId)) != null) { + try { + connection.rollback(); + } catch (SQLException e) { + throw new IOException(e); + } finally { + try { + connection.close(); + LOG.debug("Connection Closed after rollback the task with taskAttemptId {}", connectionId); + } catch (SQLException ex) { + LOG.warn(StringUtils.stringifyException(ex)); + } + } + } + } + }; + } + @Override public RecordWriter getRecordWriter(TaskAttemptContext context) throws IOException { conf = context.getConfiguration(); @@ -81,6 +161,11 @@ public RecordWriter getRecordWriter(TaskAttemptContext context) throws IOE try { Connection connection = getConnection(conf); + String stageName = conf.get(STAGE_NAME); + // If using multiple sinks, task attemptID can be same in that case, appending stage in the end for uniqueness. + String connectionId = getConnectionMapKeyFormat(context.getTaskAttemptID().toString(), stageName); + CONNECTION_MAP.put(connectionId, connection); + LOG.debug("Connection Added to the map with connectionId : {}", connectionId); PreparedStatement statement = connection.prepareStatement(constructQueryOnOperation(tableName, fieldNames, operationName, listKeys)); return new DBRecordWriter(connection, statement) { @@ -98,23 +183,15 @@ public void close(TaskAttemptContext context) throws IOException { if (!emptyData) { getStatement().executeBatch(); } - getConnection().commit(); } catch (SQLException e) { - try { - getConnection().rollback(); - } catch (SQLException ex) { - LOG.warn(StringUtils.stringifyException(ex)); - } throw new IOException(e); } finally { try { getStatement().close(); - getConnection().close(); } catch (SQLException ex) { throw new IOException(ex); } } - try { DriverManager.deregisterDriver(driverShim); } catch (SQLException e) { @@ -298,4 +375,8 @@ public String constructUpdateQuery(String table, String[] fieldNames, String[] l return query.toString(); } } + + private String getConnectionMapKeyFormat(String taskAttemptId, String stageName) { + return String.format(CONNECTION_MAP_KEY_FORMAT, taskAttemptId, stageName); + } }