diff --git a/.gitignore b/.gitignore index bc6c02c1..03b19a31 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,7 @@ mise.toml # Ignore Gradle build output directory workers/java/bin -workers/java/build +workers/java/**/build /loadgen/kitchen-sink-gen/target/ diff --git a/cmd/dev/test.go b/cmd/dev/test.go index cf0f67b6..f4f0b2f6 100644 --- a/cmd/dev/test.go +++ b/cmd/dev/test.go @@ -81,6 +81,10 @@ func runTestWorker(ctx context.Context, language string) error { if err := runPythonHarnessTests(ctx, repoDir); err != nil { return err } + } else if language == "java" { + if err := runJavaHarnessTests(ctx, repoDir); err != nil { + return err + } } if language == "dotnet" { if err := runDotnetHarnessTests(ctx, repoDir); err != nil { @@ -158,6 +162,21 @@ func runRubyHarnessTests(ctx context.Context, repoDir string) error { return nil } +func runJavaHarnessTests(ctx context.Context, repoDir string) error { + harnessDir := filepath.Join(repoDir, "workers", "java") + fmt.Println("Running Java harness tests...") + if err := runCommandInDir( + ctx, + harnessDir, + "./gradlew", + ":harness:test", + ); err != nil { + return fmt.Errorf("failed Java harness tests: %w", err) + } + fmt.Println("✅ Java harness tests completed successfully!") + return nil +} + func testWorkerLocally(ctx context.Context, repoDir, language, sdkVersion string) error { args := []string{ "go", "run", "./cmd", "run-scenario-with-worker", diff --git a/dockerfiles/java.Dockerfile b/dockerfiles/java.Dockerfile index bf5cad54..817b38c9 100644 --- a/dockerfiles/java.Dockerfile +++ b/dockerfiles/java.Dockerfile @@ -34,6 +34,7 @@ ARG SDK_DIR=.gitignore COPY ${SDK_DIR} ./repo # Copy the worker files +COPY workers/proto ./workers/proto COPY workers/java ./workers/java # Download Gradle using wrapper to cache it in build layer @@ -51,6 +52,7 @@ RUN apt-get update && apt-get install --no-install-recommends --assume-yes git & ENV GRADLE_USER_HOME="/gradle" COPY --from=build /app/temporal-omes /app/temporal-omes +COPY --from=build /app/workers/proto/harness /app/workers/proto/harness COPY --from=build /app/workers/java /app/workers/java COPY --from=build /app/repo /app/repo COPY --from=build /gradle /gradle diff --git a/workers/java/build.gradle b/workers/java/build.gradle index 24ed5158..2306987c 100644 --- a/workers/java/build.gradle +++ b/workers/java/build.gradle @@ -3,22 +3,27 @@ plugins { id 'com.diffplug.spotless' version '6.18.0' } -group 'io.temporal' -version '0.1.0' +allprojects { + group 'io.temporal' + version '0.1.0' -java { - sourceCompatibility = JavaVersion.VERSION_1_10 - targetCompatibility = JavaVersion.VERSION_1_10 -} + repositories { + mavenCentral() + } -repositories { - mavenCentral() + plugins.withType(JavaPlugin) { + java { + sourceCompatibility = JavaVersion.VERSION_1_10 + targetCompatibility = JavaVersion.VERSION_1_10 + } + } } spotless { java { target project.fileTree(project.rootDir) { include '**/*.java' + exclude '**/build/**' exclude 'io/temporal/omes/KitchenSink.java' } googleJavaFormat('1.22.0') @@ -28,28 +33,24 @@ spotless { compileJava.dependsOn spotlessJava dependencies { - implementation 'ch.qos.logback:logback-classic:1.2.13' + implementation project(':harness') implementation 'com.google.guava:guava:31.0.1-jre' implementation 'com.google.code.gson:gson:2.8.9' implementation 'com.jayway.jsonpath:json-path:2.6.0' - implementation 'info.picocli:picocli:4.6.2' implementation 'io.temporal:temporal-sdk:1.34.0' - implementation 'org.junit.jupiter:junit-jupiter-api:5.8.1' implementation 'org.reflections:reflections:0.10.2' - implementation 'net.logstash.logback:logstash-logback-encoder:7.4' - implementation "io.micrometer:micrometer-registry-prometheus" implementation(platform("com.fasterxml.jackson:jackson-bom:2.15.2")) implementation "com.fasterxml.jackson.core:jackson-databind" implementation "com.fasterxml.jackson.core:jackson-core" - implementation 'com.google.protobuf:protobuf-java:3.25.0' + implementation 'com.google.protobuf:protobuf-java:3.25.5' + compileOnly 'javax.annotation:javax.annotation-api:1.3.2' } sourceSets { main { java { srcDirs = ['io/temporal/omes'] - exclude '**/build/**' } } } diff --git a/workers/java/harness/build.gradle b/workers/java/harness/build.gradle new file mode 100644 index 00000000..ff60a5f6 --- /dev/null +++ b/workers/java/harness/build.gradle @@ -0,0 +1,54 @@ +plugins { + id 'java' + id 'com.google.protobuf' version '0.9.4' +} + +compileJava.dependsOn generateProto + +dependencies { + compileOnly 'io.temporal:temporal-sdk:1.34.0' + + implementation 'ch.qos.logback:logback-classic:1.2.13' + implementation 'com.google.protobuf:protobuf-java:3.25.5' + implementation 'info.picocli:picocli:4.6.2' + implementation 'net.logstash.logback:logstash-logback-encoder:7.4' + implementation 'io.micrometer:micrometer-registry-prometheus' + + testImplementation 'io.temporal:temporal-testing:1.34.0' + testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.1' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1' +} + +sourceSets { + main { + proto { + srcDirs = ['../../proto/harness/api'] + include 'api.proto' + } + } +} + +protobuf { + protoc { + artifact = 'com.google.protobuf:protoc:3.25.5' + } + plugins { + grpc { + artifact = 'io.grpc:protoc-gen-grpc-java:1.75.0' + } + } + generateProtoTasks { + all().configureEach { task -> + task.builtins { + java {} + } + task.plugins { + grpc {} + } + } + } +} + +test { + useJUnitPlatform() +} diff --git a/workers/java/harness/src/main/java/io/temporal/omes/harness/Harness.java b/workers/java/harness/src/main/java/io/temporal/omes/harness/Harness.java new file mode 100644 index 00000000..f74ad328 --- /dev/null +++ b/workers/java/harness/src/main/java/io/temporal/omes/harness/Harness.java @@ -0,0 +1,53 @@ +package io.temporal.omes.harness; + +import java.util.Arrays; +import java.util.Objects; + +public final class Harness { + private Harness() {} + + public static void run(App app, String... argv) throws Exception { + if (argv.length == 0) { + throw new IllegalArgumentException( + "No command specified. Expected 'worker' or 'project-server'"); + } + + if ("worker".equals(argv[0])) { + WorkerHarness.runWorkerCli( + app.worker, app.clientFactory, Arrays.copyOfRange(argv, 1, argv.length)); + return; + } + + if ("project-server".equals(argv[0])) { + if (app.project == null) { + throw new IllegalStateException( + "Wanted project-server but no project handlers registered for this app"); + } + ProjectHarness.runProjectServerCli( + app.project, app.clientFactory, Arrays.copyOfRange(argv, 1, argv.length)); + return; + } + + throw new IllegalArgumentException( + String.format("Unknown command: [%s]. Expected 'worker' or 'project-server'", argv[0])); + } + + public static final class App { + public final WorkerHarness.WorkerRegistrar worker; + public final HarnessClients.ClientFactory clientFactory; + public final ProjectHarness.ProjectHandlers project; + + public App(WorkerHarness.WorkerRegistrar worker, HarnessClients.ClientFactory clientFactory) { + this(worker, clientFactory, null); + } + + public App( + WorkerHarness.WorkerRegistrar worker, + HarnessClients.ClientFactory clientFactory, + ProjectHarness.ProjectHandlers project) { + this.worker = Objects.requireNonNull(worker); + this.clientFactory = Objects.requireNonNull(clientFactory); + this.project = project; + } + } +} diff --git a/workers/java/harness/src/main/java/io/temporal/omes/harness/HarnessClients.java b/workers/java/harness/src/main/java/io/temporal/omes/harness/HarnessClients.java new file mode 100644 index 00000000..f1ce37de --- /dev/null +++ b/workers/java/harness/src/main/java/io/temporal/omes/harness/HarnessClients.java @@ -0,0 +1,200 @@ +package io.temporal.omes.harness; + +import com.sun.net.httpserver.HttpServer; +import com.uber.m3.tally.RootScopeBuilder; +import com.uber.m3.tally.Scope; +import com.uber.m3.tally.StatsReporter; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext; +import io.micrometer.core.instrument.Meter; +import io.micrometer.core.instrument.config.NamingConvention; +import io.micrometer.prometheus.PrometheusConfig; +import io.micrometer.prometheus.PrometheusMeterRegistry; +import io.micrometer.prometheus.PrometheusNamingConvention; +import io.temporal.client.WorkflowClient; +import io.temporal.client.WorkflowClientOptions; +import io.temporal.common.converter.DataConverter; +import io.temporal.common.converter.GlobalDataConverter; +import io.temporal.common.reporter.MicrometerClientStatsReporter; +import io.temporal.serviceclient.SimpleSslContextBuilder; +import io.temporal.serviceclient.WorkflowServiceStubs; +import io.temporal.serviceclient.WorkflowServiceStubsOptions; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import javax.net.ssl.SSLException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public final class HarnessClients { + private static final Logger logger = LoggerFactory.getLogger(HarnessClients.class); + + private HarnessClients() {} + + static ClientConfig buildClientConfig( + String serverAddress, + String namespace, + String authHeader, + boolean tls, + String tlsCertPath, + String tlsKeyPath, + String tlsServerName, + boolean disableHostVerification, + String promListenAddress, + String promHandlerPath) { + if (disableHostVerification) { + logger.warn("disable_host_verification is not supported by the Java SDK harness; ignoring"); + } + + return new ClientConfig( + serverAddress, + namespace, + buildApiKey(authHeader), + buildTlsContext(tls, tlsCertPath, tlsKeyPath), + normalizeEmpty(promListenAddress), + normalizeEmpty(promHandlerPath), + normalizeEmpty(tlsServerName)); + } + + public static WorkflowClient defaultClientFactory(ClientConfig config) throws Exception { + return newWorkflowClient(config, GlobalDataConverter.get()); + } + + public static WorkflowClient newWorkflowClient(ClientConfig config, DataConverter dataConverter) + throws Exception { + WorkflowServiceStubsOptions.Builder serviceOptionsBuilder = + WorkflowServiceStubsOptions.newBuilder().setTarget(config.targetHost); + + if (config.tls != null) { + serviceOptionsBuilder.setSslContext(config.tls); + if (config.tlsServerName != null) { + serviceOptionsBuilder.setChannelInitializer( + channelBuilder -> channelBuilder.overrideAuthority(config.tlsServerName)); + } + } + + if (config.apiKey != null) { + serviceOptionsBuilder.addApiKey(() -> config.apiKey); + } + + Scope metricsScope = maybeCreateMetricsScope(config.promListenAddress, config.promHandlerPath); + if (metricsScope != null) { + serviceOptionsBuilder.setMetricsScope(metricsScope); + } + + WorkflowServiceStubs service = + WorkflowServiceStubs.newServiceStubs(serviceOptionsBuilder.build()); + + return WorkflowClient.newInstance( + service, + WorkflowClientOptions.newBuilder() + .setDataConverter(dataConverter) + .setNamespace(config.namespace) + .build()); + } + + public interface ClientFactory { + WorkflowClient create(ClientConfig config) throws Exception; + } + + public static final class ClientConfig { + public final String targetHost; + public final String namespace; + public final String apiKey; + public final SslContext tls; + final String promListenAddress; + final String promHandlerPath; + final String tlsServerName; + + public ClientConfig( + String targetHost, + String namespace, + String apiKey, + SslContext tls, + String promListenAddress, + String promHandlerPath, + String tlsServerName) { + this.targetHost = targetHost; + this.namespace = namespace; + this.apiKey = apiKey; + this.tls = tls; + this.promListenAddress = promListenAddress; + this.promHandlerPath = promHandlerPath; + this.tlsServerName = tlsServerName; + } + } + + private static String buildApiKey(String authHeader) { + if (authHeader == null || authHeader.isEmpty()) { + return null; + } + if (authHeader.startsWith("Bearer ")) { + return authHeader.substring("Bearer ".length()); + } + return authHeader; + } + + private static SslContext buildTlsContext(boolean tls, String tlsCertPath, String tlsKeyPath) { + String certPath = normalizeEmpty(tlsCertPath); + String keyPath = normalizeEmpty(tlsKeyPath); + + if (certPath != null) { + if (keyPath == null) { + throw new IllegalArgumentException("Client cert specified, but not client key!"); + } + try (InputStream clientCert = new FileInputStream(certPath); + InputStream clientKey = new FileInputStream(keyPath)) { + return SimpleSslContextBuilder.forPKCS8(clientCert, clientKey).build(); + } catch (IOException e) { + throw new IllegalArgumentException("Unable to load TLS credentials", e); + } + } + if (keyPath != null) { + throw new IllegalArgumentException("Client key specified, but not client cert!"); + } + if (tls) { + try { + return SimpleSslContextBuilder.noKeyOrCertChain().build(); + } catch (SSLException e) { + throw new IllegalArgumentException("Unable to build TLS context", e); + } + } + return null; + } + + private static Scope maybeCreateMetricsScope(String promListenAddress, String promHandlerPath) { + if (promListenAddress == null || promListenAddress.isEmpty()) { + return null; + } + + PrometheusMeterRegistry registry = new PrometheusMeterRegistry(PrometheusConfig.DEFAULT); + registry + .config() + .namingConvention( + new PrometheusNamingConvention() { + @Override + public String name(String name, Meter.Type type, String baseUnit) { + return NamingConvention.snakeCase.name(name, type, null); + } + }); + + StatsReporter reporter = new MicrometerClientStatsReporter(registry); + Scope scope = + new RootScopeBuilder() + .reporter(reporter) + .reportEvery(com.uber.m3.util.Duration.ofSeconds(1)); + HttpServer scrapeEndpoint = + HarnessMetricsUtils.startPrometheusScrapeEndpoint( + registry, promHandlerPath, promListenAddress); + Runtime.getRuntime() + .addShutdownHook( + new Thread(() -> scrapeEndpoint.stop(1), "omes-java-harness-metrics-shutdown")); + return scope; + } + + private static String normalizeEmpty(String value) { + if (value == null || value.isEmpty()) { + return null; + } + return value; + } +} diff --git a/workers/java/harness/src/main/java/io/temporal/omes/harness/HarnessHelpers.java b/workers/java/harness/src/main/java/io/temporal/omes/harness/HarnessHelpers.java new file mode 100644 index 00000000..235cc188 --- /dev/null +++ b/workers/java/harness/src/main/java/io/temporal/omes/harness/HarnessHelpers.java @@ -0,0 +1,80 @@ +package io.temporal.omes.harness; + +import ch.qos.logback.classic.Level; +import ch.qos.logback.classic.LoggerContext; +import ch.qos.logback.classic.encoder.PatternLayoutEncoder; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.ConsoleAppender; +import java.util.Locale; +import net.logstash.logback.encoder.LogstashEncoder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +final class HarnessHelpers { + private HarnessHelpers() {} + + static Logger configure(String logLevel, String logEncoding) { + Logger logger = LoggerFactory.getLogger(org.slf4j.Logger.ROOT_LOGGER_NAME); + if (!(logger instanceof ch.qos.logback.classic.Logger)) { + return logger; + } + + ch.qos.logback.classic.Logger rootLogger = (ch.qos.logback.classic.Logger) logger; + rootLogger.setLevel(resolveLogLevel(logLevel)); + + if ("json".equals(logEncoding)) { + rootLogger.detachAndStopAllAppenders(); + rootLogger.addAppender(jsonAppender(rootLogger.getLoggerContext())); + } else if (!rootLogger.iteratorForAppenders().hasNext()) { + rootLogger.addAppender(consoleAppender(rootLogger.getLoggerContext())); + } + + return rootLogger; + } + + private static Level resolveLogLevel(String logLevel) { + switch (logLevel.toUpperCase(Locale.ROOT)) { + case "PANIC": + case "FATAL": + case "ERROR": + return Level.ERROR; + case "WARN": + return Level.WARN; + case "INFO": + return Level.INFO; + case "DEBUG": + return Level.DEBUG; + case "NOTSET": + return Level.ALL; + default: + throw new IllegalArgumentException( + "Invalid log level: " + + logLevel + + ". Expected one of: debug, info, warn, error, panic, fatal, notset"); + } + } + + private static ConsoleAppender jsonAppender(LoggerContext context) { + LogstashEncoder encoder = new LogstashEncoder(); + encoder.setContext(context); + encoder.start(); + return appender(context, encoder); + } + + private static ConsoleAppender consoleAppender(LoggerContext context) { + PatternLayoutEncoder encoder = new PatternLayoutEncoder(); + encoder.setContext(context); + encoder.setPattern("%d{HH:mm:ss.SSS} %-5level %logger{36} - %msg%n"); + encoder.start(); + return appender(context, encoder); + } + + private static ConsoleAppender appender( + LoggerContext context, ch.qos.logback.core.encoder.Encoder encoder) { + ConsoleAppender appender = new ConsoleAppender<>(); + appender.setContext(context); + appender.setEncoder(encoder); + appender.start(); + return appender; + } +} diff --git a/workers/java/io/temporal/omes/MetricsUtils.java b/workers/java/harness/src/main/java/io/temporal/omes/harness/HarnessMetricsUtils.java similarity index 70% rename from workers/java/io/temporal/omes/MetricsUtils.java rename to workers/java/harness/src/main/java/io/temporal/omes/harness/HarnessMetricsUtils.java index 62152a0c..86d4e3f3 100644 --- a/workers/java/io/temporal/omes/MetricsUtils.java +++ b/workers/java/harness/src/main/java/io/temporal/omes/harness/HarnessMetricsUtils.java @@ -1,4 +1,4 @@ -package io.temporal.omes; +package io.temporal.omes.harness; import static java.nio.charset.StandardCharsets.UTF_8; @@ -8,13 +8,10 @@ import java.io.OutputStream; import java.net.InetSocketAddress; -public class MetricsUtils { +final class HarnessMetricsUtils { + private HarnessMetricsUtils() {} - /** - * Starts HttpServer to expose a scrape endpoint. See - * https://micrometer.io/docs/registry/prometheus for more info. - */ - public static HttpServer startPrometheusScrapeEndpoint( + static HttpServer startPrometheusScrapeEndpoint( PrometheusMeterRegistry registry, String path, String address) { try { String[] parts = address.split(":"); @@ -22,10 +19,7 @@ public static HttpServer startPrometheusScrapeEndpoint( throw new IllegalArgumentException("Invalid address: " + address); } String host = parts[0]; - int port = 0; - if (parts.length == 2) { - port = Integer.parseInt(parts[1]); - } + int port = parts.length == 2 ? Integer.parseInt(parts[1]) : 0; HttpServer server = HttpServer.create(new InetSocketAddress(host, port), 0); server.createContext( @@ -36,8 +30,8 @@ public static HttpServer startPrometheusScrapeEndpoint( .getResponseHeaders() .set("Content-Type", "text/plain; version=0.0.4; charset=utf-8"); httpExchange.sendResponseHeaders(200, response.getBytes(UTF_8).length); - try (OutputStream os = httpExchange.getResponseBody()) { - os.write(response.getBytes(UTF_8)); + try (OutputStream output = httpExchange.getResponseBody()) { + output.write(response.getBytes(UTF_8)); } }); diff --git a/workers/java/harness/src/main/java/io/temporal/omes/harness/ProjectHarness.java b/workers/java/harness/src/main/java/io/temporal/omes/harness/ProjectHarness.java new file mode 100644 index 00000000..95205a5c --- /dev/null +++ b/workers/java/harness/src/main/java/io/temporal/omes/harness/ProjectHarness.java @@ -0,0 +1,235 @@ +package io.temporal.omes.harness; + +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.Status; +import io.grpc.stub.StreamObserver; +import io.temporal.client.WorkflowClient; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import picocli.CommandLine; +import temporal.omes.projects.v1.Api; +import temporal.omes.projects.v1.ProjectServiceGrpc; + +public final class ProjectHarness { + private static final Logger logger = LoggerFactory.getLogger(ProjectHarness.class); + + private ProjectHarness() {} + + static void runProjectServerCli( + ProjectHandlers handlers, HarnessClients.ClientFactory clientFactory, String... argv) + throws Exception { + Arguments args = new Arguments(); + new CommandLine(args).parseArgs(argv); + Server server = + ServerBuilder.forPort(args.port) + .addService(new ProjectServiceServer(handlers, clientFactory)) + .build(); + server.start(); + logger.info("Project server listening on port {}", server.getPort()); + server.awaitTermination(); + } + + public static final class ProjectRunMetadata { + public final String runId; + public final String executionId; + + public ProjectRunMetadata(String runId, String executionId) { + this.runId = runId; + this.executionId = executionId; + } + } + + public static final class ProjectInitContext { + public final Logger logger; + public final ProjectRunMetadata run; + public final String taskQueue; + public final byte[] configJson; + + public ProjectInitContext( + Logger logger, ProjectRunMetadata run, String taskQueue, byte[] configJson) { + this.logger = logger; + this.run = run; + this.taskQueue = taskQueue; + this.configJson = configJson; + } + } + + public static final class ProjectExecuteContext { + public final Logger logger; + public final ProjectRunMetadata run; + public final String taskQueue; + public final long iteration; + public final byte[] payload; + + public ProjectExecuteContext( + Logger logger, ProjectRunMetadata run, String taskQueue, long iteration, byte[] payload) { + this.logger = logger; + this.run = run; + this.taskQueue = taskQueue; + this.iteration = iteration; + this.payload = payload; + } + } + + @FunctionalInterface + public interface ProjectInitHandler { + void init(WorkflowClient client, ProjectInitContext context) throws Exception; + } + + @FunctionalInterface + public interface ProjectExecuteHandler { + void execute(WorkflowClient client, ProjectExecuteContext context) throws Exception; + } + + public static final class ProjectHandlers { + public final ProjectExecuteHandler execute; + public final ProjectInitHandler init; + + public ProjectHandlers(ProjectExecuteHandler execute) { + this(execute, null); + } + + public ProjectHandlers(ProjectExecuteHandler execute, ProjectInitHandler init) { + this.execute = Objects.requireNonNull(execute); + this.init = init; + } + } + + static final class ProjectServiceServer extends ProjectServiceGrpc.ProjectServiceImplBase { + private static final Logger logger = LoggerFactory.getLogger(ProjectServiceServer.class); + + private final ProjectHandlers handlers; + private final HarnessClients.ClientFactory clientFactory; + private volatile WorkflowClient client; + private volatile ProjectRunMetadata run; + + ProjectServiceServer(ProjectHandlers handlers, HarnessClients.ClientFactory clientFactory) { + this.handlers = Objects.requireNonNull(handlers); + this.clientFactory = Objects.requireNonNull(clientFactory); + } + + @Override + public void init(Api.InitRequest request, StreamObserver responseObserver) { + if (request.getTaskQueue().isEmpty()) { + abort(responseObserver, Status.INVALID_ARGUMENT, "task_queue required"); + return; + } + if (request.getExecutionId().isEmpty()) { + abort(responseObserver, Status.INVALID_ARGUMENT, "execution_id required"); + return; + } + if (request.getRunId().isEmpty()) { + abort(responseObserver, Status.INVALID_ARGUMENT, "run_id required"); + return; + } + + Api.ConnectOptions conn = request.getConnectOptions(); + if (conn.getServerAddress().isEmpty()) { + abort(responseObserver, Status.INVALID_ARGUMENT, "server_address required"); + return; + } + if (conn.getNamespace().isEmpty()) { + abort(responseObserver, Status.INVALID_ARGUMENT, "namespace required"); + return; + } + + HarnessClients.ClientConfig config; + try { + config = + HarnessClients.buildClientConfig( + conn.getServerAddress(), + conn.getNamespace(), + conn.getAuthHeader(), + conn.getEnableTls(), + conn.getTlsCertPath(), + conn.getTlsKeyPath(), + conn.getTlsServerName(), + conn.getDisableHostVerification(), + null, + null); + } catch (IllegalArgumentException e) { + abort(responseObserver, Status.INVALID_ARGUMENT, messageOf(e)); + return; + } + + WorkflowClient createdClient; + try { + createdClient = clientFactory.create(config); + } catch (Exception e) { + abort(responseObserver, Status.INTERNAL, "failed to create client: " + messageOf(e)); + return; + } + + ProjectRunMetadata createdRun = + new ProjectRunMetadata(request.getRunId(), request.getExecutionId()); + + if (handlers.init != null) { + try { + handlers.init.init( + createdClient, + new ProjectInitContext( + logger, + createdRun, + request.getTaskQueue(), + request.getConfigJson().toByteArray())); + } catch (Exception e) { + abort(responseObserver, Status.INTERNAL, "init handler failed: " + messageOf(e)); + return; + } + } + + client = createdClient; + run = createdRun; + responseObserver.onNext(Api.InitResponse.getDefaultInstance()); + responseObserver.onCompleted(); + } + + @Override + public void execute( + Api.ExecuteRequest request, StreamObserver responseObserver) { + if (request.getTaskQueue().isEmpty()) { + abort(responseObserver, Status.INVALID_ARGUMENT, "task_queue required"); + return; + } + + WorkflowClient cachedClient = client; + ProjectRunMetadata cachedRun = run; + if (cachedClient == null || cachedRun == null) { + abort(responseObserver, Status.FAILED_PRECONDITION, "Init must be called before Execute"); + return; + } + + try { + handlers.execute.execute( + cachedClient, + new ProjectExecuteContext( + logger, + cachedRun, + request.getTaskQueue(), + request.getIteration(), + request.getPayload().toByteArray())); + } catch (Exception e) { + abort(responseObserver, Status.INTERNAL, "execute handler failed: " + messageOf(e)); + return; + } + + responseObserver.onNext(Api.ExecuteResponse.getDefaultInstance()); + responseObserver.onCompleted(); + } + } + + private static void abort(StreamObserver responseObserver, Status status, String description) { + responseObserver.onError(status.withDescription(description).asRuntimeException()); + } + + private static String messageOf(Exception error) { + return error.getMessage() == null ? error.toString() : error.getMessage(); + } + + private static final class Arguments { + @CommandLine.Option(names = "--port", description = "gRPC listen port", defaultValue = "8080") + private int port; + } +} diff --git a/workers/java/harness/src/main/java/io/temporal/omes/harness/WorkerHarness.java b/workers/java/harness/src/main/java/io/temporal/omes/harness/WorkerHarness.java new file mode 100644 index 00000000..67a948f1 --- /dev/null +++ b/workers/java/harness/src/main/java/io/temporal/omes/harness/WorkerHarness.java @@ -0,0 +1,339 @@ +package io.temporal.omes.harness; + +import io.temporal.client.WorkflowClient; +import io.temporal.serviceclient.ServiceStubs; +import io.temporal.worker.Worker; +import io.temporal.worker.WorkerFactory; +import io.temporal.worker.WorkerFactoryOptions; +import io.temporal.worker.WorkerOptions; +import io.temporal.worker.tuning.PollerBehaviorAutoscaling; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import org.slf4j.Logger; +import picocli.CommandLine; + +public final class WorkerHarness { + + @FunctionalInterface + public interface WorkerRegistrar { + void register(WorkflowClient client, Worker worker, WorkerContext context) throws Exception; + } + + public static final class WorkerContext { + public final Logger logger; + public final String taskQueue; + public final boolean errOnUnimplemented; + + public WorkerContext(Logger logger, String taskQueue, boolean errOnUnimplemented) { + this.logger = logger; + this.taskQueue = taskQueue; + this.errOnUnimplemented = errOnUnimplemented; + } + } + + private WorkerHarness() {} + + static void runWorkerCli( + WorkerRegistrar workerRegistrar, HarnessClients.ClientFactory clientFactory, String... argv) + throws Exception { + Arguments args = parseArguments(argv); + Logger logger = HarnessHelpers.configure(args.logLevel, args.logEncoding); + WorkflowClient client = + clientFactory.create( + HarnessClients.buildClientConfig( + args.serverAddress, + args.namespace, + args.authHeader, + args.tls, + args.tlsCertPath, + args.tlsKeyPath, + null, + false, + args.promListenAddress, + args.promHandlerPath)); + WorkerFactory workerFactory = + WorkerFactory.newInstance( + client, WorkerFactoryOptions.newBuilder().setMaxWorkflowThreadCount(1000).build()); + AtomicBoolean shutdown = new AtomicBoolean(false); + CountDownLatch stopSignal = new CountDownLatch(1); + Thread shutdownHook = + new Thread( + () -> { + stopSignal.countDown(); + shutdownWorkersAndClient(workerFactory, client, shutdown); + }, + "omes-java-harness-worker-shutdown"); + boolean shutdownHookAdded = false; + + try { + registerWorkers( + client, + workerFactory, + workerRegistrar, + logger, + buildTaskQueues( + logger, args.taskQueue, args.taskQueueSuffixIndexStart, args.taskQueueSuffixIndexEnd), + args.errOnUnimplemented, + buildWorkerOptions(args)); + + Runtime.getRuntime().addShutdownHook(shutdownHook); + shutdownHookAdded = true; + runWorkerFactory(workerFactory, client, stopSignal, shutdown); + } finally { + if (shutdownHookAdded) { + try { + Runtime.getRuntime().removeShutdownHook(shutdownHook); + } catch (IllegalStateException ignored) { + // JVM shutdown is already in progress. + } + } + } + } + + static Arguments parseArguments(String... argv) { + Arguments args = new Arguments(); + new CommandLine(args).parseArgs(argv); + if (args.taskQueueSuffixIndexStart > args.taskQueueSuffixIndexEnd) { + throw new IllegalArgumentException("Task queue suffix start after end"); + } + return args; + } + + static void registerWorkers( + WorkflowClient client, + WorkerFactory workerFactory, + WorkerRegistrar workerRegistrar, + Logger logger, + List taskQueues, + boolean errOnUnimplemented, + WorkerOptions workerOptions) + throws Exception { + for (String taskQueue : taskQueues) { + Worker worker = workerFactory.newWorker(taskQueue, workerOptions); + workerRegistrar.register( + client, worker, new WorkerContext(logger, taskQueue, errOnUnimplemented)); + } + } + + static List buildTaskQueues( + Logger logger, String taskQueue, int suffixStart, int suffixEnd) { + if (suffixEnd == 0) { + logger.info("Java worker will run on task queue {}", taskQueue); + return List.of(taskQueue); + } + + List taskQueues = new ArrayList<>(suffixEnd - suffixStart + 1); + for (int index = suffixStart; index <= suffixEnd; index++) { + taskQueues.add(String.format("%s-%d", taskQueue, index)); + } + logger.info("Java worker will run on {} task queue(s)", taskQueues.size()); + return taskQueues; + } + + static WorkerOptions buildWorkerOptions(Arguments args) { + WorkerOptions.Builder workerOptions = WorkerOptions.newBuilder(); + if (args.workflowPollerAutoscaleMax != null) { + workerOptions.setWorkflowTaskPollersBehavior( + new PollerBehaviorAutoscaling(null, args.workflowPollerAutoscaleMax, null)); + } else if (args.maxConcurrentWorkflowPollers != null) { + workerOptions.setMaxConcurrentWorkflowTaskPollers(args.maxConcurrentWorkflowPollers); + } + + if (args.activityPollerAutoscaleMax != null) { + workerOptions.setActivityTaskPollersBehavior( + new PollerBehaviorAutoscaling(null, args.activityPollerAutoscaleMax, null)); + } else if (args.maxConcurrentActivityPollers != null) { + workerOptions.setMaxConcurrentActivityTaskPollers(args.maxConcurrentActivityPollers); + } + + if (args.maxConcurrentActivities != null) { + workerOptions.setMaxConcurrentActivityExecutionSize(args.maxConcurrentActivities); + } + if (args.maxConcurrentWorkflowTasks != null) { + workerOptions.setMaxConcurrentWorkflowTaskExecutionSize(args.maxConcurrentWorkflowTasks); + } + if (args.activitiesPerSecond != null) { + workerOptions.setMaxWorkerActivitiesPerSecond(args.activitiesPerSecond); + } + return workerOptions.build(); + } + + static void runWorkerFactory( + WorkerFactory workerFactory, WorkflowClient client, CountDownLatch stopSignal) + throws InterruptedException { + runWorkerFactory(workerFactory, client, stopSignal, new AtomicBoolean(false)); + } + + private static void runWorkerFactory( + WorkerFactory workerFactory, + WorkflowClient client, + CountDownLatch stopSignal, + AtomicBoolean shutdown) + throws InterruptedException { + try { + workerFactory.start(); + stopSignal.await(); + } finally { + shutdownWorkersAndClient(workerFactory, client, shutdown); + } + } + + private static void shutdownWorkersAndClient( + WorkerFactory workerFactory, WorkflowClient client, AtomicBoolean shutdown) { + if (!shutdown.compareAndSet(false, true)) { + return; + } + + workerFactory.shutdownNow(); + workerFactory.awaitTermination(5, TimeUnit.SECONDS); + ServiceStubs serviceStubs = client.getWorkflowServiceStubs(); + serviceStubs.shutdownNow(); + serviceStubs.awaitTermination(5, TimeUnit.SECONDS); + } + + static final class Arguments { + @CommandLine.Option( + names = {"-q", "--task-queue"}, + description = "Task queue to use", + defaultValue = "omes") + String taskQueue; + + @CommandLine.Option( + names = "--task-queue-suffix-index-start", + description = "Inclusive start for task queue suffix range", + defaultValue = "0") + int taskQueueSuffixIndexStart; + + @CommandLine.Option( + names = "--task-queue-suffix-index-end", + description = "Inclusive end for task queue suffix range", + defaultValue = "0") + int taskQueueSuffixIndexEnd; + + @CommandLine.Option( + names = "--max-concurrent-activity-pollers", + description = "Max concurrent activity pollers") + Integer maxConcurrentActivityPollers; + + @CommandLine.Option( + names = "--max-concurrent-workflow-pollers", + description = "Max concurrent workflow pollers") + Integer maxConcurrentWorkflowPollers; + + @CommandLine.Option( + names = "--activity-poller-autoscale-max", + description = + "Max for activity poller autoscaling (overrides max-concurrent-activity-pollers)") + Integer activityPollerAutoscaleMax; + + @CommandLine.Option( + names = "--workflow-poller-autoscale-max", + description = + "Max for workflow poller autoscaling (overrides max-concurrent-workflow-pollers)") + Integer workflowPollerAutoscaleMax; + + @CommandLine.Option( + names = "--max-concurrent-activities", + description = "Max concurrent activities") + Integer maxConcurrentActivities; + + @CommandLine.Option( + names = "--max-concurrent-workflow-tasks", + description = "Max concurrent workflow tasks") + Integer maxConcurrentWorkflowTasks; + + @CommandLine.Option( + names = "--activities-per-second", + description = "Per-worker activity rate limit") + Double activitiesPerSecond; + + @CommandLine.Option( + names = "--err-on-unimplemented", + description = + "Error when receiving unimplemented actions (currently only affects concurrent client actions)", + arity = "0..1", + fallbackValue = "true", + defaultValue = "false", + converter = BooleanFlagConverter.class) + boolean errOnUnimplemented; + + @CommandLine.Option( + names = "--log-level", + description = "(debug info warn error panic fatal)", + defaultValue = "info") + String logLevel; + + @CommandLine.Option( + names = "--log-encoding", + description = "(console json)", + defaultValue = "console") + String logEncoding; + + @CommandLine.Option( + names = {"-n", "--namespace"}, + description = "The namespace to use", + defaultValue = "default") + String namespace; + + @CommandLine.Option( + names = {"-a", "--server-address"}, + description = "The host:port of the server", + defaultValue = "localhost:7233") + String serverAddress; + + @CommandLine.Option( + names = "--tls", + description = "Enable TLS", + arity = "0..1", + fallbackValue = "true", + defaultValue = "false", + converter = BooleanFlagConverter.class) + boolean tls; + + @CommandLine.Option( + names = "--tls-cert-path", + description = "Path to a client cert for TLS", + defaultValue = "") + String tlsCertPath; + + @CommandLine.Option( + names = "--tls-key-path", + description = "Path to a client key for TLS", + defaultValue = "") + String tlsKeyPath; + + @CommandLine.Option(names = "--prom-listen-address", description = "Prometheus listen address") + String promListenAddress; + + @CommandLine.Option( + names = "--prom-handler-path", + description = "Prometheus handler path", + defaultValue = "/metrics") + String promHandlerPath; + + @CommandLine.Option( + names = "--auth-header", + description = "Authorization header value", + defaultValue = "") + String authHeader; + + @CommandLine.Option(names = "--build-id", description = "Build ID", defaultValue = "") + String buildId; + } + + private static final class BooleanFlagConverter implements CommandLine.ITypeConverter { + @Override + public Boolean convert(String value) { + if ("true".equalsIgnoreCase(value) || "1".equals(value) || "yes".equalsIgnoreCase(value)) { + return true; + } + if ("false".equalsIgnoreCase(value) || "0".equals(value) || "no".equalsIgnoreCase(value)) { + return false; + } + throw new IllegalArgumentException("Invalid boolean value: " + value); + } + } +} diff --git a/workers/java/harness/src/test/java/io/temporal/omes/harness/ProjectHarnessEchoWorkflow.java b/workers/java/harness/src/test/java/io/temporal/omes/harness/ProjectHarnessEchoWorkflow.java new file mode 100644 index 00000000..67099c17 --- /dev/null +++ b/workers/java/harness/src/test/java/io/temporal/omes/harness/ProjectHarnessEchoWorkflow.java @@ -0,0 +1,10 @@ +package io.temporal.omes.harness; + +import io.temporal.workflow.WorkflowInterface; +import io.temporal.workflow.WorkflowMethod; + +@WorkflowInterface +public interface ProjectHarnessEchoWorkflow { + @WorkflowMethod + String run(String payload); +} diff --git a/workers/java/harness/src/test/java/io/temporal/omes/harness/ProjectHarnessEchoWorkflowImpl.java b/workers/java/harness/src/test/java/io/temporal/omes/harness/ProjectHarnessEchoWorkflowImpl.java new file mode 100644 index 00000000..e65b253a --- /dev/null +++ b/workers/java/harness/src/test/java/io/temporal/omes/harness/ProjectHarnessEchoWorkflowImpl.java @@ -0,0 +1,10 @@ +package io.temporal.omes.harness; + +public final class ProjectHarnessEchoWorkflowImpl implements ProjectHarnessEchoWorkflow { + public ProjectHarnessEchoWorkflowImpl() {} + + @Override + public String run(String payload) { + return payload; + } +} diff --git a/workers/java/harness/src/test/java/io/temporal/omes/harness/ProjectHarnessTest.java b/workers/java/harness/src/test/java/io/temporal/omes/harness/ProjectHarnessTest.java new file mode 100644 index 00000000..43473667 --- /dev/null +++ b/workers/java/harness/src/test/java/io/temporal/omes/harness/ProjectHarnessTest.java @@ -0,0 +1,85 @@ +package io.temporal.omes.harness; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; + +import io.temporal.client.WorkflowClient; +import io.temporal.client.WorkflowOptions; +import io.temporal.testing.TestWorkflowEnvironment; +import io.temporal.worker.Worker; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; + +class ProjectHarnessTest { + @Test + void projectServerExecutesWorkflowAgainstRealTemporalServer() throws Exception { + String taskQueue = "project-harness-e2e"; + List events = new ArrayList<>(); + CapturedProjectCall captured = new CapturedProjectCall(); + + try (TestWorkflowEnvironment environment = TestWorkflowEnvironment.newInstance()) { + Worker worker = environment.newWorker(taskQueue); + worker.registerWorkflowImplementationTypes(ProjectHarnessEchoWorkflowImpl.class); + environment.start(); + + try (ProjectHarnessTestSupport.TestServer server = + ProjectHarnessTestSupport.startServer( + new ProjectHarness.ProjectServiceServer( + new ProjectHarness.ProjectHandlers( + (client, context) -> { + ProjectHarnessEchoWorkflow workflow = + client.newWorkflowStub( + ProjectHarnessEchoWorkflow.class, + WorkflowOptions.newBuilder() + .setWorkflowId( + String.format( + "%s-%d", context.run.executionId, context.iteration)) + .setTaskQueue(context.taskQueue) + .build()); + captured.executeClient = client; + captured.executeContext = context; + captured.executeResult = workflow.run(new String(context.payload, UTF_8)); + events.add("execute"); + }, + (client, context) -> { + captured.initClient = client; + captured.initContext = context; + events.add("init"); + }), + config -> environment.getWorkflowClient()))) { + server.stub.init( + ProjectHarnessTestSupport.makeInitRequest().toBuilder() + .setTaskQueue(taskQueue) + .build()); + server.stub.execute( + ProjectHarnessTestSupport.makeExecuteRequest().toBuilder() + .setTaskQueue(taskQueue) + .build()); + } + } + + assertEquals(List.of("init", "execute"), events); + assertSame(captured.initClient, captured.executeClient); + assertEquals("run-id", captured.initContext.run.runId); + assertEquals("exec-id", captured.initContext.run.executionId); + assertEquals(taskQueue, captured.initContext.taskQueue); + assertArrayEquals("{\"hello\":\"world\"}".getBytes(UTF_8), captured.initContext.configJson); + assertEquals("run-id", captured.executeContext.run.runId); + assertEquals("exec-id", captured.executeContext.run.executionId); + assertEquals(taskQueue, captured.executeContext.taskQueue); + assertEquals(7L, captured.executeContext.iteration); + assertArrayEquals("payload".getBytes(UTF_8), captured.executeContext.payload); + assertEquals("payload", captured.executeResult); + } + + private static final class CapturedProjectCall { + private WorkflowClient initClient; + private WorkflowClient executeClient; + private ProjectHarness.ProjectInitContext initContext; + private ProjectHarness.ProjectExecuteContext executeContext; + private String executeResult; + } +} diff --git a/workers/java/harness/src/test/java/io/temporal/omes/harness/ProjectHarnessTestSupport.java b/workers/java/harness/src/test/java/io/temporal/omes/harness/ProjectHarnessTestSupport.java new file mode 100644 index 00000000..799f21e5 --- /dev/null +++ b/workers/java/harness/src/test/java/io/temporal/omes/harness/ProjectHarnessTestSupport.java @@ -0,0 +1,67 @@ +package io.temporal.omes.harness; + +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.protobuf.ByteString; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import java.io.IOException; +import temporal.omes.projects.v1.Api; +import temporal.omes.projects.v1.ProjectServiceGrpc; + +final class ProjectHarnessTestSupport { + private ProjectHarnessTestSupport() {} + + static Api.InitRequest makeInitRequest() { + return Api.InitRequest.newBuilder() + .setExecutionId("exec-id") + .setRunId("run-id") + .setTaskQueue("task-queue") + .setConnectOptions( + Api.ConnectOptions.newBuilder() + .setNamespace("default") + .setServerAddress("localhost:7233") + .build()) + .setConfigJson(ByteString.copyFrom("{\"hello\":\"world\"}".getBytes(UTF_8))) + .build(); + } + + static Api.ExecuteRequest makeExecuteRequest() { + return Api.ExecuteRequest.newBuilder() + .setIteration(7) + .setTaskQueue("task-queue") + .setPayload(ByteString.copyFrom("payload".getBytes(UTF_8))) + .build(); + } + + static TestServer startServer(ProjectHarness.ProjectServiceServer service) throws IOException { + String serverName = InProcessServerBuilder.generateName(); + Server server = InProcessServerBuilder.forName(serverName).addService(service).build(); + server.start(); + ManagedChannel channel = InProcessChannelBuilder.forName(serverName).build(); + return new TestServer(server, channel, ProjectServiceGrpc.newBlockingStub(channel)); + } + + static final class TestServer implements AutoCloseable { + final Server server; + final ManagedChannel channel; + final ProjectServiceGrpc.ProjectServiceBlockingStub stub; + + private TestServer( + Server server, ManagedChannel channel, ProjectServiceGrpc.ProjectServiceBlockingStub stub) { + this.server = server; + this.channel = channel; + this.stub = stub; + } + + @Override + public void close() throws Exception { + channel.shutdownNow(); + server.shutdownNow(); + channel.awaitTermination(5, java.util.concurrent.TimeUnit.SECONDS); + server.awaitTermination(5, java.util.concurrent.TimeUnit.SECONDS); + } + } +} diff --git a/workers/java/harness/src/test/java/io/temporal/omes/harness/WorkerHarnessTest.java b/workers/java/harness/src/test/java/io/temporal/omes/harness/WorkerHarnessTest.java new file mode 100644 index 00000000..b9d03e67 --- /dev/null +++ b/workers/java/harness/src/test/java/io/temporal/omes/harness/WorkerHarnessTest.java @@ -0,0 +1,135 @@ +package io.temporal.omes.harness; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.temporal.client.WorkflowClient; +import io.temporal.common.SimplePlugin; +import io.temporal.testing.TestWorkflowEnvironment; +import io.temporal.worker.Worker; +import io.temporal.worker.WorkerFactory; +import io.temporal.worker.WorkerFactoryOptions; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; +import org.junit.jupiter.api.Test; + +class WorkerHarnessTest { + @Test + void runWorkerFactoryShutsDownAllWorkersWhenStartFails() { + LifecyclePlugin lifecycle = new LifecyclePlugin("omes-1", 2); + + try (TestWorkflowEnvironment environment = TestWorkflowEnvironment.newInstance()) { + environment.start(); + WorkflowClient client = environment.getWorkflowClient(); + WorkerFactory workerFactory = newWorkerFactory(client, lifecycle); + addWorker(workerFactory, "omes-1"); + addWorker(workerFactory, "omes-2"); + + IllegalStateException error = + assertThrows( + IllegalStateException.class, + () -> WorkerHarness.runWorkerFactory(workerFactory, client, new CountDownLatch(1))); + + assertEquals("boom", error.getMessage()); + assertEquals(Set.of("omes-1", "omes-2"), lifecycle.shutdownTaskQueues()); + } + } + + @Test + void runWorkerFactoryShutsDownAllWorkersWhenStopped() throws Exception { + LifecyclePlugin lifecycle = new LifecyclePlugin(null, 2); + CountDownLatch stopSignal = new CountDownLatch(1); + + try (TestWorkflowEnvironment environment = TestWorkflowEnvironment.newInstance()) { + environment.start(); + WorkflowClient client = environment.getWorkflowClient(); + WorkerFactory workerFactory = newWorkerFactory(client, lifecycle); + addWorker(workerFactory, "omes-1"); + addWorker(workerFactory, "omes-2"); + + CompletableFuture runTask = + CompletableFuture.runAsync(() -> runWorkerFactory(workerFactory, client, stopSignal)); + + assertTrue(lifecycle.awaitStarted(5, TimeUnit.SECONDS)); + stopSignal.countDown(); + runTask.join(); + + assertEquals(Set.of("omes-1", "omes-2"), lifecycle.startedTaskQueues()); + assertEquals(Set.of("omes-1", "omes-2"), lifecycle.shutdownTaskQueues()); + } + } + + private static WorkerFactory newWorkerFactory(WorkflowClient client, LifecyclePlugin lifecycle) { + return WorkerFactory.newInstance( + client, + WorkerFactoryOptions.newBuilder() + .setMaxWorkflowThreadCount(4) + .setPlugins(lifecycle) + .build()); + } + + private static void addWorker(WorkerFactory workerFactory, String taskQueue) { + Worker worker = workerFactory.newWorker(taskQueue); + worker.registerWorkflowImplementationTypes(ProjectHarnessEchoWorkflowImpl.class); + } + + private static void runWorkerFactory( + WorkerFactory workerFactory, WorkflowClient client, CountDownLatch stopSignal) { + try { + WorkerHarness.runWorkerFactory(workerFactory, client, stopSignal); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new AssertionError(e); + } + } + + private static final class LifecyclePlugin extends SimplePlugin { + private final String failOnStartTaskQueue; + private final CountDownLatch startedWorkers; + private final List startedTaskQueues = new ArrayList<>(); + private final List shutdownTaskQueues = new ArrayList<>(); + + private LifecyclePlugin(String failOnStartTaskQueue, int expectedWorkerCount) { + super("io.temporal.omes.harness-test.lifecycle"); + this.failOnStartTaskQueue = failOnStartTaskQueue; + this.startedWorkers = new CountDownLatch(expectedWorkerCount); + } + + @Override + public synchronized void startWorker( + String taskQueue, Worker worker, BiConsumer next) { + next.accept(taskQueue, worker); + startedTaskQueues.add(taskQueue); + startedWorkers.countDown(); + if (taskQueue.equals(failOnStartTaskQueue)) { + throw new IllegalStateException("boom"); + } + } + + @Override + public synchronized void shutdownWorker( + String taskQueue, Worker worker, BiConsumer next) { + shutdownTaskQueues.add(taskQueue); + next.accept(taskQueue, worker); + } + + private boolean awaitStarted(long timeout, TimeUnit unit) throws InterruptedException { + return startedWorkers.await(timeout, unit); + } + + private synchronized Set startedTaskQueues() { + return new HashSet<>(startedTaskQueues); + } + + private synchronized Set shutdownTaskQueues() { + return new HashSet<>(shutdownTaskQueues); + } + } +} diff --git a/workers/java/io/temporal/omes/Main.java b/workers/java/io/temporal/omes/Main.java index 1bbe035a..5f4c869a 100644 --- a/workers/java/io/temporal/omes/Main.java +++ b/workers/java/io/temporal/omes/Main.java @@ -1,242 +1,31 @@ package io.temporal.omes; -import ch.qos.logback.classic.Level; -import ch.qos.logback.classic.Logger; -import ch.qos.logback.classic.spi.ILoggingEvent; -import ch.qos.logback.core.ConsoleAppender; -import com.sun.net.httpserver.HttpServer; -import com.uber.m3.tally.RootScopeBuilder; -import com.uber.m3.tally.Scope; -import com.uber.m3.tally.StatsReporter; -import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext; -import io.micrometer.core.instrument.Meter; -import io.micrometer.core.instrument.config.NamingConvention; -import io.micrometer.core.instrument.util.StringUtils; -import io.micrometer.prometheus.PrometheusConfig; -import io.micrometer.prometheus.PrometheusMeterRegistry; -import io.micrometer.prometheus.PrometheusNamingConvention; import io.temporal.client.WorkflowClient; -import io.temporal.client.WorkflowClientOptions; -import io.temporal.common.converter.*; -import io.temporal.common.reporter.MicrometerClientStatsReporter; -import io.temporal.serviceclient.SimpleSslContextBuilder; -import io.temporal.serviceclient.WorkflowServiceStubs; -import io.temporal.serviceclient.WorkflowServiceStubsOptions; +import io.temporal.common.converter.ByteArrayPayloadConverter; +import io.temporal.common.converter.DefaultDataConverter; +import io.temporal.common.converter.JacksonJsonPayloadConverter; +import io.temporal.common.converter.NullPayloadConverter; +import io.temporal.common.converter.PayloadConverter; +import io.temporal.common.converter.ProtobufJsonPayloadConverter; +import io.temporal.common.converter.ProtobufPayloadConverter; +import io.temporal.omes.harness.Harness; +import io.temporal.omes.harness.HarnessClients; +import io.temporal.omes.harness.WorkerHarness; import io.temporal.worker.Worker; -import io.temporal.worker.WorkerFactory; -import io.temporal.worker.WorkerFactoryOptions; -import io.temporal.worker.WorkerOptions; -import io.temporal.worker.tuning.PollerBehaviorAutoscaling; -import java.io.FileInputStream; -import java.io.FileNotFoundException; -import java.io.InputStream; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.CountDownLatch; -import javax.net.ssl.SSLException; -import net.logstash.logback.encoder.LogstashEncoder; -import picocli.CommandLine; -@CommandLine.Command(name = "features", description = "Runs Java features") -public class Main implements Runnable { - @CommandLine.Option( - names = {"-q", "--task-queue"}, - description = "Task queue to use", - defaultValue = "omes") - private String taskQueue; +public final class Main { + private Main() {} - @CommandLine.Option( - names = "--task-queue-suffix-index-start", - description = "Inclusive start for task queue suffix range", - defaultValue = "0") - private Integer taskQueueIndexStart; - - @CommandLine.Option( - names = "--task-queue-suffix-index-end", - description = "Inclusive end for task queue suffix range", - defaultValue = "0") - private Integer taskQueueIndexEnd; - - // Log arguments - @CommandLine.Option(names = "--log-level", description = "Log level", defaultValue = "info") - private String logLevel; - - @CommandLine.Option( - names = "--log-encoding", - description = "Log encoding", - defaultValue = "console") - private String logEncoding; - - // Client arguments - @CommandLine.Option( - names = {"-n", "--namespace"}, - description = "The namespace to use", - defaultValue = "default") - private String namespace; - - @CommandLine.Option( - names = {"-a", "--server-address"}, - description = "The host:port of the server", - defaultValue = "localhost:7233") - private String serverAddress; - - @CommandLine.Option(names = "--tls", description = "Enable TLS") - private boolean isTlsEnabled; - - @CommandLine.Option(names = "--tls-cert-path", description = "Path to a client cert for TLS") - private String clientCertPath; - - @CommandLine.Option(names = "--tls-key-path", description = "Path to a client key for TLS") - private String clientKeyPath; - - @CommandLine.Option(names = "--auth-header", description = "Authorization header value") - private String authHeader; - - @CommandLine.Option(names = "--build-id", description = "Build ID") - private String buildId; - - // Metric parameters - @CommandLine.Option( - names = "--prom-listen-address", - description = "Prometheus listen address", - defaultValue = "localhost") - private String promListenAddress; - - @CommandLine.Option( - names = "--prom-handler-path", - description = "Prometheus handler path", - defaultValue = "/metrics") - private String promHandlerPath; - - // Worker parameters - @CommandLine.Option( - names = "--max-concurrent-activity-pollers", - description = "Max concurrent activity pollers") - private int maxConcurrentActivityPollers; - - @CommandLine.Option( - names = "--max-concurrent-workflow-pollers", - description = "Max concurrent workflow pollers") - private int maxConcurrentWorkflowPollers; - - @CommandLine.Option( - names = "--activity-poller-autoscale-max", - description = - "Max for activity poller autoscaling (overrides max-concurrent-activity-pollers)") - private int activityPollerAutoscaleMax; - - @CommandLine.Option( - names = "--workflow-poller-autoscale-max", - description = - "Max for workflow poller autoscaling (overrides max-concurrent-workflow-pollers)") - private int workflowPollerAutoscaleMax; - - @CommandLine.Option( - names = "--max-concurrent-activities", - description = "Max concurrent activities") - private int maxConcurrentActivities; - - @CommandLine.Option( - names = "--max-concurrent-workflow-tasks", - description = "Max concurrent workflow tasks") - private int maxConcurrentWorkflowTasks; - - @CommandLine.Option( - names = "--activities-per-second", - description = "Per-worker activity rate limit") - private double workerActivitiesPerSecond; - - @CommandLine.Option( - names = "--err-on-unimplemented", - description = - "Error when receiving unimplemented actions (currently only affects concurrent client actions)", - defaultValue = "false") - private boolean errOnUnimplemented; - - @Override - public void run() { - // Configure TLS - SslContext sslContext = null; - if (StringUtils.isNotEmpty(clientCertPath)) { - if (StringUtils.isEmpty(clientKeyPath)) { - throw new RuntimeException("Client key path must be specified since cert path is"); - } - - try { - InputStream clientCert = new FileInputStream(clientCertPath); - InputStream clientKey = new FileInputStream(clientKeyPath); - sslContext = SimpleSslContextBuilder.forPKCS8(clientCert, clientKey).build(); - } catch (FileNotFoundException | SSLException e) { - throw new RuntimeException("Error loading certs", e); - } - - } else if (StringUtils.isNotEmpty(clientKeyPath) && StringUtils.isEmpty(clientCertPath)) { - throw new RuntimeException("Client cert path must be specified since key path is"); - } else if (isTlsEnabled) { - try { - sslContext = SimpleSslContextBuilder.noKeyOrCertChain().build(); - } catch (SSLException e) { - throw new RuntimeException(e); - } - } - - // Configure logging - Logger logger = - (Logger) org.slf4j.LoggerFactory.getLogger(ch.qos.logback.classic.Logger.ROOT_LOGGER_NAME); - logger.setLevel(Level.valueOf(logLevel)); - if (logEncoding == "json") { - ConsoleAppender appender = new ConsoleAppender(); - LogstashEncoder encoder = new LogstashEncoder(); - appender.setEncoder(encoder); - logger.addAppender(appender); - } - // Configure metrics - // Use a custom naming convention that doesn't add _seconds suffix to timers, - // for consistency with other Temporal SDKs - PrometheusMeterRegistry registry = new PrometheusMeterRegistry(PrometheusConfig.DEFAULT); - registry - .config() - .namingConvention( - new PrometheusNamingConvention() { - @Override - public String name(String name, Meter.Type type, String baseUnit) { - // Don't add unit suffix - Temporal SDKs report duration values in seconds - // but don't include _seconds in the metric name - return NamingConvention.snakeCase.name(name, type, null); - } - }); - StatsReporter reporter = new MicrometerClientStatsReporter(registry); - // set up a new scope, report every 10 seconds - Scope scope = - new RootScopeBuilder() - .reporter(reporter) - .reportEvery(com.uber.m3.util.Duration.ofSeconds(1)); - // Start the prometheus scrape endpoint for starter metrics - HttpServer scrapeEndpoint = - MetricsUtils.startPrometheusScrapeEndpoint(registry, promHandlerPath, promListenAddress); - // Stopping the starter will stop the http server that exposes the - // scrape endpoint. - Runtime.getRuntime().addShutdownHook(new Thread(() -> scrapeEndpoint.stop(1))); - // Configure API key - String apiKey = - (authHeader != null && authHeader.startsWith("Bearer ")) - ? authHeader.substring("Bearer ".length()) - : authHeader; + public static void main(String... args) throws Exception { + Harness.run(app(), args); + } - // Configure client - WorkflowServiceStubsOptions.Builder serviceOptionsBuilder = - WorkflowServiceStubsOptions.newBuilder() - .setTarget(serverAddress) - .setSslContext(sslContext) - .setMetricsScope(scope); - if (apiKey != null && !apiKey.isEmpty()) { - serviceOptionsBuilder.addApiKey(() -> apiKey); - } - WorkflowServiceStubs service = - WorkflowServiceStubs.newServiceStubs(serviceOptionsBuilder.build()); + private static Harness.App app() { + return new Harness.App(Main::configureWorker, Main::createClient); + } - PayloadConverter[] arr = { + private static WorkflowClient createClient(HarnessClients.ClientConfig config) throws Exception { + PayloadConverter[] converters = { new NullPayloadConverter(), new ByteArrayPayloadConverter(), new PassthroughDataConverter(), @@ -244,75 +33,13 @@ public String name(String name, Meter.Type type, String baseUnit) { new ProtobufPayloadConverter(), new JacksonJsonPayloadConverter() }; - - WorkflowClient client = - WorkflowClient.newInstance( - service, - WorkflowClientOptions.newBuilder() - .setDataConverter(new DefaultDataConverter(arr)) - .setNamespace(namespace) - .build()); - - // Collect task queues to run workers for (if there is a suffix end, we run multiple) - List taskQueues; - if (taskQueueIndexStart == 0) { - taskQueues = Collections.singletonList(taskQueue); - } else { - taskQueues = new ArrayList<>(taskQueueIndexEnd - taskQueueIndexStart); - for (int i = taskQueueIndexStart; i <= taskQueueIndexEnd; i++) { - taskQueues.add(String.format("%s-%d", taskQueue, i)); - } - } - // Create worker factory - WorkerFactory workerFactory = - WorkerFactory.newInstance( - client, WorkerFactoryOptions.newBuilder().setMaxWorkflowThreadCount(1000).build()); - // Create the base worker options - WorkerOptions.Builder workerOptions = WorkerOptions.newBuilder(); - // Workflow options - if (workflowPollerAutoscaleMax > 0) { - workerOptions.setWorkflowTaskPollersBehavior( - new PollerBehaviorAutoscaling(null, workflowPollerAutoscaleMax, null)); - } else if (maxConcurrentWorkflowPollers > 0) { - workerOptions.setMaxConcurrentWorkflowTaskPollers(maxConcurrentWorkflowPollers); - } - workerOptions.setMaxConcurrentWorkflowTaskExecutionSize(maxConcurrentWorkflowTasks); - // Activity options - if (activityPollerAutoscaleMax > 0) { - workerOptions.setActivityTaskPollersBehavior( - new PollerBehaviorAutoscaling(null, activityPollerAutoscaleMax, null)); - } else if (maxConcurrentActivityPollers > 0) { - workerOptions.setMaxConcurrentActivityTaskPollers(maxConcurrentActivityPollers); - } - workerOptions.setMaxConcurrentActivityExecutionSize(maxConcurrentActivities); - workerOptions.setMaxWorkerActivitiesPerSecond(workerActivitiesPerSecond); - // Start all workers, throwing on first exception - for (String taskQueue : taskQueues) { - Worker worker = workerFactory.newWorker(taskQueue, workerOptions.build()); - worker.registerWorkflowImplementationTypes(KitchenSinkWorkflowImpl.class); - worker.registerActivitiesImplementations(new ActivitiesImpl(client, errOnUnimplemented)); - } - workerFactory.start(); - CountDownLatch latch = new CountDownLatch(1); - - Runtime.getRuntime() - .addShutdownHook( - new Thread( - () -> { - scrapeEndpoint.stop(1); - // Shut all workers down - workerFactory.shutdownNow(); - latch.countDown(); - })); - try { - latch.await(); - System.exit(0); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } + return HarnessClients.newWorkflowClient(config, new DefaultDataConverter(converters)); } - public static void main(String... args) { - System.exit(new CommandLine(new Main()).execute(args)); + private static void configureWorker( + WorkflowClient client, Worker worker, WorkerHarness.WorkerContext context) { + worker.registerWorkflowImplementationTypes(KitchenSinkWorkflowImpl.class); + worker.registerActivitiesImplementations( + new ActivitiesImpl(client, context.errOnUnimplemented)); } } diff --git a/workers/java/io/temporal/omes/PassthroughDataConverter.java b/workers/java/io/temporal/omes/PassthroughDataConverter.java index dfac7ffb..447c802f 100644 --- a/workers/java/io/temporal/omes/PassthroughDataConverter.java +++ b/workers/java/io/temporal/omes/PassthroughDataConverter.java @@ -7,17 +7,13 @@ import io.temporal.common.converter.EncodingKeys; import io.temporal.common.converter.GlobalDataConverter; import io.temporal.common.converter.PayloadConverter; -import io.temporal.workflow.Workflow; import java.lang.reflect.Type; import java.nio.charset.StandardCharsets; import java.util.Optional; -import org.slf4j.Logger; -public final class PassthroughDataConverter implements PayloadConverter { - public static final Logger log = Workflow.getLogger(PassthroughDataConverter.class); - - static final String METADATA_ENCODING_NAME = "_passthrough"; - static final ByteString METADATA_ENCODING = +final class PassthroughDataConverter implements PayloadConverter { + private static final String METADATA_ENCODING_NAME = "_passthrough"; + private static final ByteString METADATA_ENCODING = ByteString.copyFrom(METADATA_ENCODING_NAME, StandardCharsets.UTF_8); @Override diff --git a/workers/java/settings.gradle b/workers/java/settings.gradle index 22929586..e56ddf45 100644 --- a/workers/java/settings.gradle +++ b/workers/java/settings.gradle @@ -6,3 +6,4 @@ */ rootProject.name = 'omes' +include 'harness' diff --git a/workers/run.go b/workers/run.go index 9d2da136..e23d66b3 100644 --- a/workers/run.go +++ b/workers/run.go @@ -130,8 +130,8 @@ func (r *Runner) Run(ctx context.Context, baseDir string) error { case clioptions.LangTypeScript: // Node also needs module before the harness subcommand. args = append(args, "./tslib/omes.js", "worker") - case clioptions.LangDotNet, clioptions.LangRuby, clioptions.LangGo: - // .NET, Ruby, and Go just need the harness worker subcommand + case clioptions.LangDotNet, clioptions.LangRuby, clioptions.LangJava, clioptions.LangGo: + // .NET, Ruby, Java and Go just need the harness worker subcommand args = append(args, "worker") }