Skip to content

Commit 4ead940

Browse files
authored
Inject SDK-side flattens while handling input/output coder mismatch in flattens. (#34641)
* A more general fix on handling flatten by injecting sdk-side flatten. * Re-enable a previously failed flatten test in java. * Add a new test to cover another case that would crash prism prior to the fix. The test is also included in the test suite of flink, samza and spark, but without transcoding until their corresponding FRs are resolved. * Fix a flaky test by sorting the keys of internal producers. * Skip the new flatten-gbk test in flink runner.
1 parent 8fb42c7 commit 4ead940

File tree

7 files changed

+69
-43
lines changed

7 files changed

+69
-43
lines changed

runners/prism/java/build.gradle

-5
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,6 @@ def sickbayTests = [
109109
// ShardedKey not yet implemented.
110110
'org.apache.beam.sdk.transforms.GroupIntoBatchesTest.testWithShardedKeyInGlobalWindow',
111111

112-
// Java side dying during execution.
113-
// Stream corruption error java side: failed:java.io.StreamCorruptedException: invalid stream header: 206E6F74
114-
// Likely due to prism't coder changes.
115-
'org.apache.beam.sdk.transforms.FlattenTest.testFlattenWithDifferentInputAndOutputCoders2',
116-
117112
// java.lang.IllegalStateException: Output with tag Tag<output> must have a schema in order to call getRowReceiver
118113
// Ultimately because getRoeReceiver code path SDK side isn't friendly to LengthPrefix wrapping of row coders.
119114
// https://github.com/apache/beam/issues/32931

sdks/go/pkg/beam/runners/prism/internal/handlerunner.go

+31-36
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@ func (h *runner) PrepareTransform(tid string, t *pipepb.PTransform, comps *pipep
8888
}
8989

9090
func (h *runner) handleFlatten(tid string, t *pipepb.PTransform, comps *pipepb.Components) prepareResult {
91-
if !h.config.SDKFlatten {
92-
t.EnvironmentId = "" // force the flatten to be a runner transform due to configuration.
91+
if !h.config.SDKFlatten && !strings.HasPrefix(tid, "ft_") {
9392
forcedRoots := []string{tid} // Have runner side transforms be roots.
9493

9594
// Force runner flatten consumers to be roots.
@@ -109,52 +108,48 @@ func (h *runner) handleFlatten(tid string, t *pipepb.PTransform, comps *pipepb.C
109108
// they're written out to the runner in the same fashion.
110109
// This may stop being necessary once Flatten Unzipping happens in the optimizer.
111110
outPCol := comps.GetPcollections()[outColID]
112-
outCoderID := outPCol.CoderId
113-
outCoder := comps.GetCoders()[outCoderID]
114-
coderSubs := map[string]*pipepb.Coder{}
115111
pcollSubs := map[string]*pipepb.PCollection{}
112+
tSubs := map[string]*pipepb.PTransform{}
116113

117-
if !strings.HasPrefix(outCoderID, "cf_") {
118-
// Create a new coder id for the flatten output PCollection and use
119-
// this coder id for all input PCollections
120-
outCoderID = "cf_" + outColID
121-
outCoder = proto.Clone(outCoder).(*pipepb.Coder)
122-
coderSubs[outCoderID] = outCoder
123-
124-
pcollSubs[outColID] = proto.Clone(outPCol).(*pipepb.PCollection)
125-
pcollSubs[outColID].CoderId = outCoderID
126-
127-
outPCol = pcollSubs[outColID]
128-
}
129-
130-
for _, p := range t.GetInputs() {
114+
ts := proto.Clone(t).(*pipepb.PTransform)
115+
ts.EnvironmentId = "" // force the flatten to be a runner transform due to configuration.
116+
for localID, p := range t.GetInputs() {
131117
inPCol := comps.GetPcollections()[p]
132118
if inPCol.CoderId != outPCol.CoderId {
133-
if strings.HasPrefix(inPCol.CoderId, "cf_") {
134-
// The input pcollection is the output of another flatten:
135-
// e.g. [[a, b] | Flatten], c] | Flatten
136-
// In this case, we just point the input coder id to the new flatten
137-
// output coder, so any upstream input pcollections will use the new
138-
// output coder.
139-
coderSubs[inPCol.CoderId] = outCoder
140-
} else {
141-
// Create a substitute PCollection for this input with the flatten
142-
// output coder id
143-
pcollSubs[p] = proto.Clone(inPCol).(*pipepb.PCollection)
144-
pcollSubs[p].CoderId = outPCol.CoderId
145-
}
119+
// TODO: do the following injection conditionally.
120+
// Now we inject an SDK-side flatten between the upstream transform and
121+
// the flatten.
122+
// Before: upstream -> [upstream out] -> runner flatten
123+
// After: upstream -> [upstream out] -> SDK-side flatten -> [SDK-side flatten out] -> runner flatten
124+
// Create a PCollection sub
125+
fColID := "fc_" + p + "_to_" + outColID
126+
fPCol := proto.Clone(outPCol).(*pipepb.PCollection)
127+
fPCol.CoderId = outPCol.CoderId // same coder as runner flatten
128+
pcollSubs[fColID] = fPCol
129+
130+
// Create a PTransform sub
131+
ftID := "ft_" + p + "_to_" + outColID
132+
ft := proto.Clone(t).(*pipepb.PTransform)
133+
ft.EnvironmentId = t.EnvironmentId // Set environment to ensure it is a SDK-side transform
134+
ft.Inputs = map[string]string{"0": p}
135+
ft.Outputs = map[string]string{"0": fColID}
136+
tSubs[ftID] = ft
137+
138+
// Replace the input of runner flatten with the output of SDK-side flatten
139+
ts.Inputs[localID] = fColID
140+
141+
// Force sdk-side flattens to be roots
142+
forcedRoots = append(forcedRoots, ftID)
146143
}
147144
}
145+
tSubs[tid] = ts
148146

149147
// Return the new components which is the transforms consumer
150148
return prepareResult{
151149
// We sub this flatten with itself, to not drop it.
152150
SubbedComps: &pipepb.Components{
153-
Transforms: map[string]*pipepb.PTransform{
154-
tid: t,
155-
},
151+
Transforms: tSubs,
156152
Pcollections: pcollSubs,
157-
Coders: coderSubs,
158153
},
159154
RemovedLeaves: nil,
160155
ForcedRoots: forcedRoots,

sdks/go/pkg/beam/runners/prism/internal/preprocess.go

+3
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,9 @@ func finalizeStage(stg *stage, comps *pipepb.Components, pipelineFacts *fusionFa
492492
}
493493

494494
stg.internalCols = internal
495+
// Sort the keys of internal producers (from stageFacts.PcolProducers)
496+
// to ensure deterministic order for stable tests.
497+
sort.Strings(stg.internalCols)
495498
stg.outputs = maps.Values(outputs)
496499
stg.sideInputs = sideInputs
497500

sdks/python/apache_beam/runners/portability/flink_runner_test.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,16 @@ def test_sql(self):
299299

300300
def test_flattened_side_input(self):
301301
# Blocked on support for transcoding
302-
# https://jira.apache.org/jira/browse/BEAM-6523
302+
# https://github.com/apache/beam/issues/19365
303303
super().test_flattened_side_input(with_transcoding=False)
304304

305+
def test_flatten_and_gbk(self):
306+
# Blocked on support for transcoding
307+
# https://github.com/apache/beam/issues/19365
308+
# Also blocked on support of flatten and groupby sharing the same input
309+
# https://github.com/apache/beam/issues/34647
310+
raise unittest.SkipTest("https://github.com/apache/beam/issues/34647")
311+
305312
def test_metrics(self):
306313
super().test_metrics(check_gauge=False, check_bounded_trie=False)
307314

sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py

+16
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,22 @@ def test_flattened_side_input(self, with_transcoding=True):
538538
equal_to([('a', 1), ('b', 2)] + third_element),
539539
label='CheckFlattenOfSideInput')
540540

541+
def test_flatten_and_gbk(self, with_transcoding=True):
542+
with self.create_pipeline() as p:
543+
side1 = p | 'side1' >> beam.Create([('a', 1)])
544+
if with_transcoding:
545+
# Also test non-matching coder types (transcoding required)
546+
second_element = [('another_type')]
547+
else:
548+
second_element = [('b', 2)]
549+
side2 = p | 'side2' >> beam.Create(second_element)
550+
551+
flatten_out = (side1, side2) | beam.Flatten()
552+
gbk_out = side1 | beam.GroupByKey()
553+
554+
assert_that(flatten_out, equal_to([('a', 1)] + second_element))
555+
assert_that(gbk_out, equal_to([('a', [1])]))
556+
541557
def test_gbk_side_input(self):
542558
with self.create_pipeline() as p:
543559
main = p | 'main' >> beam.Create([None])

sdks/python/apache_beam/runners/portability/samza_runner_test.py

+5
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ def test_flattened_side_input(self):
142142
# https://github.com/apache/beam/issues/20984
143143
super().test_flattened_side_input(with_transcoding=False)
144144

145+
def test_flatten_and_gbk(self):
146+
# Blocked on support for transcoding
147+
# https://github.com/apache/beam/issues/20984
148+
super().test_flatten_and_gbk(with_transcoding=False)
149+
145150
def test_pack_combiners(self):
146151
# Stages produced by translations.pack_combiners are fused
147152
# by translations.greedily_fuse, which prevent the stages

sdks/python/apache_beam/runners/portability/spark_runner_test.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,14 @@ def test_pardo_dynamic_timer(self):
174174

175175
def test_flattened_side_input(self):
176176
# Blocked on support for transcoding
177-
# https://jira.apache.org/jira/browse/BEAM-7236
177+
# https://github.com/apache/beam/issues/19504
178178
super().test_flattened_side_input(with_transcoding=False)
179179

180+
def test_flatten_and_gbk(self):
181+
# Blocked on support for transcoding
182+
# https://github.com/apache/beam/issues/19504
183+
super().test_flatten_and_gbk(with_transcoding=False)
184+
180185
def test_custom_merging_window(self):
181186
raise unittest.SkipTest("https://github.com/apache/beam/issues/20641")
182187

0 commit comments

Comments
 (0)