diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 4f7133838794..3e594cf76d42 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -814,6 +814,9 @@ def split(element): mapping_transform = mapping_transform.with_outputs(*output_set) splits = pcoll | mapping_transform.with_input_types(T).with_output_types(T) result = {out: getattr(splits, out) for out in output_set} + for tag, out in result.items(): + if tag != error_output: + out.element_type = pcoll.element_type if error_output: result[error_output] = result[error_output] | map_errors_to_standard_format( pcoll.element_type) diff --git a/sdks/python/apache_beam/yaml/yaml_mapping_test.py b/sdks/python/apache_beam/yaml/yaml_mapping_test.py index 1054f73dc130..d5179d385caf 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping_test.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping_test.py @@ -212,6 +212,7 @@ def test_partition(self): language: python outputs: [even, odd] ''') + self.assertEqual(result['even'].element_type, elements.element_type) assert_that( result['even'] | beam.Map(lambda x: x.element), equal_to(['banana', 'orange']),