Skip to content

Commit 9ed29f6

Browse files
committed
format file
1 parent 0eed57b commit 9ed29f6

File tree

2 files changed

+108
-32
lines changed

2 files changed

+108
-32
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,4 @@ fabric.properties
116116
.ionide
117117

118118
# End of https://www.toptal.com/developers/gitignore/api/pycharm+all,visualstudiocode
119+
.venv/

bin/tflocal

Lines changed: 107 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,34 @@ DRY_RUN = str(os.environ.get("DRY_RUN")).strip().lower() in ["1", "true"]
3131
DEFAULT_REGION = "us-east-1"
3232
DEFAULT_ACCESS_KEY = "test"
3333
AWS_ENDPOINT_URL = os.environ.get("AWS_ENDPOINT_URL")
34-
CUSTOMIZE_ACCESS_KEY = str(os.environ.get("CUSTOMIZE_ACCESS_KEY")).strip().lower() in ["1", "true"]
34+
CUSTOMIZE_ACCESS_KEY = str(os.environ.get("CUSTOMIZE_ACCESS_KEY")).strip().lower() in [
35+
"1",
36+
"true",
37+
]
3538
LOCALHOST_HOSTNAME = "localhost.localstack.cloud"
3639
S3_HOSTNAME = os.environ.get("S3_HOSTNAME") or f"s3.{LOCALHOST_HOSTNAME}"
3740
USE_EXEC = str(os.environ.get("USE_EXEC")).strip().lower() in ["1", "true"]
3841
TF_CMD = os.environ.get("TF_CMD") or "terraform"
39-
ADDITIONAL_TF_OVERRIDE_LOCATIONS = (x for x in os.environ.get("ADDITIONAL_TF_OVERRIDE_LOCATIONS", default="").split(sep=",") if x and x != "")
40-
TF_UNPROXIED_CMDS = os.environ.get("TF_UNPROXIED_CMDS").split(sep=",") if os.environ.get("TF_UNPROXIED_CMDS") else ("fmt", "validate", "version")
41-
LS_PROVIDERS_FILE = os.environ.get("LS_PROVIDERS_FILE") or "localstack_providers_override.tf"
42-
LOCALSTACK_HOSTNAME = urlparse(AWS_ENDPOINT_URL).hostname or os.environ.get("LOCALSTACK_HOSTNAME") or "localhost"
42+
ADDITIONAL_TF_OVERRIDE_LOCATIONS = (
43+
x
44+
for x in os.environ.get("ADDITIONAL_TF_OVERRIDE_LOCATIONS", default="").split(
45+
sep=","
46+
)
47+
if x and x != ""
48+
)
49+
TF_UNPROXIED_CMDS = (
50+
os.environ.get("TF_UNPROXIED_CMDS").split(sep=",")
51+
if os.environ.get("TF_UNPROXIED_CMDS")
52+
else ("fmt", "validate", "version")
53+
)
54+
LS_PROVIDERS_FILE = (
55+
os.environ.get("LS_PROVIDERS_FILE") or "localstack_providers_override.tf"
56+
)
57+
LOCALSTACK_HOSTNAME = (
58+
urlparse(AWS_ENDPOINT_URL).hostname
59+
or os.environ.get("LOCALSTACK_HOSTNAME")
60+
or "localhost"
61+
)
4362
EDGE_PORT = int(urlparse(AWS_ENDPOINT_URL).port or os.environ.get("EDGE_PORT") or 4566)
4463
TF_VERSION: Optional[version.Version] = None
4564
TF_PROVIDER_CONFIG = """
@@ -134,11 +153,19 @@ SERVICE_REPLACEMENTS = {
134153
# CONFIG GENERATION UTILS
135154
# ---
136155

156+
137157
def create_provider_config_file(provider_file_path: str, provider_aliases=None) -> None:
138158
provider_aliases = provider_aliases or []
139159

140160
# Force service alias replacements
141-
SERVICE_REPLACEMENTS.update({alias: alias_pairs[0] for alias_pairs in SERVICE_ALIASES for alias in alias_pairs if alias != alias_pairs[0]})
161+
SERVICE_REPLACEMENTS.update(
162+
{
163+
alias: alias_pairs[0]
164+
for alias_pairs in SERVICE_ALIASES
165+
for alias in alias_pairs
166+
if alias != alias_pairs[0]
167+
}
168+
)
142169

143170
# create list of service names
144171
services = list(config.get_service_ports())
@@ -163,9 +190,11 @@ def create_provider_config_file(provider_file_path: str, provider_aliases=None)
163190
for provider in provider_aliases:
164191
provider_config = TF_PROVIDER_CONFIG.replace(
165192
"<access_key>",
166-
get_access_key(provider) if CUSTOMIZE_ACCESS_KEY else DEFAULT_ACCESS_KEY
193+
get_access_key(provider) if CUSTOMIZE_ACCESS_KEY else DEFAULT_ACCESS_KEY,
194+
)
195+
endpoints = "\n".join(
196+
[f' {s} = "{get_service_endpoint(s)}"' for s in services]
167197
)
168-
endpoints = "\n".join([f' {s} = "{get_service_endpoint(s)}"' for s in services])
169198
provider_config = provider_config.replace("<endpoints>", endpoints)
170199
additional_configs = []
171200
if use_s3_path_style():
@@ -179,7 +208,9 @@ def create_provider_config_file(provider_file_path: str, provider_aliases=None)
179208
if isinstance(region, list):
180209
region = region[0]
181210
additional_configs += [f'region = "{region}"']
182-
provider_config = provider_config.replace("<configs>", "\n".join(additional_configs))
211+
provider_config = provider_config.replace(
212+
"<configs>", "\n".join(additional_configs)
213+
)
183214
provider_configs.append(provider_config)
184215

185216
# construct final config file content
@@ -197,6 +228,7 @@ def write_provider_config_file(providers_file, tf_config):
197228
with open(providers_file, mode="w") as fp:
198229
fp.write(tf_config)
199230

231+
200232
def get_default_provider_folder_path() -> str:
201233
"""Determine the folder under which the providers override file should be stored"""
202234
chdir = [arg for arg in sys.argv if arg.startswith("-chdir=")]
@@ -206,6 +238,7 @@ def get_default_provider_folder_path() -> str:
206238

207239
return os.path.abspath(base_dir)
208240

241+
209242
def get_providers_file_path(base_dir) -> str:
210243
"""Retrieve the path under which the providers override file should be stored"""
211244
return os.path.join(base_dir, LS_PROVIDERS_FILE)
@@ -219,7 +252,11 @@ def determine_provider_aliases() -> list:
219252
for _file, obj in tf_files.items():
220253
try:
221254
providers = ensure_list(obj.get("provider", []))
222-
aws_providers = [prov["aws"] for prov in providers if prov.get("aws") and prov.get("aws").get("alias") not in skipped]
255+
aws_providers = [
256+
prov["aws"]
257+
for prov in providers
258+
if prov.get("aws") and prov.get("aws").get("alias") not in skipped
259+
]
223260
result.extend(aws_providers)
224261
except Exception as e:
225262
print(f"Warning: Unable to extract providers from {_file}:", e)
@@ -260,7 +297,6 @@ def generate_s3_backend_config() -> str:
260297
"skip_credentials_validation": True,
261298
"skip_metadata_api_check": True,
262299
"secret_key": "test",
263-
264300
"endpoints": {
265301
"s3": get_service_endpoint("s3"),
266302
"iam": get_service_endpoint("iam"),
@@ -271,23 +307,37 @@ def generate_s3_backend_config() -> str:
271307
}
272308
# Merge in legacy endpoint configs if not existing already
273309
if is_tf_legacy and backend_config.get("endpoints"):
274-
print("Warning: Unsupported backend option(s) detected (`endpoints`). Please make sure you always use the corresponding options to your Terraform version.")
310+
print(
311+
"Warning: Unsupported backend option(s) detected (`endpoints`). Please make sure you always use the corresponding options to your Terraform version."
312+
)
275313
exit(1)
276314
for legacy_endpoint, endpoint in legacy_endpoint_mappings.items():
277-
if legacy_endpoint in backend_config and backend_config.get("endpoints") and endpoint in backend_config["endpoints"]:
315+
if (
316+
legacy_endpoint in backend_config
317+
and backend_config.get("endpoints")
318+
and endpoint in backend_config["endpoints"]
319+
):
278320
del backend_config[legacy_endpoint]
279321
continue
280-
if legacy_endpoint in backend_config and (not backend_config.get("endpoints") or endpoint not in backend_config["endpoints"]):
322+
if legacy_endpoint in backend_config and (
323+
not backend_config.get("endpoints")
324+
or endpoint not in backend_config["endpoints"]
325+
):
281326
if not backend_config.get("endpoints"):
282327
backend_config["endpoints"] = {}
283-
backend_config["endpoints"].update({endpoint: backend_config[legacy_endpoint]})
328+
backend_config["endpoints"].update(
329+
{endpoint: backend_config[legacy_endpoint]}
330+
)
284331
del backend_config[legacy_endpoint]
285332
# Add any missing default endpoints
286333
if backend_config.get("endpoints"):
287334
backend_config["endpoints"] = {
288335
k: backend_config["endpoints"].get(k) or v
289-
for k, v in configs["endpoints"].items()}
290-
backend_config["access_key"] = get_access_key(backend_config) if CUSTOMIZE_ACCESS_KEY else DEFAULT_ACCESS_KEY
336+
for k, v in configs["endpoints"].items()
337+
}
338+
backend_config["access_key"] = (
339+
get_access_key(backend_config) if CUSTOMIZE_ACCESS_KEY else DEFAULT_ACCESS_KEY
340+
)
291341
configs.update(backend_config)
292342
if not DRY_RUN:
293343
get_or_create_bucket(configs["bucket"])
@@ -300,22 +350,27 @@ def generate_s3_backend_config() -> str:
300350
elif isinstance(value, dict):
301351
if key == "endpoints" and is_tf_legacy:
302352
for legacy_endpoint, endpoint in legacy_endpoint_mappings.items():
303-
config_options += f'\n {legacy_endpoint} = "{configs[key][endpoint]}"'
353+
config_options += (
354+
f'\n {legacy_endpoint} = "{configs[key][endpoint]}"'
355+
)
304356
continue
305357
else:
306358
value = textwrap.indent(
307-
text=f"{key} = {{\n" + "\n".join([f' {k} = "{v}"' for k, v in value.items()]) + "\n}",
308-
prefix=" " * 4)
359+
text=f"{key} = {{\n"
360+
+ "\n".join([f' {k} = "{v}"' for k, v in value.items()])
361+
+ "\n}",
362+
prefix=" " * 4,
363+
)
309364
config_options += f"\n{value}"
310365
continue
311366
elif isinstance(value, list):
312367
# TODO this will break if it's a list of dicts or other complex object
313368
# this serialization logic should probably be moved to a separate recursive function
314369
as_string = [f'"{item}"' for item in value]
315-
value = f'[{", ".join(as_string)}]'
370+
value = f"[{', '.join(as_string)}]"
316371
else:
317372
value = f'"{str(value)}"'
318-
config_options += f'\n {key} = {value}'
373+
config_options += f"\n {key} = {value}"
319374
result = result.replace("<configs>", config_options)
320375
return result
321376

@@ -339,6 +394,7 @@ def check_override_file(providers_file: str) -> None:
339394
# AWS CLIENT UTILS
340395
# ---
341396

397+
342398
def use_s3_path_style() -> bool:
343399
"""
344400
Whether to use S3 path addressing (depending on the configured S3 endpoint)
@@ -363,6 +419,7 @@ def get_region() -> str:
363419
# Note that boto3 is currently not included in the dependencies, to
364420
# keep the library lightweight.
365421
import boto3
422+
366423
region = boto3.session.Session().region_name
367424
except Exception:
368425
pass
@@ -371,7 +428,9 @@ def get_region() -> str:
371428

