|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, software |
| 13 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 14 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 15 | + * See the License for the specific language governing permissions and |
| 16 | + * limitations under the License. |
| 17 | + */ |
| 18 | +package org.apache.beam.sdk.extensions.yaml; |
| 19 | + |
| 20 | +import java.util.HashSet; |
| 21 | +import java.util.Map; |
| 22 | +import java.util.Set; |
| 23 | +import java.util.stream.Collectors; |
| 24 | +import org.apache.beam.sdk.extensions.python.PythonExternalTransform; |
| 25 | +import org.apache.beam.sdk.transforms.PTransform; |
| 26 | +import org.apache.beam.sdk.values.PBegin; |
| 27 | +import org.apache.beam.sdk.values.PCollection; |
| 28 | +import org.apache.beam.sdk.values.PCollectionRowTuple; |
| 29 | +import org.apache.beam.sdk.values.PInput; |
| 30 | +import org.apache.beam.sdk.values.POutput; |
| 31 | +import org.apache.beam.sdk.values.PValue; |
| 32 | +import org.apache.beam.sdk.values.Row; |
| 33 | +import org.apache.beam.sdk.values.TupleTag; |
| 34 | +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Joiner; |
| 35 | +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; |
| 36 | +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; |
| 37 | +import org.checkerframework.checker.nullness.qual.Nullable; |
| 38 | + |
| 39 | +/** |
| 40 | + * Allows one to invoke <a href="https://beam.apache.org/documentation/sdks/yaml/">Beam YAML</a> |
| 41 | + * transforms from Java. |
| 42 | + * |
| 43 | + * <p>This leverages Beam's cross-langauge transforms. Although python is required to parse and |
| 44 | + * expand the given transforms, the actual implementation may still be in Java. |
| 45 | + * |
| 46 | + * @param <InputT> the type of the input to this PTransform |
| 47 | + * @param <OutputT> the type of the output to this PTransform |
| 48 | + */ |
| 49 | +public class YamlTransform<InputT extends PInput, OutputT extends POutput> |
| 50 | + extends PTransform<InputT, OutputT> { |
| 51 | + |
| 52 | + /** The YAML definition of this transform. */ |
| 53 | + private final String yamlDefinition; |
| 54 | + /** |
| 55 | + * If non-null, the set of input tags that are expected to be passed to this transform. |
| 56 | + * |
| 57 | + * <p>If null, a {@literal PCollection<Row>} or PBegin is expected. |
| 58 | + */ |
| 59 | + private final @Nullable Set<String> inputTags; |
| 60 | + |
| 61 | + /** |
| 62 | + * If non-null, the set of output tags that are expected to be produced by this transform. |
| 63 | + * |
| 64 | + * <p>If null, exactly one output is expected and will be returned as a {@literal |
| 65 | + * PCollection<Row>}. |
| 66 | + */ |
| 67 | + private final @Nullable Set<String> outputTags; |
| 68 | + |
| 69 | + private YamlTransform( |
| 70 | + String yamlDefinition, |
| 71 | + @Nullable Iterable<String> inputTags, |
| 72 | + @Nullable Iterable<String> outputTags) { |
| 73 | + this.yamlDefinition = yamlDefinition; |
| 74 | + this.inputTags = inputTags == null ? null : ImmutableSet.copyOf(inputTags); |
| 75 | + this.outputTags = outputTags == null ? null : ImmutableSet.copyOf(outputTags); |
| 76 | + } |
| 77 | + |
| 78 | + /** |
| 79 | + * Creates a new YamlTransform mapping a single input {@literal PCollection<Row>} to a single |
| 80 | + * {@literal PCollection<Row>} output. |
| 81 | + * |
| 82 | + * <p>Use {@link #withMultipleInputs} or {@link #withMultipleOutputs} to indicate that this |
| 83 | + * transform has multiple inputs and/or outputs. |
| 84 | + * |
| 85 | + * @param yamlDefinition a YAML string defining this transform. |
| 86 | + * @return a PTransform that applies this YAML to its inputs. |
| 87 | + */ |
| 88 | + public static YamlTransform<PCollection<Row>, PCollection<Row>> of(String yamlDefinition) { |
| 89 | + return new YamlTransform<PCollection<Row>, PCollection<Row>>(yamlDefinition, null, null); |
| 90 | + } |
| 91 | + |
| 92 | + /** |
| 93 | + * Creates a new YamlTransform PBegin a single {@literal PCollection<Row>} output. |
| 94 | + * |
| 95 | + * @param yamlDefinition a YAML string defining this source. |
| 96 | + * @return a PTransform that applies this YAML as a root transform. |
| 97 | + */ |
| 98 | + public static YamlTransform<PBegin, PCollection<Row>> source(String yamlDefinition) { |
| 99 | + return new YamlTransform<PBegin, PCollection<Row>>(yamlDefinition, null, null); |
| 100 | + } |
| 101 | + |
| 102 | + /** |
| 103 | + * Creates a new YamlTransform mapping a single input {@literal PCollection<Row>} to a single |
| 104 | + * {@literal PCollection<Row>} output. |
| 105 | + * |
| 106 | + * <p>Use {@link #withMultipleOutputs} to indicate that this sink has multiple (or no) or outputs. |
| 107 | + * |
| 108 | + * @param yamlDefinition a YAML string defining this sink. |
| 109 | + * @return a PTransform that applies this YAML to its inputs. |
| 110 | + */ |
| 111 | + public static YamlTransform<PCollection<Row>, PCollection<Row>> sink(String yamlDefinition) { |
| 112 | + return of(yamlDefinition); |
| 113 | + } |
| 114 | + |
| 115 | + /** |
| 116 | + * Indicates that this YamlTransform expects multiple, named inputs. |
| 117 | + * |
| 118 | + * @param inputTags the set of expected input tags to this transform |
| 119 | + * @return a PTransform like this but with a {@link PCollectionRowTuple} input type. |
| 120 | + */ |
| 121 | + public YamlTransform<PCollectionRowTuple, OutputT> withMultipleInputs(String... inputTags) { |
| 122 | + return new YamlTransform<PCollectionRowTuple, OutputT>( |
| 123 | + yamlDefinition, ImmutableSet.copyOf(inputTags), outputTags); |
| 124 | + } |
| 125 | + |
| 126 | + /** |
| 127 | + * Indicates that this YamlTransform expects multiple, named outputs. |
| 128 | + * |
| 129 | + * @param outputTags the set of expected output tags to this transform |
| 130 | + * @return a PTransform like this but with a {@link PCollectionRowTuple} output type. |
| 131 | + */ |
| 132 | + public YamlTransform<InputT, PCollectionRowTuple> withMultipleOutputs(String... outputTags) { |
| 133 | + return new YamlTransform<InputT, PCollectionRowTuple>( |
| 134 | + yamlDefinition, inputTags, ImmutableSet.copyOf(outputTags)); |
| 135 | + } |
| 136 | + |
| 137 | + @Override |
| 138 | + public OutputT expand(InputT input) { |
| 139 | + if (inputTags != null) { |
| 140 | + Set<String> actualInputTags = |
| 141 | + input.expand().keySet().stream() |
| 142 | + .map(TupleTag::getId) |
| 143 | + .collect(Collectors.toCollection(HashSet::new)); |
| 144 | + if (!inputTags.equals(actualInputTags)) { |
| 145 | + throw new IllegalArgumentException( |
| 146 | + "Input has tags " |
| 147 | + + Joiner.on(", ").join(actualInputTags) |
| 148 | + + " but expected input tags " |
| 149 | + + Joiner.on(", ").join(inputTags)); |
| 150 | + } |
| 151 | + } |
| 152 | + |
| 153 | + // There is no generic apply... |
| 154 | + POutput output; |
| 155 | + @SuppressWarnings("rawtypes") |
| 156 | + PTransform externalTransform = |
| 157 | + PythonExternalTransform.from("apache_beam.yaml.yaml_transform.YamlTransform") |
| 158 | + .withArgs(yamlDefinition) |
| 159 | + .withExtraPackages(ImmutableList.of("jinja2", "pyyaml", "virtualenv-clone")); |
| 160 | + if (input instanceof PBegin) { |
| 161 | + output = ((PBegin) input).apply(externalTransform); |
| 162 | + } else if (input instanceof PCollection) { |
| 163 | + output = ((PCollection<?>) input).apply(externalTransform); |
| 164 | + } else if (input instanceof PCollection) { |
| 165 | + output = ((PCollection<?>) input).apply(externalTransform); |
| 166 | + } else if (input instanceof PCollectionRowTuple) { |
| 167 | + output = ((PCollectionRowTuple) input).apply(externalTransform); |
| 168 | + } else { |
| 169 | + throw new IllegalArgumentException("Unrecognized input type: " + input); |
| 170 | + } |
| 171 | + |
| 172 | + if (outputTags == null) { |
| 173 | + if (!(output instanceof PCollection)) { |
| 174 | + throw new IllegalArgumentException( |
| 175 | + "Expected a single PCollection output, but got " |
| 176 | + + output |
| 177 | + + ". Perhaps withMultipleOutputs() needs to be specified?"); |
| 178 | + } |
| 179 | + return (OutputT) output; |
| 180 | + } else { |
| 181 | + if (output instanceof PCollection) { |
| 182 | + // ExternalPythonTransform always returns single outputs as PCollections. |
| 183 | + if (outputTags.size() != 1) { |
| 184 | + throw new IllegalArgumentException( |
| 185 | + "Expected " + outputTags.size() + " outputs, but got exactly one."); |
| 186 | + } |
| 187 | + return (OutputT) |
| 188 | + PCollectionRowTuple.of(outputTags.iterator().next(), (PCollection<Row>) output); |
| 189 | + } else { |
| 190 | + Map<TupleTag<?>, PValue> expandedOutputs = output.expand(); |
| 191 | + Set<String> actualOutputTags = |
| 192 | + expandedOutputs.keySet().stream() |
| 193 | + .map(TupleTag::getId) |
| 194 | + .collect(Collectors.toCollection(HashSet::new)); |
| 195 | + if (!outputTags.equals(actualOutputTags)) { |
| 196 | + throw new IllegalArgumentException( |
| 197 | + "Output has tags " |
| 198 | + + Joiner.on(", ").join(actualOutputTags) |
| 199 | + + " but expected output tags " |
| 200 | + + Joiner.on(", ").join(outputTags)); |
| 201 | + } |
| 202 | + PCollectionRowTuple result = PCollectionRowTuple.empty(input.getPipeline()); |
| 203 | + for (Map.Entry<TupleTag<?>, PValue> subOutput : expandedOutputs.entrySet()) { |
| 204 | + result = result.and(subOutput.getKey().getId(), (PCollection<Row>) subOutput.getValue()); |
| 205 | + } |
| 206 | + return (OutputT) result; |
| 207 | + } |
| 208 | + } |
| 209 | + } |
| 210 | +} |
0 commit comments