|
| 1 | +package pl.touk.nussknacker.engine.flink |
| 2 | + |
| 3 | +import com.typesafe.config.ConfigFactory |
| 4 | +import pl.touk.nussknacker.engine.api.component.ComponentDefinition |
| 5 | +import pl.touk.nussknacker.engine.api.process.SourceFactory |
| 6 | +import pl.touk.nussknacker.engine.canonicalgraph.CanonicalProcess |
| 7 | +import pl.touk.nussknacker.engine.flink.test.FlinkSpec |
| 8 | +import pl.touk.nussknacker.engine.flink.test.ScalatestMiniClusterJobStatusCheckingOps.miniClusterWithServicesToOps |
| 9 | +import pl.touk.nussknacker.engine.flink.util.source.EmitWatermarkAfterEachElementCollectionSource |
| 10 | +import pl.touk.nussknacker.engine.flink.util.transformer.FlinkBaseComponentProvider |
| 11 | +import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.AggregateWindowsConfig |
| 12 | +import pl.touk.nussknacker.engine.process.helpers.ConfigCreatorWithCollectingListener |
| 13 | +import pl.touk.nussknacker.engine.process.runner.FlinkScenarioUnitTestJob |
| 14 | +import pl.touk.nussknacker.engine.testing.LocalModelData |
| 15 | +import pl.touk.nussknacker.engine.testmode.{ResultsCollectingListener, ResultsCollectingListenerHolder} |
| 16 | +import pl.touk.nussknacker.engine.testmode.TestProcess.{NodeTransition, TestResults} |
| 17 | +import pl.touk.nussknacker.engine.util.config.DocsConfig |
| 18 | +import pl.touk.nussknacker.test.ProcessUtils.convertToAnyShouldWrapper |
| 19 | + |
| 20 | +import java.time.{Duration, Instant} |
| 21 | +import scala.jdk.CollectionConverters._ |
| 22 | +import scala.util.Try |
| 23 | + |
| 24 | +trait FlinkMiniClusterTestRunner { _: FlinkSpec => |
| 25 | + |
| 26 | + protected def sourcesWithMockedData: Map[String, List[Int]] |
| 27 | + |
| 28 | + protected def withCollectingTestResults( |
| 29 | + canonicalProcess: CanonicalProcess, |
| 30 | + assertions: TestResults[Any] => Unit, |
| 31 | + allowEndingScenarioWithoutSink: Boolean = false, |
| 32 | + ): Unit = { |
| 33 | + ResultsCollectingListenerHolder.withListener { collectingListener => |
| 34 | + val model = modelData(collectingListener, AggregateWindowsConfig.Default, allowEndingScenarioWithoutSink) |
| 35 | + flinkMiniCluster.withDetachedStreamExecutionEnvironment { env => |
| 36 | + val executionResult = new FlinkScenarioUnitTestJob(model).run(canonicalProcess, env) |
| 37 | + flinkMiniCluster.waitForJobIsFinished(executionResult.getJobID) |
| 38 | + assertions(collectingListener.results) |
| 39 | + } |
| 40 | + } |
| 41 | + } |
| 42 | + |
| 43 | + private def modelData( |
| 44 | + collectingListener: => ResultsCollectingListener[Any], |
| 45 | + aggregateWindowsConfig: AggregateWindowsConfig, |
| 46 | + allowEndingScenarioWithoutSink: Boolean, |
| 47 | + ): LocalModelData = { |
| 48 | + def sourceComponent(data: List[Int]) = SourceFactory.noParamUnboundedStreamFactory[Int]( |
| 49 | + EmitWatermarkAfterEachElementCollectionSource |
| 50 | + .create[Int](data, _ => Instant.now.toEpochMilli, Duration.ofHours(1)) |
| 51 | + ) |
| 52 | + val config = |
| 53 | + if (allowEndingScenarioWithoutSink) { |
| 54 | + ConfigFactory.parseString("""allowEndingScenarioWithoutSink: true""") |
| 55 | + } else { |
| 56 | + ConfigFactory.empty() |
| 57 | + } |
| 58 | + LocalModelData( |
| 59 | + config, |
| 60 | + sourcesWithMockedData.toList.map { case (name, data) => |
| 61 | + ComponentDefinition(name, sourceComponent(data)) |
| 62 | + } ::: |
| 63 | + FlinkBaseUnboundedComponentProvider.create( |
| 64 | + DocsConfig.Default, |
| 65 | + aggregateWindowsConfig |
| 66 | + ) ::: FlinkBaseComponentProvider.Components, |
| 67 | + configCreator = new ConfigCreatorWithCollectingListener(collectingListener), |
| 68 | + ) |
| 69 | + } |
| 70 | + |
| 71 | + protected def transitionVariables( |
| 72 | + testResults: TestResults[Any], |
| 73 | + fromNodeId: String, |
| 74 | + toNodeId: Option[String] |
| 75 | + ): Set[Map[String, Any]] = |
| 76 | + testResults |
| 77 | + .nodeTransitionResults(NodeTransition(fromNodeId, toNodeId)) |
| 78 | + .map(_.variables) |
| 79 | + .toSet[Map[String, Any]] |
| 80 | + .map(_.map { case (key, value) => (key, scalaMap(value)) }) |
| 81 | + |
| 82 | + private def scalaMap(value: Any): Any = { |
| 83 | + value match { |
| 84 | + case hashMap: java.util.HashMap[_, _] => hashMap.asScala.toMap |
| 85 | + case other => other |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | + protected def assertNumberOfSamplesThatFinishedInNode(testResults: TestResults[Any], sinkId: String, expected: Int) = |
| 90 | + testResults.nodeTransitionResults.get(NodeTransition(sinkId, None)).map(_.length) shouldBe Some(expected) |
| 91 | + |
| 92 | + protected def catchExceptionMessage(f: => Any): String = Try(f).failed.get.getMessage |
| 93 | + |
| 94 | +} |
0 commit comments