372429

373430
def get_access_key(provider: dict) -> str:
374-
access_key = str(os.environ.get("AWS_ACCESS_KEY_ID") or provider.get("access_key", "")).strip()
431+
access_key = str(
432+
os.environ.get("AWS_ACCESS_KEY_ID") or provider.get("access_key", "")
433+
).strip()
375434
if access_key and access_key != DEFAULT_ACCESS_KEY:
376435
# Change live access key to mocked one
377436
return deactivate_access_key(access_key)
@@ -380,6 +439,7 @@ def get_access_key(provider: dict) -> str:
380439
# Note that boto3 is currently not included in the dependencies, to
381440
# keep the library lightweight.
382441
import boto3
442+
383443
access_key = boto3.session.Session().get_credentials().access_key
384444
except Exception:
385445
pass
@@ -389,7 +449,7 @@ def get_access_key(provider: dict) -> str:
389449

390450
def deactivate_access_key(access_key: str) -> str:
391451
"""Safe guarding user from accidental live credential usage by deactivating access key IDs.
392-
See more: https://docs.localstack.cloud/references/credentials/"""
452+
See more: https://docs.localstack.cloud/references/credentials/"""
393453
return "L" + access_key[1:] if access_key[0] == "A" else access_key
394454

395455

@@ -415,10 +475,14 @@ def get_service_endpoint(service: str) -> str:
415475

