Skip to content

Commit 56b286c

Browse files
committed
Add tests.
1 parent 3422fba commit 56b286c

File tree

2 files changed

+45
-6
lines changed

2 files changed

+45
-6
lines changed

Diff for: sdks/python/apache_beam/io/gcp/bigquery_file_loads.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,15 @@
3737

3838
import apache_beam as beam
3939
from apache_beam import pvalue
40-
from apache_beam.transforms import util
4140
from apache_beam.io import filesystems as fs
4241
from apache_beam.io.gcp import bigquery_tools
4342
from apache_beam.io.gcp.bigquery_io_metadata import create_bigquery_io_metadata
4443
from apache_beam.metrics.metric import Lineage
4544
from apache_beam.options import value_provider as vp
4645
from apache_beam.options.pipeline_options import GoogleCloudOptions
4746
from apache_beam.transforms import trigger
47+
from apache_beam.transforms import util
4848
from apache_beam.transforms.display import DisplayDataItem
49-
from apache_beam.transforms.util import GroupIntoBatches
5049
from apache_beam.transforms.window import GlobalWindows
5150

5251
# Protect against environments where bigquery library is not available.
@@ -1063,7 +1062,7 @@ def _write_files_with_auto_sharding(
10631062
destination_data_kv_pc
10641063
|
10651064
'ToHashableTableRef' >> beam.Map(bigquery_tools.to_hashable_table_ref)
1066-
| 'WithAutoSharding' >> GroupIntoBatches.WithShardedKey(
1065+
| 'WithAutoSharding' >> util.GroupIntoBatches.WithShardedKey(
10671066
batch_size=_FILE_TRIGGERING_RECORD_COUNT,
10681067
max_buffering_duration_secs=_FILE_TRIGGERING_BATCHING_DURATION_SECS,
10691068
clock=clock)
@@ -1102,9 +1101,11 @@ def _load_data(
11021101
of the load jobs would fail but not other. If any of them fails, then
11031102
copy jobs are not triggered.
11041103
"""
1105-
# Ensure that TriggerLoadJob retry inputs are deterministic by breaking
1106-
# fusion for inputs.
1107-
if not util.is_compat_version_prior_to(p.options, "2.65.0"):
1104+
self.reshuffle_before_load = not util.is_compat_version_prior_to(
1105+
p.options, "2.65.0")
1106+
if self.reshuffle_before_load:
1107+
# Ensure that TriggerLoadJob retry inputs are deterministic by breaking
1108+
# fusion for inputs.
11081109
partitions_using_temp_tables = (
11091110
partitions_using_temp_tables
11101111
| "ReshuffleBeforeLoadWithTempTables" >> beam.Reshuffle())

Diff for: sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py

+38
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,44 @@ def test_records_traverse_transform_with_mocks(self):
478478

479479
assert_that(jobs, equal_to([job_reference]), label='CheckJobs')
480480

481+
@parameterized.expand([
482+
param(compat_version=None),
483+
param(compat_version="2.64.0"),
484+
])
485+
def test_reshuffle_before_load(self, compat_version):
486+
destination = 'project1:dataset1.table1'
487+
488+
job_reference = bigquery_api.JobReference()
489+
job_reference.projectId = 'project1'
490+
job_reference.jobId = 'job_name1'
491+
result_job = bigquery_api.Job()
492+
result_job.jobReference = job_reference
493+
494+
mock_job = mock.Mock()
495+
mock_job.status.state = 'DONE'
496+
mock_job.status.errorResult = None
497+
mock_job.jobReference = job_reference
498+
499+
bq_client = mock.Mock()
500+
bq_client.jobs.Get.return_value = mock_job
501+
502+
bq_client.jobs.Insert.return_value = result_job
503+
504+
transform = bqfl.BigQueryBatchFileLoads(
505+
destination,
506+
custom_gcs_temp_location=self._new_tempdir(),
507+
test_client=bq_client,
508+
validate=False,
509+
temp_file_format=bigquery_tools.FileFormat.JSON)
510+
511+
options = PipelineOptions(update_compatibility_version=compat_version)
512+
# Need to test this with the DirectRunner to avoid serializing mocks
513+
with TestPipeline('DirectRunner', options=options) as p:
514+
_ = p | beam.Create(_ELEMENTS) | transform
515+
516+
reshuffle_before_load = compat_version is None
517+
assert transform.reshuffle_before_load == reshuffle_before_load
518+
481519
def test_load_job_id_used(self):
482520
job_reference = bigquery_api.JobReference()
483521
job_reference.projectId = 'loadJobProject'

0 commit comments

Comments
 (0)