Skip to content

Commit 659cc4d

Browse files
authored
Add pipeline option to enforce gbek (apache#36321)
* Add pipeline option to enforce gbek * option description * lint * typing * Fix test mocks * Don't depend on secretmanager in test_gbk_actually_does_encryption * gemini feedback
1 parent c0774c9 commit 659cc4d

File tree

4 files changed

+159
-8
lines changed

4 files changed

+159
-8
lines changed

sdks/python/apache_beam/options/pipeline_options.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1716,6 +1716,21 @@ def _add_argparse_args(cls, parser):
17161716
help=(
17171717
'Docker registry url to use for tagging and pushing the prebuilt '
17181718
'sdk worker container image.'))
1719+
parser.add_argument(
1720+
'--gbek',
1721+
default=None,
1722+
help=(
1723+
'When set, will replace all GroupByKey transforms in the pipeline '
1724+
'with EncryptedGroupByKey transforms using the secret passed in '
1725+
'the option. Beam will infer the secret type and value based on '
1726+
'secret itself. This guarantees that any data at rest during the '
1727+
'GBK will be encrypted. Many runners only store data at rest when '
1728+
'performing a GBK, so this can be used to guarantee that data is '
1729+
'not unencrypted. Runners with this behavior include the '
1730+
'Dataflow, Flink, and Spark runners. The option should be '
1731+
'structured like: '
1732+
'--gbek=type:<secret_type>;<secret_param>:<value>, for example '
1733+
'--gbek=type:GcpSecret;version_name:my_secret/versions/latest'))
17191734
parser.add_argument(
17201735
'--user_agent',
17211736
default=None,

sdks/python/apache_beam/transforms/core.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from apache_beam.coders import typecoders
4040
from apache_beam.internal import pickler
4141
from apache_beam.internal import util
42+
from apache_beam.options.pipeline_options import SetupOptions
4243
from apache_beam.options.pipeline_options import TypeOptions
4344
from apache_beam.portability import common_urns
4445
from apache_beam.portability import python_urns
@@ -3324,6 +3325,10 @@ class GroupByKey(PTransform):
33243325
33253326
The implementation here is used only when run on the local direct runner.
33263327
"""
3328+
def __init__(self):
3329+
self._replaced_by_gbek = False
3330+
self._inside_gbek = False
3331+
33273332
class ReifyWindows(DoFn):
33283333
def process(
33293334
self, element, window=DoFn.WindowParam, timestamp=DoFn.TimestampParam):
@@ -3354,6 +3359,16 @@ def get_windowing(self, inputs):
33543359
environment_id=windowing.environment_id)
33553360

33563361
def expand(self, pcoll):
3362+
replace_with_gbek_secret = (
3363+
pcoll.pipeline._options.view_as(SetupOptions).gbek)
3364+
if replace_with_gbek_secret is not None and not self._inside_gbek:
3365+
self._replaced_by_gbek = True
3366+
from apache_beam.transforms.util import GroupByEncryptedKey
3367+
from apache_beam.transforms.util import Secret
3368+
3369+
secret = Secret.parse_secret_option(replace_with_gbek_secret)
3370+
return (pcoll | "Group by encrypted key" >> GroupByEncryptedKey(secret))
3371+
33573372
from apache_beam.transforms.trigger import DataLossReason
33583373
from apache_beam.transforms.trigger import DefaultTrigger
33593374
windowing = pcoll.windowing
@@ -3400,7 +3415,11 @@ def infer_output_type(self, input_type):
34003415
return typehints.KV[key_type, typehints.Iterable[value_type]]
34013416

34023417
def to_runner_api_parameter(self, unused_context):
3403-
# type: (PipelineContext) -> typing.Tuple[str, None]
3418+
# type: (PipelineContext) -> tuple[str, typing.Optional[typing.Union[message.Message, bytes, str]]]
3419+
# if we're containing a GroupByEncryptedKey, don't allow runners to
3420+
# recognize this transform as a GBEK so that it doesn't get replaced.
3421+
if self._replaced_by_gbek:
3422+
return super().to_runner_api_parameter(unused_context)
34043423
return common_urns.primitives.GROUP_BY_KEY.urn, None
34053424

34063425
@staticmethod

sdks/python/apache_beam/transforms/util.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,44 @@ def generate_secret_bytes() -> bytes:
341341
"""Generates a new secret key."""
342342
return Fernet.generate_key()
343343

344+
@staticmethod
345+
def parse_secret_option(secret) -> 'Secret':
346+
"""Parses a secret string and returns the appropriate secret type.
347+
348+
The secret string should be formatted like:
349+
'type:<secret_type>;<secret_param>:<value>'
350+
351+
For example, 'type:GcpSecret;version_name:my_secret/versions/latest'
352+
would return a GcpSecret initialized with 'my_secret/versions/latest'.
353+
"""
354+
param_map = {}
355+
for param in secret.split(';'):
356+
parts = param.split(':')
357+
param_map[parts[0]] = parts[1]
358+
359+
if 'type' not in param_map:
360+
raise ValueError('Secret string must contain a valid type parameter')
361+
362+
secret_type = param_map['type'].lower()
363+
del param_map['type']
364+
secret_class = None
365+
secret_params = None
366+
if secret_type == 'gcpsecret':
367+
secret_class = GcpSecret
368+
secret_params = ['version_name']
369+
else:
370+
raise ValueError(
371+
f'Invalid secret type {secret_type}, currently only '
372+
'GcpSecret is supported')
373+
374+
for param_name in param_map.keys():
375+
if param_name not in secret_params:
376+
raise ValueError(
377+
f'Invalid secret parameter {param_name}, '
378+
f'{secret_type} only supports the following '
379+
f'parameters: {secret_params}')
380+
return secret_class(**param_map)
381+
344382

345383
class GcpSecret(Secret):
346384
"""A secret manager implementation that retrieves secrets from Google Cloud
@@ -367,7 +405,12 @@ def get_secret_bytes(self) -> bytes:
367405
secret = response.payload.data
368406
return secret
369407
except Exception as e:
370-
raise RuntimeError(f'Failed to retrieve secret bytes with excetion {e}')
408+
raise RuntimeError(
409+
'Failed to retrieve secret bytes for secret '
410+
f'{self._version_name} with exception {e}')
411+
412+
def __eq__(self, secret):
413+
return self._version_name == getattr(secret, '_version_name', None)
371414

372415

373416
class _EncryptMessage(DoFn):
@@ -499,7 +542,9 @@ def __init__(self, hmac_key: Secret):
499542
self._hmac_key = hmac_key
500543

501544
def expand(self, pcoll):
502-
kv_type_hint = pcoll.element_type
545+
key_type, value_type = (typehints.typehints.coerce_to_kv_type(
546+
pcoll.element_type).tuple_types)
547+
kv_type_hint = typehints.KV[key_type, value_type]
503548
if kv_type_hint and kv_type_hint != typehints.Any:
504549
coder = coders.registry.get_coder(kv_type_hint).as_deterministic_coder(
505550
f'GroupByEncryptedKey {self.label}'
@@ -518,10 +563,13 @@ def expand(self, pcoll):
518563
key_coder = coders.registry.get_coder(typehints.Any)
519564
value_coder = key_coder
520565

566+
gbk = beam.GroupByKey()
567+
gbk._inside_gbek = True
568+
521569
return (
522570
pcoll
523571
| beam.ParDo(_EncryptMessage(self._hmac_key, key_coder, value_coder))
524-
| beam.GroupByKey()
572+
| gbk
525573
| beam.ParDo(_DecryptMessage(self._hmac_key, key_coder, value_coder)))
526574

527575

sdks/python/apache_beam/transforms/util_test.py

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from apache_beam.coders import coders
5151
from apache_beam.metrics import MetricsFilter
5252
from apache_beam.options.pipeline_options import PipelineOptions
53+
from apache_beam.options.pipeline_options import SetupOptions
5354
from apache_beam.options.pipeline_options import StandardOptions
5455
from apache_beam.options.pipeline_options import TypeOptions
5556
from apache_beam.portability import common_urns
@@ -252,7 +253,7 @@ def test_co_group_by_key_on_unpickled(self):
252253

253254

254255
class FakeSecret(beam.Secret):
255-
def __init__(self, should_throw=False):
256+
def __init__(self, version_name=None, should_throw=False):
256257
self._secret = b'aKwI2PmqYFt2p5tNKCyBS5qYmHhHsGZcyZrnZQiQ-uE='
257258
self._should_throw = should_throw
258259

@@ -273,6 +274,12 @@ def __init__(self, hmac_key_secret, key_coder, value_coder):
273274
super().__init__(hmac_key_secret, key_coder, value_coder)
274275

275276
def process(self, element):
277+
final_elements = list(super().process(element))
278+
# Check if we're looking at the actual elements being encoded/decoded
279+
# There is also a gbk on assertEqual, which uses None as the key type.
280+
final_element_keys = [e for e in final_elements if e[0] in ['a', 'b', 'c']]
281+
if len(final_element_keys) == 0:
282+
return final_elements
276283
hmac_key, actual_elements = element
277284
if hmac_key not in self.known_hmacs:
278285
raise ValueError(f'GBK produced unencrypted value {hmac_key}')
@@ -286,7 +293,38 @@ def process(self, element):
286293
except InvalidToken:
287294
raise ValueError(f'GBK produced unencrypted value {e[1]}')
288295

289-
return super().process(element)
296+
return final_elements
297+
298+
299+
class SecretTest(unittest.TestCase):
300+
@parameterized.expand([
301+
param(
302+
secret_string='type:GcpSecret;version_name:my_secret/versions/latest',
303+
secret=GcpSecret('my_secret/versions/latest')),
304+
param(
305+
secret_string='type:GcpSecret;version_name:foo',
306+
secret=GcpSecret('foo')),
307+
param(
308+
secret_string='type:gcpsecreT;version_name:my_secret/versions/latest',
309+
secret=GcpSecret('my_secret/versions/latest')),
310+
])
311+
def test_secret_manager_parses_correctly(self, secret_string, secret):
312+
self.assertEqual(secret, Secret.parse_secret_option(secret_string))
313+
314+
@parameterized.expand([
315+
param(
316+
secret_string='version_name:foo',
317+
exception_str='must contain a valid type parameter'),
318+
param(
319+
secret_string='type:gcpsecreT',
320+
exception_str='missing 1 required positional argument'),
321+
param(
322+
secret_string='type:gcpsecreT;version_name:foo;extra:val',
323+
exception_str='Invalid secret parameter extra'),
324+
])
325+
def test_secret_manager_throws_on_invalid(self, secret_string, exception_str):
326+
with self.assertRaisesRegex(Exception, exception_str):
327+
Secret.parse_secret_option(secret_string)
290328

291329

292330
class GroupByEncryptedKeyTest(unittest.TestCase):
@@ -318,7 +356,9 @@ def setUp(self):
318356
'data': Secret.generate_secret_bytes()
319357
}
320358
})
321-
self.gcp_secret = GcpSecret(f'{self.secret_path}/versions/latest')
359+
version_name = f'{self.secret_path}/versions/latest'
360+
self.gcp_secret = GcpSecret(version_name)
361+
self.secret_option = f'type:GcpSecret;version_name:{version_name}'
322362