416476
def connect_to_service(service: str, region: str = None):
417477
import boto3
478+
418479
region = region or get_region()
419480
return boto3.client(
420-
service, endpoint_url=get_service_endpoint(service), region_name=region,
421-
aws_access_key_id="test", aws_secret_access_key="test",
481+
service,
482+
endpoint_url=get_service_endpoint(service),
483+
region_name=region,
484+
aws_access_key_id="test",
485+
aws_secret_access_key="test",
422486
)
423487

424488

@@ -442,9 +506,10 @@ def get_or_create_ddb_table(table_name: str, region: str = None):
442506
return ddb_client.describe_table(TableName=table_name)
443507
except Exception:
444508
return ddb_client.create_table(
445-
TableName=table_name, BillingMode="PAY_PER_REQUEST",
509+
TableName=table_name,
510+
BillingMode="PAY_PER_REQUEST",
446511
KeySchema=[{"AttributeName": "LockID", "KeyType": "HASH"}],
447-
AttributeDefinitions=[{"AttributeName": "LockID", "AttributeType": "S"}]
512+
AttributeDefinitions=[{"AttributeName": "LockID", "AttributeType": "S"}],
448513
)
449514

450515

@@ -471,13 +536,15 @@ def parse_tf_files() -> dict:
471536

472537
def get_tf_version(env):
473538
global TF_VERSION
474-
output = subprocess.run([f"{TF_CMD}", "version", "-json"], env=env, check=True, capture_output=True).stdout.decode("utf-8")
539+
output = subprocess.run(
540+
[f"{TF_CMD}", "version", "-json"], env=env, check=True, capture_output=True
541+
).stdout.decode("utf-8")
475542
TF_VERSION = version.parse(json.loads(output)["terraform_version"])
476543

