diff --git a/CHANGES.md b/CHANGES.md index aee9b96d9f4a..d9981d69a276 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -58,7 +58,8 @@ ## I/Os -* Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Add Jms pool session to control number of connections (Java) ([#27312](https://github.com/apache/beam/issues/27312)). +* Fixed JmsIO unit tests trying to bind on a hard coded port number (Java) ([#26203](https://github.com/apache/beam/issues/26203)). * Support for Bigtable Change Streams added in Java `BigtableIO.ReadChangeStream` ([#27183](https://github.com/apache/beam/issues/27183)) ## New Features / Improvements diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index 4e6620c0bbc1..bbed3850ca72 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -637,6 +637,7 @@ class BeamModulePlugin implements Plugin { commons_lang3 : "org.apache.commons:commons-lang3:3.9", commons_logging : "commons-logging:commons-logging:1.2", commons_math3 : "org.apache.commons:commons-math3:3.6.1", + commons_pool2 : "org.apache.commons:commons-pool2:2.11.1", dbcp2 : "org.apache.commons:commons-dbcp2:$dbcp2_version", error_prone_annotations : "com.google.errorprone:error_prone_annotations:$errorprone_version", flogger_system_backend : "com.google.flogger:flogger-system-backend:0.7.3", diff --git a/sdks/java/io/jdbc/build.gradle b/sdks/java/io/jdbc/build.gradle index 07262bc793f4..29d42d321226 100644 --- a/sdks/java/io/jdbc/build.gradle +++ b/sdks/java/io/jdbc/build.gradle @@ -31,7 +31,7 @@ dependencies { implementation project(path: ":sdks:java:core", configuration: "shadow") implementation library.java.dbcp2 implementation library.java.joda_time - implementation "org.apache.commons:commons-pool2:2.11.1" + implementation library.java.commons_pool2 implementation library.java.slf4j_api testImplementation "org.apache.derby:derby:10.14.2.0" testImplementation "org.apache.derby:derbyclient:10.14.2.0" diff --git a/sdks/java/io/jms/build.gradle b/sdks/java/io/jms/build.gradle index 5ecc0ec19d57..efaa7222d1bc 100644 --- a/sdks/java/io/jms/build.gradle +++ b/sdks/java/io/jms/build.gradle @@ -32,6 +32,7 @@ dependencies { implementation project(path: ":sdks:java:core", configuration: "shadow") implementation library.java.slf4j_api implementation library.java.joda_time + implementation library.java.commons_pool2 implementation "org.apache.geronimo.specs:geronimo-jms_1.1_spec:1.1.1" testImplementation library.java.activemq_amqp testImplementation library.java.activemq_broker diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java index 0528b576de72..c9606e0b6777 100644 --- a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java +++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java @@ -47,6 +47,8 @@ import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; +import org.apache.beam.sdk.io.jms.pool.JmsPoolConfiguration; +import org.apache.beam.sdk.io.jms.pool.JmsSessionPool; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.options.ExecutorOptions; @@ -699,15 +701,15 @@ protected void finalize() { public abstract static class Write extends PTransform, WriteJmsResult> { - abstract @Nullable ConnectionFactory getConnectionFactory(); + public abstract @Nullable ConnectionFactory getConnectionFactory(); abstract @Nullable String getQueue(); abstract @Nullable String getTopic(); - abstract @Nullable String getUsername(); + public abstract @Nullable String getUsername(); - abstract @Nullable String getPassword(); + public abstract @Nullable String getPassword(); abstract @Nullable SerializableBiFunction getValueMapper(); @@ -715,6 +717,8 @@ public abstract static class Write abstract @Nullable RetryConfiguration getRetryConfiguration(); + public abstract @Nullable JmsPoolConfiguration getJmsPoolConfiguration(); + abstract Builder builder(); @AutoValue.Builder @@ -737,6 +741,8 @@ abstract Builder setTopicNameMapper( abstract Builder setRetryConfiguration(RetryConfiguration retryConfiguration); + abstract Builder setJmsPoolConfiguration(JmsPoolConfiguration jmsPoolConfiguration); + abstract Write build(); } @@ -919,6 +925,11 @@ public Write withRetryConfiguration(RetryConfiguration retryConfiguratio return builder().setRetryConfiguration(retryConfiguration).build(); } + public Write withJmsPoolConfiguration(JmsPoolConfiguration poolConfiguration) { + checkArgument(poolConfiguration != null, "poolConfiguration can not be null"); + return builder().setJmsPoolConfiguration(poolConfiguration).build(); + } + @Override public WriteJmsResult expand(PCollection input) { checkArgument(getConnectionFactory() != null, "withConnectionFactory() is required"); @@ -930,23 +941,21 @@ public WriteJmsResult expand(PCollection input) { exclusiveTopicQueue, "Only one of withQueue(queue), withTopic(topic), or withTopicNameMapper(function) must be set."); checkArgument(getValueMapper() != null, "withValueMapper() is required"); + checkArgument(getJmsPoolConfiguration() != null, "withJmsPoolConfiguration() is required"); return input.apply(new Writer<>(this)); } private boolean isExclusiveTopicQueue() { - boolean exclusiveTopicQueue = - Stream.of(getQueue() != null, getTopic() != null, getTopicNameMapper() != null) - .filter(b -> b) - .count() - == 1; - return exclusiveTopicQueue; + return Stream.of(getQueue() != null, getTopic() != null, getTopicNameMapper() != null) + .filter(b -> b) + .count() + == 1; } } - static class Writer extends PTransform, WriteJmsResult> { + public static class Writer extends PTransform, WriteJmsResult> { - public static final String CONNECTION_ERRORS_METRIC_NAME = "connectionErrors"; public static final String PUBLICATION_RETRIES_METRIC_NAME = "publicationRetries"; public static final String JMS_IO_PRODUCER_METRIC_NAME = Writer.class.getCanonicalName(); @@ -982,43 +991,15 @@ private static class JmsConnection implements Serializable { private static final long serialVersionUID = 1L; private transient @Initialized Session session; - private transient @Initialized Connection connection; private transient @Initialized Destination destination; private transient @Initialized MessageProducer producer; private final JmsIO.Write spec; - private final Counter connectionErrors = - Metrics.counter(JMS_IO_PRODUCER_METRIC_NAME, CONNECTION_ERRORS_METRIC_NAME); + private final JmsSessionPool sessionPool; JmsConnection(Write spec) { this.spec = spec; - } - - void connect() throws JMSException { - if (this.producer == null) { - ConnectionFactory connectionFactory = spec.getConnectionFactory(); - if (spec.getUsername() != null) { - this.connection = - connectionFactory.createConnection(spec.getUsername(), spec.getPassword()); - } else { - this.connection = connectionFactory.createConnection(); - } - this.connection.setExceptionListener( - exception -> { - this.connectionErrors.inc(); - }); - this.connection.start(); - // false means we don't use JMS transaction. - this.session = this.connection.createSession(false, Session.AUTO_ACKNOWLEDGE); - - if (spec.getQueue() != null) { - this.destination = session.createQueue(spec.getQueue()); - } else if (spec.getTopic() != null) { - this.destination = session.createTopic(spec.getTopic()); - } - // Create producer with null destination. Destination will be set with producer.send(). - startProducer(); - } + this.sessionPool = new JmsSessionPool<>(spec); } void publishMessage(T input) throws JMSException, JmsIOException { @@ -1038,33 +1019,48 @@ void publishMessage(T input) throws JMSException, JmsIOException { } } - void startProducer() throws JMSException { + void addSessions() throws Exception { + JmsPoolConfiguration configuration = this.spec.getJmsPoolConfiguration(); + sessionPool.addObjects(configuration.getInitialActiveConnections()); + } + + void startProducer() throws Exception { + this.session = sessionPool.borrowObject(); + if (spec.getQueue() != null) { + this.destination = this.session.createQueue(spec.getQueue()); + } else if (spec.getTopic() != null) { + this.destination = this.session.createTopic(spec.getTopic()); + } this.producer = this.session.createProducer(null); } - void closeProducer() throws JMSException { + void releaseSession() throws Exception { if (producer != null) { producer.close(); producer = null; } + sessionPool.returnObject(this.session); } - void close() { + void reconnect() throws Exception { + release(); + startProducer(); + } + + void release() { try { - closeProducer(); - if (session != null) { - session.close(); - } - if (connection != null) { - connection.close(); - } - } catch (JMSException exception) { - LOG.warn("The connection couldn't be closed", exception); + releaseSession(); + } catch (Exception exception) { + LOG.warn("The session couldn't be released", exception); } finally { session = null; - connection = null; } } + + void close() { + release(); + this.sessionPool.close(); + } } static class JmsIOProducerFn extends DoFn { @@ -1084,8 +1080,7 @@ static class JmsIOProducerFn extends DoFn { } @Setup - public void setup() throws JMSException { - this.jmsConnection.connect(); + public void setup() throws Exception { RetryConfiguration retryConfiguration = MoreObjects.firstNonNull(spec.getRetryConfiguration(), RetryConfiguration.create()); retryBackOff = @@ -1093,10 +1088,11 @@ public void setup() throws JMSException { .withInitialBackoff(checkStateNotNull(retryConfiguration.getInitialDuration())) .withMaxCumulativeBackoff(checkStateNotNull(retryConfiguration.getMaxDuration())) .withMaxRetries(retryConfiguration.getMaxAttempts()); + this.jmsConnection.addSessions(); } @StartBundle - public void startBundle() throws JMSException { + public void startBundle() throws Exception { this.jmsConnection.startProducer(); } @@ -1104,7 +1100,7 @@ public void startBundle() throws JMSException { public void processElement(@Element T input, ProcessContext context) { try { publishMessage(input); - } catch (JMSException | JmsIOException | IOException | InterruptedException exception) { + } catch (Exception exception) { LOG.error("Error while publishing the message", exception); context.output(this.failedMessagesTags, input); if (exception instanceof InterruptedException) { @@ -1113,8 +1109,7 @@ public void processElement(@Element T input, ProcessContext context) { } } - private void publishMessage(T input) - throws JMSException, JmsIOException, IOException, InterruptedException { + private void publishMessage(T input) throws Exception { Sleeper sleeper = Sleeper.DEFAULT; BackOff backoff = checkStateNotNull(retryBackOff).backoff(); while (true) { @@ -1126,14 +1121,15 @@ private void publishMessage(T input) throw exception; } else { publicationRetries.inc(); + this.jmsConnection.reconnect(); } } } } @FinishBundle - public void finishBundle() throws JMSException { - this.jmsConnection.closeProducer(); + public void finishBundle() { + this.jmsConnection.release(); } @Teardown diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/pool/JmsPoolConfiguration.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/pool/JmsPoolConfiguration.java new file mode 100644 index 000000000000..59941dd23fcf --- /dev/null +++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/pool/JmsPoolConfiguration.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jms.pool; + +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; + +import com.google.auto.value.AutoValue; +import java.io.Serializable; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; + +@AutoValue +public abstract class JmsPoolConfiguration implements Serializable { + private static final Integer DEFAULT_MAX_ACTIVE_CONNECTIONS = 100; + private static final Integer DEFAULT_INITIAL_ACTIVE_CONNECTIONS = 20; + private static final Duration DEFAULT_MAX_TIMEOUT_DURATION = Duration.standardMinutes(5); + + abstract int getMaxActiveConnections(); + + public abstract int getInitialActiveConnections(); + + public abstract Duration getMaxTimeout(); + + public static JmsPoolConfiguration create() { + return create(DEFAULT_MAX_ACTIVE_CONNECTIONS, DEFAULT_INITIAL_ACTIVE_CONNECTIONS, null); + } + + public static JmsPoolConfiguration create( + int maxActiveConnections, int initialActiveConnections, @Nullable Duration maxTimeout) { + checkArgument(maxActiveConnections > 0, "maxActiveConnections should be greater than 0"); + checkArgument( + initialActiveConnections > 0, "initialActiveConnections should be greater than 0"); + + if (maxTimeout == null || maxTimeout.equals(Duration.ZERO)) { + maxTimeout = DEFAULT_MAX_TIMEOUT_DURATION; + } + + return new AutoValue_JmsPoolConfiguration.Builder() + .setMaxActiveConnections(maxActiveConnections) + .setInitialActiveConnections(initialActiveConnections) + .setMaxTimeout(maxTimeout) + .build(); + } + + @AutoValue.Builder + abstract static class Builder { + abstract JmsPoolConfiguration.Builder setMaxActiveConnections(int maxActiveConnections); + + abstract JmsPoolConfiguration.Builder setInitialActiveConnections(int initialActiveConnections); + + abstract JmsPoolConfiguration.Builder setMaxTimeout(Duration maxTimeout); + + abstract JmsPoolConfiguration build(); + } +} diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/pool/JmsSessionFactory.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/pool/JmsSessionFactory.java new file mode 100644 index 000000000000..40bf99ceed73 --- /dev/null +++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/pool/JmsSessionFactory.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jms.pool; + +import static org.apache.beam.sdk.io.jms.JmsIO.Writer.JMS_IO_PRODUCER_METRIC_NAME; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import javax.jms.Connection; +import javax.jms.ConnectionFactory; +import javax.jms.JMSException; +import javax.jms.Session; +import org.apache.beam.sdk.io.jms.JmsIO; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.commons.pool2.BasePooledObjectFactory; +import org.apache.commons.pool2.PooledObject; +import org.apache.commons.pool2.impl.DefaultPooledObject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public class JmsSessionFactory extends BasePooledObjectFactory { + + private static final Logger LOG = LoggerFactory.getLogger(JmsSessionFactory.class); + public static final String CONNECTION_ERRORS_METRIC_NAME = "connectionErrors"; + + private final Counter connectionErrors = + Metrics.counter(JMS_IO_PRODUCER_METRIC_NAME, CONNECTION_ERRORS_METRIC_NAME); + + private final JmsIO.Write spec; + private boolean isConnectionClosed; + + public JmsSessionFactory(JmsIO.Write spec) { + this.spec = spec; + } + + @Override + public Session create() throws JMSException { + Connection connection; + // reset the connection flag + this.isConnectionClosed = false; + ConnectionFactory connectionFactory = spec.getConnectionFactory(); + if (spec.getUsername() != null) { + connection = connectionFactory.createConnection(spec.getUsername(), spec.getPassword()); + } else { + connection = connectionFactory.createConnection(); + } + connection.setExceptionListener( + exception -> { + this.isConnectionClosed = true; + this.connectionErrors.inc(); + }); + LOG.debug("creating new connection: {}", connection); + connection.start(); + // false means we don't use JMS transaction. + return connection.createSession(false, Session.AUTO_ACKNOWLEDGE); + } + + @Override + public boolean validateObject(PooledObject pooledObject) { + return !isConnectionClosed && !callSessionMethod(pooledObject.getObject(), "isClosed", true); + } + + @Override + public void destroyObject(PooledObject pooledObject) throws JMSException { + Session session = pooledObject.getObject(); + session.close(); + Connection connection = callSessionMethod(session, "getConnection", null); + if (connection != null) { + connection.close(); + } + } + + @Override + public PooledObject wrap(Session session) { + return new DefaultPooledObject<>(session); + } + + private U callSessionMethod(Session session, String methodName, U defaultValue) { + Method isClosed = getSessionMethod(session, methodName); + if (isClosed != null) { + try { + return (U) isClosed.invoke(session); + } catch (IllegalAccessException | InvocationTargetException exception) { + LOG.debug( + "The class {} couldn't allow access to the function 'isClosed' with the following error {}", + session.getClass(), + exception.getMessage()); + } + } + return defaultValue; + } + + private Method getSessionMethod(Object session, String methodName) { + try { + return session.getClass().getDeclaredMethod(methodName); + } catch (NoSuchMethodException e) { + LOG.debug( + "The class {} implemented JMS Session doesn't contain a function to check if session is closed", + session.getClass()); + } + return null; + } +} diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/pool/JmsSessionPool.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/pool/JmsSessionPool.java new file mode 100644 index 000000000000..fedc3e595c48 --- /dev/null +++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/pool/JmsSessionPool.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jms.pool; + +import java.time.Duration; +import java.util.Objects; +import java.util.UUID; +import javax.jms.Session; +import org.apache.beam.sdk.io.jms.JmsIO; +import org.apache.commons.pool2.ObjectPool; +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.apache.commons.pool2.impl.GenericObjectPoolConfig; + +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public class JmsSessionPool extends SerializableSessionPool { + + private final String id = UUID.randomUUID().toString(); + + private final JmsIO.Write spec; + + public JmsSessionPool(JmsIO.Write spec) { + super(spec.getJmsPoolConfiguration().getMaxActiveConnections()); + this.spec = spec; + } + + @Override + public ObjectPool createDelegate() { + JmsPoolConfiguration configuration = spec.getJmsPoolConfiguration(); + JmsSessionFactory factory = new JmsSessionFactory<>(spec); + GenericObjectPoolConfig config = new GenericObjectPoolConfig<>(); + config.setLifo(false); + config.setFairness(true); + config.setJmxEnabled(false); + config.setTestOnCreate(true); + config.setTestOnBorrow(true); + config.setTestWhileIdle(true); + config.setTestOnReturn(true); + config.setNumTestsPerEvictionRun(-1); + config.setMinEvictableIdleTime(Duration.ofSeconds(60)); + config.setTimeBetweenEvictionRuns(Duration.ofSeconds(15)); + config.setSoftMinEvictableIdleTime(Duration.ofSeconds(60)); + config.setMaxWait( + Duration.ofSeconds( + Objects.requireNonNull(configuration.getMaxTimeout()).getStandardSeconds())); + config.setMaxTotal(configuration.getMaxActiveConnections()); + return new GenericObjectPool<>(factory, config); + } + + @Override + public String toString() { + return String.format("Session Pool id: %s", id); + } +} diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/pool/SerializableSessionPool.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/pool/SerializableSessionPool.java new file mode 100644 index 000000000000..afca138b1f96 --- /dev/null +++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/pool/SerializableSessionPool.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jms.pool; + +import java.io.Closeable; +import java.io.Serializable; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.commons.pool2.ObjectPool; + +public abstract class SerializableSessionPool implements ObjectPool, Serializable, Closeable { + + private final int maxConnections; + private final AtomicReference> delegate = new AtomicReference<>(); + + protected SerializableSessionPool(int maxConnections) { + this.maxConnections = maxConnections; + } + + protected abstract ObjectPool createDelegate(); + + public ObjectPool getDelegate() { + ObjectPool value = this.delegate.get(); + if (value == null) { + synchronized (this.delegate) { + value = this.delegate.get(); + if (value == null) { + final ObjectPool actualValue = createDelegate(); + value = actualValue; + this.delegate.set(actualValue); + } + } + } + return value; + } + + @Override + public T borrowObject() throws Exception { + ObjectPool pool = getDelegate(); + if (pool.getNumIdle() == 0 && pool.getNumActive() < maxConnections) { + pool.addObject(); + } + return pool.borrowObject(); + } + + @Override + public void returnObject(T obj) throws Exception { + getDelegate().returnObject(obj); + } + + @Override + public void invalidateObject(T obj) throws Exception { + getDelegate().invalidateObject(obj); + } + + @Override + public void addObject() throws Exception { + getDelegate().addObject(); + } + + @Override + public int getNumIdle() { + return getDelegate().getNumIdle(); + } + + @Override + public int getNumActive() { + return getDelegate().getNumActive(); + } + + @Override + public void clear() throws Exception { + getDelegate().clear(); + } + + @Override + public void close() { + getDelegate().close(); + } +} diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/pool/package-info.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/pool/package-info.java new file mode 100644 index 000000000000..f9bfde1476c7 --- /dev/null +++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/pool/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Transforms for reading and writing from Jms. */ +package org.apache.beam.sdk.io.jms.pool; diff --git a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java index fb6a08384f27..75ec82235714 100644 --- a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java +++ b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java @@ -41,6 +41,7 @@ import org.apache.beam.sdk.io.GenerateSequence; import org.apache.beam.sdk.io.common.IOITHelper; import org.apache.beam.sdk.io.common.IOTestPipelineOptions; +import org.apache.beam.sdk.io.jms.pool.JmsPoolConfiguration; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.options.Default; @@ -241,7 +242,9 @@ private PipelineResult publishingMessages() { .withUsername(USERNAME) .withPassword(PASSWORD) .withValueMapper(new TextMessageMapper()) - .withConnectionFactory(connectionFactory)); + .withConnectionFactory(connectionFactory) + .withJmsPoolConfiguration( + JmsPoolConfiguration.create(50, 10, Duration.standardSeconds(15)))); return pipelineWrite.run(); } diff --git a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java index 10f3ec7317cb..07cfd1bc0305 100644 --- a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java +++ b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java @@ -49,7 +49,6 @@ import static org.mockito.Mockito.when; import java.io.IOException; -import java.io.Serializable; import java.lang.reflect.Proxy; import java.nio.charset.StandardCharsets; import java.time.Instant; @@ -78,6 +77,8 @@ import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.io.jms.pool.JmsPoolConfiguration; +import org.apache.beam.sdk.io.jms.utils.JmsIOTestUtils; import org.apache.beam.sdk.metrics.MetricNameFilter; import org.apache.beam.sdk.metrics.MetricQueryResults; import org.apache.beam.sdk.metrics.MetricsFilter; @@ -117,16 +118,25 @@ public class JmsIOTest { private static final Logger LOG = LoggerFactory.getLogger(JmsIOTest.class); private final RetryConfiguration retryConfiguration = RetryConfiguration.create(1, Duration.standardSeconds(1), null); + private final JmsPoolConfiguration jmsPoolConfiguration = + JmsPoolConfiguration.create(10, 5, null); + @Rule public final transient TestPipeline pipeline = TestPipeline.create(); @Parameterized.Parameters(name = "with client class {3}") public static Collection connectionFactories() { return Arrays.asList( new Object[] { - "vm://localhost", 5672, "jms.sendAcksAsync=false", ActiveMQConnectionFactory.class + "vm://localhost", + JmsIOTestUtils.getUnusedPort(), + "jms.sendAcksAsync=false", + ActiveMQConnectionFactory.class }, new Object[] { - "amqp://localhost", 5672, "jms.forceAsyncAcks=false", JmsConnectionFactory.class + "amqp://localhost", + JmsIOTestUtils.getUnusedPort(), + "jms.forceAsyncAcks=false", + JmsConnectionFactory.class }); } @@ -286,7 +296,8 @@ public void testWriteMessage() throws Exception { .withRetryConfiguration(retryConfiguration) .withQueue(QUEUE) .withUsername(USERNAME) - .withPassword(PASSWORD)); + .withPassword(PASSWORD) + .withJmsPoolConfiguration(jmsPoolConfiguration)); pipeline.run(); @@ -314,11 +325,12 @@ public void testWriteMessageWithError() throws Exception { .apply( JmsIO.write() .withConnectionFactory(connectionFactory) - .withValueMapper(new TextMessageMapperWithError()) + .withValueMapper(new JmsIOTestUtils.TextMessageMapperWithError()) .withRetryConfiguration(retryConfiguration) .withQueue(QUEUE) .withUsername(USERNAME) - .withPassword(PASSWORD)); + .withPassword(PASSWORD) + .withJmsPoolConfiguration(jmsPoolConfiguration)); PAssert.that(output.getFailedMessages()).containsInAnyOrder("Message 1", "Message 2"); @@ -342,22 +354,23 @@ public void testWriteDynamicMessage() throws Exception { Session session = connection.createSession(false, Session.AUTO_ACKNOWLEDGE); MessageConsumer consumerOne = session.createConsumer(session.createTopic("Topic_One")); MessageConsumer consumerTwo = session.createConsumer(session.createTopic("Topic_Two")); - ArrayList data = new ArrayList<>(); + ArrayList data = new ArrayList<>(); for (int i = 0; i < 50; i++) { - data.add(new TestEvent("Topic_One", "Message One " + i)); + data.add(new JmsIOTestUtils.TestEvent("Topic_One", "Message One " + i)); } for (int i = 0; i < 100; i++) { - data.add(new TestEvent("Topic_Two", "Message Two " + i)); + data.add(new JmsIOTestUtils.TestEvent("Topic_Two", "Message Two " + i)); } pipeline .apply(Create.of(data)) .apply( - JmsIO.write() + JmsIO.write() .withConnectionFactory(connectionFactory) .withUsername(USERNAME) .withPassword(PASSWORD) .withRetryConfiguration(retryConfiguration) .withTopicNameMapper(e -> e.getTopicName()) + .withJmsPoolConfiguration(jmsPoolConfiguration) .withValueMapper( (e, s) -> { try { @@ -767,11 +780,12 @@ public void testWriteMessageWithRetryPolicy() throws Exception { .apply( JmsIO.write() .withConnectionFactory(connectionFactory) - .withValueMapper(new TextMessageMapperWithErrorCounter()) + .withValueMapper(new JmsIOTestUtils.TextMessageMapperWithErrorCounter()) .withRetryConfiguration(retryPolicy) .withQueue(QUEUE) .withUsername(USERNAME) - .withPassword(PASSWORD)); + .withPassword(PASSWORD) + .withJmsPoolConfiguration(jmsPoolConfiguration)); PAssert.that(output.getFailedMessages()).empty(); pipeline.run(); @@ -813,7 +827,8 @@ public void testWriteMessageWithRetryPolicyReachesLimit() throws Exception { .withRetryConfiguration(retryConfiguration) .withQueue(QUEUE) .withUsername(USERNAME) - .withPassword(PASSWORD)); + .withPassword(PASSWORD) + .withJmsPoolConfiguration(jmsPoolConfiguration)); PAssert.that(output.getFailedMessages()).containsInAnyOrder(messageText); PipelineResult pipelineResult = pipeline.run(); @@ -866,11 +881,12 @@ public void testWriteMessagesWithErrors() throws Exception { .apply( JmsIO.write() .withConnectionFactory(connectionFactory) - .withValueMapper(new TextMessageMapperWithErrorAndCounter()) + .withValueMapper(new JmsIOTestUtils.TextMessageMapperWithErrorAndCounter()) .withRetryConfiguration(retryConfiguration) .withQueue(QUEUE) .withUsername(USERNAME) - .withPassword(PASSWORD)); + .withPassword(PASSWORD) + .withJmsPoolConfiguration(jmsPoolConfiguration)); PAssert.that(output.getFailedMessages()).containsInAnyOrder("Message 2"); pipeline.run(); @@ -906,7 +922,8 @@ public void testWriteMessageToStaticTopicWithoutRetryPolicy() throws Exception { .withValueMapper(new TextMessageMapper()) .withTopic(TOPIC) .withUsername(USERNAME) - .withPassword(PASSWORD)); + .withPassword(PASSWORD) + .withJmsPoolConfiguration(jmsPoolConfiguration)); PAssert.that(output.getFailedMessages()).empty(); pipeline.run(); Message message = consumer.receive(1000); @@ -914,6 +931,36 @@ public void testWriteMessageToStaticTopicWithoutRetryPolicy() throws Exception { assertNull(consumer.receiveNoWait()); } + @Test + public void testWriteMessageWithClosedSession() throws Exception { + List data = Arrays.asList("Message 1", "Message 2", "Message 3", "Message 4"); + + Connection connection = connectionFactory.createConnection(USERNAME, PASSWORD); + connection.start(); + Session session = connection.createSession(false, Session.AUTO_ACKNOWLEDGE); + MessageConsumer consumer = session.createConsumer(session.createTopic(TOPIC)); + + WriteJmsResult output = + pipeline + .apply(Create.of(data)) + .apply( + JmsIO.write() + .withConnectionFactory(connectionFactory) + .withValueMapper(new JmsIOTestUtils.TextMessageMapperWithSessionClosed()) + .withTopic(TOPIC) + .withUsername(USERNAME) + .withPassword(PASSWORD) + .withJmsPoolConfiguration(JmsPoolConfiguration.create(1, 1, null))); + PAssert.that(output.getFailedMessages()).empty(); + pipeline.run(); + + int count = 0; + while (consumer.receive(1000) != null) { + count++; + } + assertEquals(4, count); + } + private int count(String queue) throws Exception { Connection connection = connectionFactory.createConnection(USERNAME, PASSWORD); connection.start(); @@ -978,89 +1025,4 @@ private T proxyMethod( return result; }); } - - private static class TestEvent implements Serializable { - private final String topicName; - private final String value; - - private TestEvent(String topicName, String value) { - this.topicName = topicName; - this.value = value; - } - - private String getTopicName() { - return this.topicName; - } - - private String getValue() { - return this.value; - } - } - - private static class TextMessageMapperWithError - implements SerializableBiFunction { - @Override - public Message apply(String value, Session session) { - try { - if (value.equals("Message 1") || value.equals("Message 2")) { - throw new JMSException("Error!!"); - } - TextMessage msg = session.createTextMessage(); - msg.setText(value); - return msg; - } catch (JMSException e) { - throw new JmsIOException("Error creating TextMessage", e); - } - } - } - - private static class TextMessageMapperWithErrorCounter - implements SerializableBiFunction { - - private static int errorCounter; - - TextMessageMapperWithErrorCounter() { - errorCounter = 0; - } - - @Override - public Message apply(String value, Session session) { - try { - if (errorCounter == 0) { - errorCounter++; - throw new JMSException("Error!!"); - } - TextMessage msg = session.createTextMessage(); - msg.setText(value); - return msg; - } catch (JMSException e) { - throw new JmsIOException("Error creating TextMessage", e); - } - } - } - - private static class TextMessageMapperWithErrorAndCounter - implements SerializableBiFunction { - private static int errorCounter = 0; - - @Override - public Message apply(String value, Session session) { - try { - if (value.equals("Message 1") || value.equals("Message 2")) { - if (errorCounter != 0 && value.equals("Message 1")) { - TextMessage msg = session.createTextMessage(); - msg.setText(value); - return msg; - } - errorCounter++; - throw new JMSException("Error!!"); - } - TextMessage msg = session.createTextMessage(); - msg.setText(value); - return msg; - } catch (JMSException e) { - throw new JmsIOException("Error creating TextMessage", e); - } - } - } } diff --git a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/utils/JmsIOTestUtils.java b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/utils/JmsIOTestUtils.java new file mode 100644 index 000000000000..6b2bb48c45e5 --- /dev/null +++ b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/utils/JmsIOTestUtils.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jms.utils; + +import java.io.IOException; +import java.io.Serializable; +import java.net.ServerSocket; +import java.util.Random; +import javax.jms.*; +import org.apache.beam.sdk.io.jms.JmsIOException; +import org.apache.beam.sdk.transforms.SerializableBiFunction; + +public class JmsIOTestUtils { + + private JmsIOTestUtils() {} + + public static class TestEvent implements Serializable { + private final String topicName; + private final String value; + + public TestEvent(String topicName, String value) { + this.topicName = topicName; + this.value = value; + } + + public String getTopicName() { + return this.topicName; + } + + public String getValue() { + return this.value; + } + } + + public static class TextMessageMapperWithError + implements SerializableBiFunction { + @Override + public Message apply(String value, Session session) { + try { + if (value.equals("Message 1") || value.equals("Message 2")) { + throw new JMSException("Error!!"); + } + TextMessage msg = session.createTextMessage(); + msg.setText(value); + return msg; + } catch (JMSException e) { + throw new JmsIOException("Error creating TextMessage", e); + } + } + } + + public static class TextMessageMapperWithErrorCounter + implements SerializableBiFunction { + + private static int errorCounter = 0; + + @Override + public Message apply(String value, Session session) { + try { + if (errorCounter == 0) { + errorCounter++; + throw new JMSException("Error!!"); + } + TextMessage msg = session.createTextMessage(); + msg.setText(value); + return msg; + } catch (JMSException e) { + throw new JmsIOException("Error creating TextMessage", e); + } + } + } + + public static class TextMessageMapperWithErrorAndCounter + implements SerializableBiFunction { + private static int errorCounter = 0; + + @Override + public Message apply(String value, Session session) { + try { + if (value.equals("Message 1") || value.equals("Message 2")) { + if (errorCounter != 0 && value.equals("Message 1")) { + TextMessage msg = session.createTextMessage(); + msg.setText(value); + return msg; + } + errorCounter++; + throw new JMSException("Error!!"); + } + TextMessage msg = session.createTextMessage(); + msg.setText(value); + return msg; + } catch (JMSException e) { + throw new JmsIOException("Error creating TextMessage", e); + } + } + } + + public static class TextMessageMapperWithSessionClosed + implements SerializableBiFunction { + private static int errorCounter = 0; + + @Override + public Message apply(String value, Session session) { + try { + if (value.equals("Message 1") || value.equals("Message 2")) { + if (errorCounter == 0) { + errorCounter++; + session.close(); + } + } + TextMessage msg = session.createTextMessage(); + msg.setText(value); + return msg; + } catch (JMSException e) { + throw new JmsIOException("Error creating TextMessage", e); + } + } + } + + private static int getRandomNumber() { + int max = 65535, min = 1024; + return new Random().nextInt(max - min + 1) + min; + } + + private static boolean isPortUnused(int portNumber) { + try (ServerSocket serverSocket = new ServerSocket(portNumber)) { + serverSocket.close(); + return true; + } catch (IOException exception) { + return false; + } + } + + public static int getUnusedPort() { + int portNumber = getRandomNumber(); + while (true) { + if (isPortUnused(portNumber)) { + return portNumber; + } else { + portNumber = getRandomNumber(); + } + } + } +}