Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#27312] fix(JmsIO): create a session pool for JmsIO #27312 #27313

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@

## I/Os

* Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Add Jms pool session to control number of connections (Java) ([#27312](https://github.com/apache/beam/issues/27312)).
* Fixed JmsIO unit tests trying to bind on a hard coded port number (Java) ([#26203](https://github.com/apache/beam/issues/26203)).
* Support for Bigtable Change Streams added in Java `BigtableIO.ReadChangeStream` ([#27183](https://github.com/apache/beam/issues/27183))

## New Features / Improvements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ class BeamModulePlugin implements Plugin<Project> {
commons_lang3 : "org.apache.commons:commons-lang3:3.9",
commons_logging : "commons-logging:commons-logging:1.2",
commons_math3 : "org.apache.commons:commons-math3:3.6.1",
commons_pool2 : "org.apache.commons:commons-pool2:2.11.1",
dbcp2 : "org.apache.commons:commons-dbcp2:$dbcp2_version",
error_prone_annotations : "com.google.errorprone:error_prone_annotations:$errorprone_version",
flogger_system_backend : "com.google.flogger:flogger-system-backend:0.7.3",
Expand Down
2 changes: 1 addition & 1 deletion sdks/java/io/jdbc/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies {
implementation project(path: ":sdks:java:core", configuration: "shadow")
implementation library.java.dbcp2
implementation library.java.joda_time
implementation "org.apache.commons:commons-pool2:2.11.1"
implementation library.java.commons_pool2
implementation library.java.slf4j_api
testImplementation "org.apache.derby:derby:10.14.2.0"
testImplementation "org.apache.derby:derbyclient:10.14.2.0"
Expand Down
1 change: 1 addition & 0 deletions sdks/java/io/jms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies {
implementation project(path: ":sdks:java:core", configuration: "shadow")
implementation library.java.slf4j_api
implementation library.java.joda_time
implementation library.java.commons_pool2
implementation "org.apache.geronimo.specs:geronimo-jms_1.1_spec:1.1.1"
testImplementation library.java.activemq_amqp
testImplementation library.java.activemq_broker
Expand Down
120 changes: 58 additions & 62 deletions sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
import org.apache.beam.sdk.io.UnboundedSource;
import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark;
import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader;
import org.apache.beam.sdk.io.jms.pool.JmsPoolConfiguration;
import org.apache.beam.sdk.io.jms.pool.JmsSessionPool;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.options.ExecutorOptions;
Expand Down Expand Up @@ -699,22 +701,24 @@ protected void finalize() {
public abstract static class Write<EventT>
extends PTransform<PCollection<EventT>, WriteJmsResult<EventT>> {

abstract @Nullable ConnectionFactory getConnectionFactory();
public abstract @Nullable ConnectionFactory getConnectionFactory();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are these public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I still remember, I was trying to access to it from the new package pool. We can move the pool classes to the main package to keep all functions as they were


abstract @Nullable String getQueue();

abstract @Nullable String getTopic();

abstract @Nullable String getUsername();
public abstract @Nullable String getUsername();

abstract @Nullable String getPassword();
public abstract @Nullable String getPassword();

abstract @Nullable SerializableBiFunction<EventT, Session, Message> getValueMapper();

abstract @Nullable SerializableFunction<EventT, String> getTopicNameMapper();

abstract @Nullable RetryConfiguration getRetryConfiguration();

public abstract @Nullable JmsPoolConfiguration getJmsPoolConfiguration();

abstract Builder<EventT> builder();

@AutoValue.Builder
Expand All @@ -737,6 +741,8 @@ abstract Builder<EventT> setTopicNameMapper(

abstract Builder<EventT> setRetryConfiguration(RetryConfiguration retryConfiguration);

abstract Builder<EventT> setJmsPoolConfiguration(JmsPoolConfiguration jmsPoolConfiguration);

abstract Write<EventT> build();
}

Expand Down Expand Up @@ -919,6 +925,11 @@ public Write<EventT> withRetryConfiguration(RetryConfiguration retryConfiguratio
return builder().setRetryConfiguration(retryConfiguration).build();
}

public Write<EventT> withJmsPoolConfiguration(JmsPoolConfiguration poolConfiguration) {
checkArgument(poolConfiguration != null, "poolConfiguration can not be null");
return builder().setJmsPoolConfiguration(poolConfiguration).build();
}

@Override
public WriteJmsResult<EventT> expand(PCollection<EventT> input) {
checkArgument(getConnectionFactory() != null, "withConnectionFactory() is required");
Expand All @@ -930,23 +941,21 @@ public WriteJmsResult<EventT> expand(PCollection<EventT> input) {
exclusiveTopicQueue,
"Only one of withQueue(queue), withTopic(topic), or withTopicNameMapper(function) must be set.");
checkArgument(getValueMapper() != null, "withValueMapper() is required");
checkArgument(getJmsPoolConfiguration() != null, "withJmsPoolConfiguration() is required");

return input.apply(new Writer<>(this));
}

private boolean isExclusiveTopicQueue() {
boolean exclusiveTopicQueue =
Stream.of(getQueue() != null, getTopic() != null, getTopicNameMapper() != null)
.filter(b -> b)
.count()
== 1;
return exclusiveTopicQueue;
return Stream.of(getQueue() != null, getTopic() != null, getTopicNameMapper() != null)
.filter(b -> b)
.count()
== 1;
}
}

static class Writer<T> extends PTransform<PCollection<T>, WriteJmsResult<T>> {
public static class Writer<T> extends PTransform<PCollection<T>, WriteJmsResult<T>> {

public static final String CONNECTION_ERRORS_METRIC_NAME = "connectionErrors";
public static final String PUBLICATION_RETRIES_METRIC_NAME = "publicationRetries";
public static final String JMS_IO_PRODUCER_METRIC_NAME = Writer.class.getCanonicalName();

Expand Down Expand Up @@ -982,43 +991,15 @@ private static class JmsConnection<T> implements Serializable {
private static final long serialVersionUID = 1L;

private transient @Initialized Session session;
private transient @Initialized Connection connection;
private transient @Initialized Destination destination;
private transient @Initialized MessageProducer producer;

private final JmsIO.Write<T> spec;
private final Counter connectionErrors =
Metrics.counter(JMS_IO_PRODUCER_METRIC_NAME, CONNECTION_ERRORS_METRIC_NAME);
private final JmsSessionPool<T> sessionPool;

JmsConnection(Write<T> spec) {
this.spec = spec;
}

void connect() throws JMSException {
if (this.producer == null) {
ConnectionFactory connectionFactory = spec.getConnectionFactory();
if (spec.getUsername() != null) {
this.connection =
connectionFactory.createConnection(spec.getUsername(), spec.getPassword());
} else {
this.connection = connectionFactory.createConnection();
}
this.connection.setExceptionListener(
exception -> {
this.connectionErrors.inc();
});
this.connection.start();
// false means we don't use JMS transaction.
this.session = this.connection.createSession(false, Session.AUTO_ACKNOWLEDGE);

if (spec.getQueue() != null) {
this.destination = session.createQueue(spec.getQueue());
} else if (spec.getTopic() != null) {
this.destination = session.createTopic(spec.getTopic());
}
// Create producer with null destination. Destination will be set with producer.send().
startProducer();
}
this.sessionPool = new JmsSessionPool<>(spec);
}

void publishMessage(T input) throws JMSException, JmsIOException {
Expand All @@ -1038,33 +1019,48 @@ void publishMessage(T input) throws JMSException, JmsIOException {
}
}

void startProducer() throws JMSException {
void addSessions() throws Exception {
JmsPoolConfiguration configuration = this.spec.getJmsPoolConfiguration();
sessionPool.addObjects(configuration.getInitialActiveConnections());
}

void startProducer() throws Exception {
this.session = sessionPool.borrowObject();
if (spec.getQueue() != null) {
this.destination = this.session.createQueue(spec.getQueue());
} else if (spec.getTopic() != null) {
this.destination = this.session.createTopic(spec.getTopic());
}
this.producer = this.session.createProducer(null);
}

void closeProducer() throws JMSException {
void releaseSession() throws Exception {
if (producer != null) {
producer.close();
producer = null;
}
sessionPool.returnObject(this.session);
}

void close() {
void reconnect() throws Exception {
release();
startProducer();
}

void release() {
try {
closeProducer();
if (session != null) {
session.close();
}
if (connection != null) {
connection.close();
}
} catch (JMSException exception) {
LOG.warn("The connection couldn't be closed", exception);
releaseSession();
} catch (Exception exception) {
LOG.warn("The session couldn't be released", exception);
} finally {
session = null;
connection = null;
}
}

void close() {
release();
this.sessionPool.close();
}
}

static class JmsIOProducerFn<T> extends DoFn<T, T> {
Expand All @@ -1084,27 +1080,27 @@ static class JmsIOProducerFn<T> extends DoFn<T, T> {
}

@Setup
public void setup() throws JMSException {
this.jmsConnection.connect();
public void setup() throws Exception {
RetryConfiguration retryConfiguration =
MoreObjects.firstNonNull(spec.getRetryConfiguration(), RetryConfiguration.create());
retryBackOff =
FluentBackoff.DEFAULT
.withInitialBackoff(checkStateNotNull(retryConfiguration.getInitialDuration()))
.withMaxCumulativeBackoff(checkStateNotNull(retryConfiguration.getMaxDuration()))
.withMaxRetries(retryConfiguration.getMaxAttempts());
this.jmsConnection.addSessions();
}

@StartBundle
public void startBundle() throws JMSException {
public void startBundle() throws Exception {
this.jmsConnection.startProducer();
}

@ProcessElement
public void processElement(@Element T input, ProcessContext context) {
try {
publishMessage(input);
} catch (JMSException | JmsIOException | IOException | InterruptedException exception) {
} catch (Exception exception) {
LOG.error("Error while publishing the message", exception);
context.output(this.failedMessagesTags, input);
if (exception instanceof InterruptedException) {
Expand All @@ -1113,8 +1109,7 @@ public void processElement(@Element T input, ProcessContext context) {
}
}

private void publishMessage(T input)
throws JMSException, JmsIOException, IOException, InterruptedException {
private void publishMessage(T input) throws Exception {
Sleeper sleeper = Sleeper.DEFAULT;
BackOff backoff = checkStateNotNull(retryBackOff).backoff();
while (true) {
Expand All @@ -1126,14 +1121,15 @@ private void publishMessage(T input)
throw exception;
} else {
publicationRetries.inc();
this.jmsConnection.reconnect();
}
}
}
}

@FinishBundle
public void finishBundle() throws JMSException {
this.jmsConnection.closeProducer();
public void finishBundle() {
this.jmsConnection.release();
}

@Teardown
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.jms.pool;

import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;

import com.google.auto.value.AutoValue;
import java.io.Serializable;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;

@AutoValue
public abstract class JmsPoolConfiguration implements Serializable {
private static final Integer DEFAULT_MAX_ACTIVE_CONNECTIONS = 100;
private static final Integer DEFAULT_INITIAL_ACTIVE_CONNECTIONS = 20;
private static final Duration DEFAULT_MAX_TIMEOUT_DURATION = Duration.standardMinutes(5);

abstract int getMaxActiveConnections();

public abstract int getInitialActiveConnections();

public abstract Duration getMaxTimeout();

public static JmsPoolConfiguration create() {
return create(DEFAULT_MAX_ACTIVE_CONNECTIONS, DEFAULT_INITIAL_ACTIVE_CONNECTIONS, null);
}

public static JmsPoolConfiguration create(
int maxActiveConnections, int initialActiveConnections, @Nullable Duration maxTimeout) {
checkArgument(maxActiveConnections > 0, "maxActiveConnections should be greater than 0");
checkArgument(
initialActiveConnections > 0, "initialActiveConnections should be greater than 0");

if (maxTimeout == null || maxTimeout.equals(Duration.ZERO)) {
maxTimeout = DEFAULT_MAX_TIMEOUT_DURATION;
}

return new AutoValue_JmsPoolConfiguration.Builder()
.setMaxActiveConnections(maxActiveConnections)
.setInitialActiveConnections(initialActiveConnections)
.setMaxTimeout(maxTimeout)
.build();
}

@AutoValue.Builder
abstract static class Builder {
abstract JmsPoolConfiguration.Builder setMaxActiveConnections(int maxActiveConnections);

abstract JmsPoolConfiguration.Builder setInitialActiveConnections(int initialActiveConnections);

abstract JmsPoolConfiguration.Builder setMaxTimeout(Duration maxTimeout);

abstract JmsPoolConfiguration build();
}
}
Loading