323363
def tearDown(self):
324364
if secretmanager is not None:
@@ -334,6 +374,20 @@ def test_gbek_fake_secret_manager_roundtrips(self):
334374
assert_that(
335375
result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))
336376

377+
@unittest.skipIf(secretmanager is None, 'GCP dependencies are not installed')
378+
def test_gbk_with_gbek_option_fake_secret_manager_roundtrips(self):
379+
options = PipelineOptions()
380+
options.view_as(SetupOptions).gbek = self.secret_option
381+
382+
with beam.Pipeline(options=options) as pipeline:
383+
pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
384+
('b', 3), ('c', 4)])
385+
result = (pcoll_1) | beam.GroupByKey()
386+
sorted_result = result | beam.Map(lambda x: (x[0], sorted(x[1])))
387+
assert_that(
388+
sorted_result,
389+
equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))
390+
337391
@mock.patch('apache_beam.transforms.util._DecryptMessage', MockNoOpDecrypt)
338392
def test_gbek_fake_secret_manager_actually_does_encryption(self):
339393
fakeSecret = FakeSecret()
@@ -345,8 +399,23 @@ def test_gbek_fake_secret_manager_actually_does_encryption(self):
345399
assert_that(
346400
result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))
347401

402+
@mock.patch('apache_beam.transforms.util._DecryptMessage', MockNoOpDecrypt)
403+
@mock.patch('apache_beam.transforms.util.GcpSecret', FakeSecret)
404+
def test_gbk_actually_does_encryption(self):
405+
options = PipelineOptions()
406+
# Version of GcpSecret doesn't matter since it is replaced by FakeSecret
407+
options.view_as(SetupOptions).gbek = 'type:GcpSecret;version_name:Foo'
408+
409+
with TestPipeline('FnApiRunner', options=options) as pipeline:
410+
pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
411+
('b', 3), ('c', 4)],
412+
reshuffle=False)
413+
result = pcoll_1 | beam.GroupByKey()
414+
assert_that(
415+
result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))
416+
348417
def test_gbek_fake_secret_manager_throws(self):
349-
fakeSecret = FakeSecret(True)
418+
fakeSecret = FakeSecret(None, True)
350419

351420
with self.assertRaisesRegex(RuntimeError, r'Exception retrieving secret'):
352421
with TestPipeline() as pipeline:

0 commit comments

Comments
 (0)