Skip to content

Commit 7974241

Browse files
authored
gcs to jdbc chnages
1 parent 887b07d commit 7974241

File tree

1 file changed

+38
-2
lines changed

1 file changed

+38
-2
lines changed

python/dataproc_templates/gcs/gcs_to_jdbc.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,25 +105,61 @@ def parse_args(args: Optional[Sequence[str]] = None) -> Dict[str, Any]:
105105
required=False,
106106
help='The maximum number of partitions to be used for parallelism in table writing'
107107
)
108+
parser.add_argument(
109+
'--gcs.to.jdbc.password.secret.id',
110+
dest='gcs.to.jdbc.password.secret.id',
111+
required=False,
112+
help='Secret Manager secret ID for the JDBC password. Must be in the format projects/PROJECT_ID/secrets/SECRET_ID.'
113+
)
114+
parser.add_argument(
115+
'--gcs.to.jdbc.password.secret.version',
116+
dest='gcs.to.jdbc.password.secret.version',
117+
required=False,
118+
default='latest',
119+
help='Secret Manager secret version for the JDBC password. Defaults to "latest".'
120+
)
108121

109122
known_args: argparse.Namespace
110123
known_args, _ = parser.parse_known_args(args)
111124

112125
return vars(known_args)
113126

127+
def get_secret(project_id, secret_id, version_id):
128+
"""Retrieves a secret from Secret Manager."""
129+
client = secretmanager.SecretManagerServiceClient()
130+
name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}"
131+
response = client.access_secret_version(name=name)
132+
return response.payload.data.decode('UTF-8')
133+
114134
def run(self, spark: SparkSession, args: Dict[str, Any]) -> None:
115-
logger: Logger = self.get_logger(spark=spark)
135+
logger: Logger = self.get_logger(spark=spark)
116136

117137
# Arguments
118138
input_location: str = args[constants.GCS_JDBC_INPUT_LOCATION]
119139
input_format: str = args[constants.GCS_JDBC_INPUT_FORMAT]
120-
jdbc_url: str = args[constants.GCS_JDBC_OUTPUT_URL]
140+
#jdbc_url: str = args[constants.GCS_JDBC_OUTPUT_URL]
121141
jdbc_table: str = args[constants.GCS_JDBC_OUTPUT_TABLE]
122142
output_mode: str = args[constants.GCS_JDBC_OUTPUT_MODE]
123143
output_driver: str = args[constants.GCS_JDBC_OUTPUT_DRIVER]
124144
batch_size: int = args[constants.GCS_JDBC_BATCH_SIZE]
125145
jdbc_numpartitions: int = args[constants.GCS_JDBC_NUMPARTITIONS]
126146

147+
if hasattr(properties, 'gcs.to.jdbc.password.secret.id') and properties['gcs.to.jdbc.password.secret.id']:
148+
password = get_secret(
149+
properties['gcp.project.id'],
150+
properties['gcs.to.jdbc.password.secret.id'],
151+
properties['gcs.to.jdbc.password.secret.version']
152+
)
153+
154+
jdbc_url: str = args[constants.GCS_JDBC_OUTPUT_URL]
155+
if 'password' not in jdbc_url:
156+
# This assumes the password is provided via a property in the URL like ';password=...'
157+
# A more robust solution is needed depending on the database driver
158+
if '?' in jdbc_url:
159+
jdbc_url += f"&password={password}"
160+
else:
161+
jdbc_url += f"?password={password}"
162+
127163
ignore_keys = {constants.GCS_JDBC_OUTPUT_URL}
128164
filtered_args = {key:val for key,val in args.items() if key not in ignore_keys}
129165
logger.info(

0 commit comments

Comments
 (0)