@@ -21,6 +21,7 @@ import (
21
21
"io"
22
22
"reflect"
23
23
"sort"
24
+ "strings"
24
25
25
26
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
26
27
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime"
@@ -108,12 +109,40 @@ func (h *runner) handleFlatten(tid string, t *pipepb.PTransform, comps *pipepb.C
108
109
// they're written out to the runner in the same fashion.
109
110
// This may stop being necessary once Flatten Unzipping happens in the optimizer.
110
111
outPCol := comps .GetPcollections ()[outColID ]
112
+ outCoderID := outPCol .CoderId
113
+ outCoder := comps .GetCoders ()[outCoderID ]
114
+ coderSubs := map [string ]* pipepb.Coder {}
111
115
pcollSubs := map [string ]* pipepb.PCollection {}
116
+
117
+ if ! strings .HasPrefix (outCoderID , "cf_" ) {
118
+ // Create a new coder id for the flatten output PCollection and use
119
+ // this coder id for all input PCollections
120
+ outCoderID = "cf_" + outColID
121
+ outCoder = proto .Clone (outCoder ).(* pipepb.Coder )
122
+ coderSubs [outCoderID ] = outCoder
123
+
124
+ pcollSubs [outColID ] = proto .Clone (outPCol ).(* pipepb.PCollection )
125
+ pcollSubs [outColID ].CoderId = outCoderID
126
+
127
+ outPCol = pcollSubs [outColID ]
128
+ }
129
+
112
130
for _ , p := range t .GetInputs () {
113
131
inPCol := comps .GetPcollections ()[p ]
114
132
if inPCol .CoderId != outPCol .CoderId {
115
- pcollSubs [p ] = proto .Clone (inPCol ).(* pipepb.PCollection )
116
- pcollSubs [p ].CoderId = outPCol .CoderId
133
+ if strings .HasPrefix (inPCol .CoderId , "cf_" ) {
134
+ // The input pcollection is the output of another flatten:
135
+ // e.g. [[a, b] | Flatten], c] | Flatten
136
+ // In this case, we just point the input coder id to the new flatten
137
+ // output coder, so any upstream input pcollections will use the new
138
+ // output coder.
139
+ coderSubs [inPCol .CoderId ] = outCoder
140
+ } else {
141
+ // Create a substitute PCollection for this input with the flatten
142
+ // output coder id
143
+ pcollSubs [p ] = proto .Clone (inPCol ).(* pipepb.PCollection )
144
+ pcollSubs [p ].CoderId = outPCol .CoderId
145
+ }
117
146
}
118
147
}
119
148
@@ -125,6 +154,7 @@ func (h *runner) handleFlatten(tid string, t *pipepb.PTransform, comps *pipepb.C
125
154
tid : t ,
126
155
},
127
156
Pcollections : pcollSubs ,
157
+ Coders : coderSubs ,
128
158
},
129
159
RemovedLeaves : nil ,
130
160
ForcedRoots : forcedRoots ,
0 commit comments