25
25
import io .cdap .plugin .db .TransactionIsolationLevel ;
26
26
import io .cdap .plugin .util .DBUtils ;
27
27
import org .apache .hadoop .conf .Configuration ;
28
+ import org .apache .hadoop .mapreduce .JobContext ;
29
+ import org .apache .hadoop .mapreduce .OutputCommitter ;
28
30
import org .apache .hadoop .mapreduce .RecordWriter ;
29
31
import org .apache .hadoop .mapreduce .TaskAttemptContext ;
30
32
import org .apache .hadoop .mapreduce .lib .db .DBConfiguration ;
43
45
import java .sql .Statement ;
44
46
import java .util .Map ;
45
47
import java .util .Properties ;
48
+ import java .util .concurrent .ConcurrentHashMap ;
46
49
47
50
import static io .cdap .plugin .db .ConnectionConfigAccessor .OPERATION_NAME ;
48
51
import static io .cdap .plugin .db .ConnectionConfigAccessor .RELATION_TABLE_KEY ;
56
59
public class ETLDBOutputFormat <K extends DBWritable , V > extends DBOutputFormat <K , V > {
57
60
// Batch size before submitting a batch to the SQL engine. If set to 0, no batches will be submitted until commit.
58
61
public static final String COMMIT_BATCH_SIZE = "io.cdap.plugin.db.output.commit.batch.size" ;
62
+ public static final String STAGE_NAME = "io.cdap.plugin.db.output.stage_name" ;
59
63
public static final int DEFAULT_COMMIT_BATCH_SIZE = 1000 ;
60
64
private static final Character ESCAPE_CHAR = '"' ;
61
65
66
+ // Format for connection map's key will be "taskAttemptId_stageName"
67
+ private static final String CONNECTION_MAP_KEY_FORMAT = "%s_%s" ;
68
+
69
+ // CONNECTION_MAP will be used to store connections with "taskAttemptId_stageName" as key and
70
+ // connection object as value. Making it static to be accessed from multiple task attempts within same executor.
71
+ private static final Map <String , Connection > CONNECTION_MAP = new ConcurrentHashMap <>();
62
72
private static final Logger LOG = LoggerFactory .getLogger (ETLDBOutputFormat .class );
63
73
64
74
private Configuration conf ;
65
75
private Driver driver ;
66
76
private JDBCDriverShim driverShim ;
67
77
78
+ @ Override
79
+ public OutputCommitter getOutputCommitter (TaskAttemptContext context )
80
+ throws IOException , InterruptedException {
81
+ return new OutputCommitter () {
82
+ @ Override
83
+ public void setupJob (JobContext jobContext ) throws IOException {
84
+ // do nothing
85
+ }
86
+
87
+ @ Override
88
+ public void setupTask (TaskAttemptContext taskContext ) throws IOException {
89
+ // do nothing
90
+ }
91
+
92
+ @ Override
93
+ public boolean needsTaskCommit (TaskAttemptContext taskContext ) throws IOException {
94
+ return true ;
95
+ }
96
+
97
+ @ Override
98
+ public void commitTask (TaskAttemptContext taskContext ) throws IOException {
99
+ conf = context .getConfiguration ();
100
+ String stageName = conf .get (STAGE_NAME );
101
+ String connectionId = getConnectionMapKeyFormat (context .getTaskAttemptID ().toString (), stageName );
102
+ Connection connection ;
103
+ if ((connection = CONNECTION_MAP .remove (connectionId )) != null ) {
104
+ try {
105
+ connection .commit ();
106
+ } catch (SQLException e ) {
107
+ try {
108
+ connection .rollback ();
109
+ } catch (SQLException ex ) {
110
+ LOG .warn (StringUtils .stringifyException (ex ));
111
+ }
112
+ throw new IOException (e );
113
+ } finally {
114
+ try {
115
+ connection .close ();
116
+ LOG .debug ("Connection Closed after committing the task with taskAttemptId {}" , connectionId );
117
+ } catch (SQLException ex ) {
118
+ LOG .warn (StringUtils .stringifyException (ex ));
119
+ }
120
+ }
121
+ }
122
+ }
123
+
124
+ @ Override
125
+ public void abortTask (TaskAttemptContext taskContext ) throws IOException {
126
+ conf = context .getConfiguration ();
127
+ String stageName = conf .get (STAGE_NAME );
128
+ String connectionId = getConnectionMapKeyFormat (context .getTaskAttemptID ().toString (), stageName );
129
+ Connection connection ;
130
+ if ((connection = CONNECTION_MAP .remove (connectionId )) != null ) {
131
+ try {
132
+ connection .rollback ();
133
+ } catch (SQLException e ) {
134
+ throw new IOException (e );
135
+ } finally {
136
+ try {
137
+ connection .close ();
138
+ LOG .debug ("Connection Closed after rollback the task with taskAttemptId {}" , connectionId );
139
+ } catch (SQLException ex ) {
140
+ LOG .warn (StringUtils .stringifyException (ex ));
141
+ }
142
+ }
143
+ }
144
+ }
145
+ };
146
+ }
147
+
68
148
@ Override
69
149
public RecordWriter <K , V > getRecordWriter (TaskAttemptContext context ) throws IOException {
70
150
conf = context .getConfiguration ();
@@ -81,6 +161,11 @@ public RecordWriter<K, V> getRecordWriter(TaskAttemptContext context) throws IOE
81
161
82
162
try {
83
163
Connection connection = getConnection (conf );
164
+ String stageName = conf .get (STAGE_NAME );
165
+ // If using multiple sinks, task attemptID can be same in that case, appending stage in the end for uniqueness.
166
+ String connectionId = getConnectionMapKeyFormat (context .getTaskAttemptID ().toString (), stageName );
167
+ CONNECTION_MAP .put (connectionId , connection );
168
+ LOG .debug ("Connection Added to the map with connectionId : {}" , connectionId );
84
169
PreparedStatement statement = connection .prepareStatement (constructQueryOnOperation (tableName , fieldNames ,
85
170
operationName , listKeys ));
86
171
return new DBRecordWriter (connection , statement ) {
@@ -98,23 +183,15 @@ public void close(TaskAttemptContext context) throws IOException {
98
183
if (!emptyData ) {
99
184
getStatement ().executeBatch ();
100
185
}
101
- getConnection ().commit ();
102
186
} catch (SQLException e ) {
103
- try {
104
- getConnection ().rollback ();
105
- } catch (SQLException ex ) {
106
- LOG .warn (StringUtils .stringifyException (ex ));
107
- }
108
187
throw new IOException (e );
109
188
} finally {
110
189
try {
111
190
getStatement ().close ();
112
- getConnection ().close ();
113
191
} catch (SQLException ex ) {
114
192
throw new IOException (ex );
115
193
}
116
194
}
117
-
118
195
try {
119
196
DriverManager .deregisterDriver (driverShim );
120
197
} catch (SQLException e ) {
@@ -298,4 +375,8 @@ public String constructUpdateQuery(String table, String[] fieldNames, String[] l
298
375
return query .toString ();
299
376
}
300
377
}
378
+
379
+ private String getConnectionMapKeyFormat (String taskAttemptId , String stageName ) {
380
+ return String .format (CONNECTION_MAP_KEY_FORMAT , taskAttemptId , stageName );
381
+ }
301
382
}
0 commit comments