477544

478545
def run_tf_exec(cmd, env):
479546
"""Run terraform using os.exec - can be useful as it does not require any I/O
480-
handling for stdin/out/err. Does *not* allow us to perform any cleanup logic."""
547+
handling for stdin/out/err. Does *not* allow us to perform any cleanup logic."""
481548
os.execvpe(cmd[0], cmd, env=env)
482549

483550

@@ -487,19 +554,25 @@ def run_tf_subprocess(cmd, env):
487554

488555
# register signal handlers
489556
import signal
557+
490558
signal.signal(signal.SIGINT, signal_handler)
491559

492560
PROCESS = subprocess.Popen(
493-
cmd, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stdout)
561+
cmd, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stdout
562+
)
494563
PROCESS.communicate()
495564
sys.exit(PROCESS.returncode)
496565

566+
497567
def cleanup_override_files(override_files: Iterable[str]):
498568
for file_path in override_files:
499569
try:
500570
os.remove(file_path)
501571
except Exception:
502-
print(f"Count not clean up '{file_path}'. This is not normally a problem but you can delete this file manually.")
572+
print(
573+
f"Count not clean up '{file_path}'. This is not normally a problem but you can delete this file manually."
574+
)
575+
503576

504577
def get_folder_paths_that_require_an_override_file() -> Iterable[str]:
505578
if not is_override_needed(sys.argv[1:]):
@@ -514,6 +587,7 @@ def get_folder_paths_that_require_an_override_file() -> Iterable[str]:
514587
# UTIL FUNCTIONS
515588
# ---
516589

590+
517591
def signal_handler(sig, frame):
518592
PROCESS.send_signal(sig)
519593

@@ -534,6 +608,7 @@ def to_str(obj) -> bytes:
534608
# MAIN ENTRYPOINT
535609
# ---
536610

611+
537612
def main():
538613
env = dict(os.environ)
539614
cmd = [TF_CMD] + sys.argv[1:]

0 commit comments

Comments
 (0)