@@ -105,25 +105,61 @@ def parse_args(args: Optional[Sequence[str]] = None) -> Dict[str, Any]:
105105 required = False ,
106106 help = 'The maximum number of partitions to be used for parallelism in table writing'
107107 )
108+ parser .add_argument (
109+ '--gcs.to.jdbc.password.secret.id' ,
110+ dest = 'gcs.to.jdbc.password.secret.id' ,
111+ required = False ,
112+ help = 'Secret Manager secret ID for the JDBC password. Must be in the format projects/PROJECT_ID/secrets/SECRET_ID.'
113+ )
114+ parser .add_argument (
115+ '--gcs.to.jdbc.password.secret.version' ,
116+ dest = 'gcs.to.jdbc.password.secret.version' ,
117+ required = False ,
118+ default = 'latest' ,
119+ help = 'Secret Manager secret version for the JDBC password. Defaults to "latest".'
120+ )
108121
109122 known_args : argparse .Namespace
110123 known_args , _ = parser .parse_known_args (args )
111124
112125 return vars (known_args )
113126
127+ def get_secret (project_id , secret_id , version_id ):
128+ """Retrieves a secret from Secret Manager."""
129+ client = secretmanager .SecretManagerServiceClient ()
130+ name = f"projects/{ project_id } /secrets/{ secret_id } /versions/{ version_id } "
131+ response = client .access_secret_version (name = name )
132+ return response .payload .data .decode ('UTF-8' )
133+
114134 def run (self , spark : SparkSession , args : Dict [str , Any ]) -> None :
115- logger : Logger = self .get_logger (spark = spark )
135+ logger : Logger = self .get_logger (spark = spark )
116136
117137 # Arguments
118138 input_location : str = args [constants .GCS_JDBC_INPUT_LOCATION ]
119139 input_format : str = args [constants .GCS_JDBC_INPUT_FORMAT ]
120- jdbc_url : str = args [constants .GCS_JDBC_OUTPUT_URL ]
140+ # jdbc_url: str = args[constants.GCS_JDBC_OUTPUT_URL]
121141 jdbc_table : str = args [constants .GCS_JDBC_OUTPUT_TABLE ]
122142 output_mode : str = args [constants .GCS_JDBC_OUTPUT_MODE ]
123143 output_driver : str = args [constants .GCS_JDBC_OUTPUT_DRIVER ]
124144 batch_size : int = args [constants .GCS_JDBC_BATCH_SIZE ]
125145 jdbc_numpartitions : int = args [constants .GCS_JDBC_NUMPARTITIONS ]
126146
147+ if hasattr (properties , 'gcs.to.jdbc.password.secret.id' ) and properties ['gcs.to.jdbc.password.secret.id' ]:
148+ password = get_secret (
149+ properties ['gcp.project.id' ],
150+ properties ['gcs.to.jdbc.password.secret.id' ],
151+ properties ['gcs.to.jdbc.password.secret.version' ]
152+ )
153+
154+ jdbc_url : str = args [constants .GCS_JDBC_OUTPUT_URL ]
155+ if 'password' not in jdbc_url :
156+ # This assumes the password is provided via a property in the URL like ';password=...'
157+ # A more robust solution is needed depending on the database driver
158+ if '?' in jdbc_url :
159+ jdbc_url += f"&password={ password } "
160+ else :
161+ jdbc_url += f"?password={ password } "
162+
127163 ignore_keys = {constants .GCS_JDBC_OUTPUT_URL }
128164 filtered_args = {key :val for key ,val in args .items () if key not in ignore_keys }
129165 logger .info (
0 commit comments