|
| 1 | +import logging |
| 2 | +from typing import Tuple |
| 3 | + |
| 4 | +from apache_beam import Create |
| 5 | +from apache_beam import DoFn |
| 6 | +from apache_beam import FlatMap |
| 7 | +from apache_beam import ParDo |
| 8 | +from apache_beam import Pipeline |
| 9 | +from apache_beam import TimeDomain |
| 10 | +from apache_beam import WithKeys |
| 11 | +from apache_beam import typehints |
| 12 | +from apache_beam.coders import BytesCoder |
| 13 | +from apache_beam.options.pipeline_options import PipelineOptions |
| 14 | +from apache_beam.options.pipeline_options import SetupOptions |
| 15 | +from apache_beam.transforms.userstate import BagStateSpec |
| 16 | +from apache_beam.transforms.userstate import TimerSpec |
| 17 | +from apache_beam.transforms.userstate import on_timer |
| 18 | + |
| 19 | +# The total bytes processed is NUM_SHARDS * NUM_ELEMENTS_PER_SHARD * ELEMENT_BYTES ~= 3 GiB |
| 20 | +NUM_SHARDS = 100 |
| 21 | +NUM_ELEMENTS_PER_SHARD = 10 |
| 22 | +ELEMENT_BYTES = 3 * 1024 * 1024 # 3 MiB |
| 23 | + |
| 24 | + |
| 25 | +@typehints.with_input_types(Tuple[str, bytes]) |
| 26 | +@typehints.with_output_types(None) |
| 27 | +class BigBagDoFn(DoFn): |
| 28 | + VALUES_STATE = BagStateSpec('values', BytesCoder()) |
| 29 | + END_OF_WINDOW_TIMER = TimerSpec('end_of_window', TimeDomain.WATERMARK) |
| 30 | + |
| 31 | + def process(self, element: Tuple[str, bytes], window=DoFn.WindowParam, |
| 32 | + values_state=DoFn.StateParam(VALUES_STATE), |
| 33 | + end_of_window_timer=DoFn.TimerParam(END_OF_WINDOW_TIMER)): |
| 34 | + logging.info('start process.') |
| 35 | + key, value = element |
| 36 | + end_of_window_timer.set(window.end) |
| 37 | + values_state.add(value) |
| 38 | + logging.info('end process.') |
| 39 | + |
| 40 | + @on_timer(END_OF_WINDOW_TIMER) |
| 41 | + def end_of_window(self, values_state=DoFn.StateParam(VALUES_STATE)): |
| 42 | + logging.info('start end_of_window.') |
| 43 | + |
| 44 | + read_count = 0 |
| 45 | + read_bytes = 0 |
| 46 | + values = values_state.read() |
| 47 | + for value in values: |
| 48 | + read_count += 1 |
| 49 | + read_bytes += len(value) |
| 50 | + |
| 51 | + logging.info('read_count: %s, read_bytes: %s', read_count, read_bytes) |
| 52 | + logging.info('end end_of_window.') |
| 53 | + |
| 54 | + |
| 55 | +def main(): |
| 56 | + options = PipelineOptions() |
| 57 | + options.view_as(SetupOptions).save_main_session = True |
| 58 | + |
| 59 | + p = Pipeline(options=options) |
| 60 | + (p |
| 61 | + | Create(list(range(NUM_SHARDS))) |
| 62 | + | FlatMap(lambda _: |
| 63 | + (bytes(ELEMENT_BYTES) for _ in range(NUM_ELEMENTS_PER_SHARD))) |
| 64 | + | WithKeys('') |
| 65 | + | ParDo(BigBagDoFn())) |
| 66 | + |
| 67 | + p.run() |
| 68 | + |
| 69 | + |
| 70 | +if __name__ == '__main__': |
| 71 | + logging.getLogger().setLevel(logging.INFO) |
| 72 | + main() |
0 commit comments