Skip to content

Commit 69a83d5

Browse files
committed
fix: better support for non-aws providers
We only have an import ID generator for the AWS provider, but for other providers we can still create the import blocks and do our best at linking to the provider documentation for that specific resource with instructions on how to construct the import ID.
1 parent b95d6ff commit 69a83d5

File tree

2 files changed

+143
-63
lines changed

2 files changed

+143
-63
lines changed

src/tfblocks/main.py

Lines changed: 70 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,7 @@ def get_aws_resource_import_id_generators() -> Dict[str, type]:
1717
}
1818

1919

20-
def file_exists(file_path: str) -> bool:
21-
"""Check if a file exists."""
22-
return os.path.exists(file_path)
23-
24-
25-
def extract_resource_addresses_from_content(content: str) -> List[str]:
20+
def extract_addresses_from_content(content: str) -> List[str]:
2621
"""Extract resource and module addresses from Terraform content."""
2722
addresses = []
2823

@@ -41,20 +36,19 @@ def extract_resource_addresses_from_content(content: str) -> List[str]:
4136
return addresses
4237

4338

44-
def extract_resource_addresses_from_file(file_path: str) -> List[str]:
39+
def extract_addresses_from_file(file_path: str) -> List[str]:
4540
"""Extract resource and module addresses from a Terraform file."""
4641
addresses = []
4742

48-
# First check if file exists
49-
if not file_exists(file_path):
43+
if not os.path.exists(file_path):
5044
print(f"Error: File {file_path} does not exist", file=sys.stderr)
5145
sys.exit(1)
5246

5347
try:
5448
with open(file_path, "r") as f:
5549
content = f.read()
5650

57-
addresses = extract_resource_addresses_from_content(content)
51+
addresses = extract_addresses_from_content(content)
5852

5953
except Exception as e:
6054
print(f"Warning: Could not process file {file_path}: {str(e)}", file=sys.stderr)
@@ -158,12 +152,12 @@ def matches_address_list(addr: str, addr_list: List[str]) -> bool:
158152
def filter_resources(
159153
state: Dict[str, Any], addresses: List[str] = [], files: List[str] = []
160154
) -> List[Dict[str, Any]]:
161-
"""Extract matching AWS resources from Terraform state."""
155+
"""Extract matching resources from Terraform state."""
162156
# Extract addresses from files if provided
163157
file_addresses = []
164158
if files:
165159
for file_path in files:
166-
extracted = extract_resource_addresses_from_file(file_path)
160+
extracted = extract_addresses_from_file(file_path)
167161
file_addresses.extend(extracted)
168162

169163
if not file_addresses:
@@ -185,10 +179,8 @@ def filter_resources(
185179

186180
# Process resources in current module
187181
for resource in module.get("resources", []):
188-
if (
189-
resource.get("type", "").startswith("aws_")
190-
and resource.get("mode") == "managed"
191-
and is_resource_match(resource["address"], addresses, file_addresses)
182+
if resource.get("mode") == "managed" and is_resource_match(
183+
resource["address"], addresses, file_addresses
192184
):
193185
resources.append(resource)
194186

@@ -202,20 +194,48 @@ def filter_resources(
202194

203195

204196
def generate_import_block(
205-
resource: Dict[str, Any], schema_classes: Dict[str, type]
206-
) -> str:
197+
resource: Dict[str, Any],
198+
schema_classes: Dict[str, type],
199+
supported_providers_only: bool = False,
200+
) -> str | None:
207201
"""Generate Terraform import block for a resource."""
208-
matching_class = schema_classes.get(resource["type"])
209-
documentation = f"https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/{resource['type'].replace('aws_', '')}#import"
210-
import_id = f'"" # TODO: {documentation}'
211-
212-
if matching_class:
213-
try:
214-
instance = matching_class(resource["address"], resource["values"])
215-
if instance.import_id is not None:
216-
import_id = f'"{instance.import_id}"'
217-
except Exception:
218-
pass
202+
provider_name = resource["type"].split("_")[0] if "_" in resource["type"] else ""
203+
204+
if provider_name == "aws":
205+
# For AWS resources, we have import ID generators
206+
matching_class = schema_classes.get(resource["type"])
207+
documentation = f"https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/{resource['type'].replace('aws_', '')}#import"
208+
import_id = f'"" # TODO: {documentation}'
209+
210+
if matching_class:
211+
try:
212+
instance = matching_class(resource["address"], resource["values"])
213+
if instance.import_id is not None:
214+
import_id = f'"{instance.import_id}"'
215+
except Exception:
216+
pass
217+
elif not supported_providers_only:
218+
# For providers we don't have import ID generators for
219+
# we create a block with a link to the resource documentation
220+
resource_type = resource["type"]
221+
222+
# Use provider_name field when available to get the right documentation URL
223+
if resource.get("provider_name", "").startswith("registry.terraform.io/"):
224+
parts = resource["provider_name"].split("/")
225+
if len(parts) >= 3:
226+
org = parts[1]
227+
provider = parts[2]
228+
# Best-effort attempt at linking to the right place
229+
docs_url = f"https://registry.terraform.io/providers/{org}/{provider}/latest/docs/resources/{resource_type.removeprefix(f'{provider_name}_')}#import"
230+
import_id = f'"" # TODO: {docs_url}'
231+
else:
232+
import_id = '"" # TODO'
233+
else:
234+
# Fallback to generic message if provider is not on the Terraform registry
235+
import_id = '"" # TODO'
236+
else:
237+
# For unsupported providers when restricting to supported providers only
238+
return None # Skip this resource
219239

220240
return f"""import {{
221241
to = {resource["address"]}
@@ -239,7 +259,10 @@ def generate_removed_block(resource_addr: str, destroy: bool = False) -> str:
239259

240260

241261
def generate_blocks_for_command(
242-
resources: List[Dict[str, Any]], command: str, destroy: bool = False
262+
resources: List[Dict[str, Any]],
263+
command: str,
264+
destroy: bool = False,
265+
supported_providers_only: bool = False,
243266
) -> List[str]:
244267
"""Generate Terraform code blocks based on command."""
245268
blocks = []
@@ -261,7 +284,14 @@ def generate_blocks_for_command(
261284
elif command == "import":
262285
# For import blocks, we need the full resource data
263286
schema_classes = get_aws_resource_import_id_generators()
264-
blocks = [generate_import_block(r, schema_classes) for r in resources]
287+
blocks = [
288+
block
289+
for block in [
290+
generate_import_block(r, schema_classes, supported_providers_only)
291+
for r in resources
292+
]
293+
if block is not None
294+
]
265295
else:
266296
raise ValueError(f"Invalid command '{command}'")
267297
return blocks
@@ -295,6 +325,11 @@ def add_filter_args(cmd_parser: argparse.ArgumentParser):
295325

296326
import_parser = subparsers.add_parser("import", help="Generate import blocks")
297327
add_filter_args(import_parser)
328+
import_parser.add_argument(
329+
"--supported-providers-only",
330+
action="store_true",
331+
help="Only generate import IDs for supported providers (currently only AWS)",
332+
)
298333

299334
remove_parser = subparsers.add_parser("remove", help="Generate removed blocks")
300335
add_filter_args(remove_parser)
@@ -330,7 +365,10 @@ def main():
330365
return
331366

332367
blocks = generate_blocks_for_command(
333-
resources, args.command, getattr(args, "destroy", False)
368+
resources,
369+
args.command,
370+
getattr(args, "destroy", False),
371+
getattr(args, "supported_providers_only", False),
334372
)
335373

336374
if args.no_color:

tests/test_main.py

Lines changed: 73 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
import tempfile
21
import unittest
3-
from unittest.mock import mock_open, patch
2+
from unittest.mock import patch
43

54
from tfblocks import main
65

@@ -42,7 +41,7 @@ def test_module_match(self):
4241
"module.my_module[0].aws_s3_bucket.test", ["module.my_module"], []
4342
)
4443
)
45-
44+
4645
def test_resource_type_name_match(self):
4746
"""Test matching resources by type and name across different module paths"""
4847
# Resource in a module should match the same resource type and name from a file
@@ -51,14 +50,14 @@ def test_resource_type_name_match(self):
5150
"module.my_module.aws_s3_bucket.test", [], ["aws_s3_bucket.test"]
5251
)
5352
)
54-
53+
5554
# Different resource name should not match
5655
self.assertFalse(
5756
main.is_resource_match(
5857
"module.my_module.aws_s3_bucket.test", [], ["aws_s3_bucket.other"]
5958
)
6059
)
61-
60+
6261
# Different resource type should not match
6362
self.assertFalse(
6463
main.is_resource_match(
@@ -78,18 +77,24 @@ def test_wildcard_matching(self):
7877
self.assertFalse(
7978
main.is_resource_match("aws_lambda_function.test", ["aws_s3_bucket.*"], [])
8079
)
81-
80+
8281
# Match resources using wildcards in module paths
8382
self.assertTrue(
84-
main.is_resource_match("module.my_module.aws_s3_bucket.test", ["*.aws_s3_bucket.test"], [])
83+
main.is_resource_match(
84+
"module.my_module.aws_s3_bucket.test", ["*.aws_s3_bucket.test"], []
85+
)
8586
)
86-
87+
8788
# Match more complex patterns
8889
self.assertTrue(
89-
main.is_resource_match("module.my_module.aws_s3_bucket.test", ["module.*.aws_s3_bucket.*"], [])
90+
main.is_resource_match(
91+
"module.my_module.aws_s3_bucket.test", ["module.*.aws_s3_bucket.*"], []
92+
)
9093
)
9194
self.assertFalse(
92-
main.is_resource_match("other.my_module.aws_s3_bucket.test", ["module.*.aws_s3_bucket.*"], [])
95+
main.is_resource_match(
96+
"other.my_module.aws_s3_bucket.test", ["module.*.aws_s3_bucket.*"], []
97+
)
9398
)
9499

95100
def test_intersection_filter(self):
@@ -107,7 +112,7 @@ def test_intersection_filter(self):
107112

108113

109114
class TestFileProcessing(unittest.TestCase):
110-
def test_extract_resource_addresses_from_content(self):
115+
def test_extract_addresses_from_content(self):
111116
"""Test extracting resource addresses from content"""
112117
terraform_content = """
113118
resource "aws_s3_bucket" "bucket" {
@@ -123,28 +128,12 @@ def test_extract_resource_addresses_from_content(self):
123128
}
124129
"""
125130

126-
addresses = main.extract_resource_addresses_from_content(terraform_content)
131+
addresses = main.extract_addresses_from_content(terraform_content)
127132

128133
self.assertEqual(len(addresses), 3)
129134
self.assertIn("aws_s3_bucket.bucket", addresses)
130135
self.assertIn("aws_dynamodb_table.table", addresses)
131136
self.assertIn("module.vpc", addresses)
132-
133-
def test_file_exists(self):
134-
"""Test file existence check"""
135-
# Test with tempfile to avoid dependencies on filesystem
136-
with tempfile.NamedTemporaryFile() as temp_file:
137-
self.assertTrue(main.file_exists(temp_file.name))
138-
139-
# This path should not exist
140-
self.assertFalse(main.file_exists("/path/that/does/not/exist/file.tf"))
141-
142-
def test_extract_resource_addresses_nonexistent_file(self):
143-
"""Test behavior with nonexistent file"""
144-
with patch("tfblocks.main.file_exists", return_value=False), \
145-
patch("sys.exit") as mock_exit:
146-
main.extract_resource_addresses_from_file("nonexistent.tf")
147-
mock_exit.assert_called_once_with(1)
148137

149138
def test_filter_resources_basic(self):
150139
"""Test basic resource filtering without files"""
@@ -201,7 +190,7 @@ def test_filter_resources_basic(self):
201190
resources = main.filter_resources(test_state, ["module.test"])
202191
self.assertEqual(len(resources), 1)
203192
self.assertEqual(resources[0]["address"], "module.test.aws_s3_bucket.nested")
204-
193+
205194
def test_filter_resources_with_file_filters(self):
206195
"""Test resource filtering with file filters"""
207196
# Create a test state
@@ -221,13 +210,13 @@ def test_filter_resources_with_file_filters(self):
221210
}
222211

223212
# Mock the file handling part without mocking implementation details
224-
with patch("tfblocks.main.extract_resource_addresses_from_file") as mock_extract:
213+
with patch("tfblocks.main.extract_addresses_from_file") as mock_extract:
225214
# Case 1: File contains matching resource
226215
mock_extract.return_value = ["aws_s3_bucket.test"]
227216
resources = main.filter_resources(test_state, [], ["matching_file.tf"])
228217
self.assertEqual(len(resources), 1)
229218
self.assertEqual(resources[0]["address"], "aws_s3_bucket.test")
230-
219+
231220
# Case 2: File contains non-matching resource
232221
mock_extract.return_value = ["aws_lambda_function.not_in_state"]
233222
resources = main.filter_resources(test_state, [], ["non_matching_file.tf"])
@@ -250,6 +239,29 @@ def test_generate_import_block(self):
250239
self.assertIn("to = aws_s3_bucket.test", block)
251240
self.assertIn("id =", block)
252241

242+
def test_generate_import_block_for_unsupported_provider(self):
243+
"""Test generating an import block for an unsupported provider resource"""
244+
resource = {
245+
"address": "google_storage_bucket.test",
246+
"type": "google_storage_bucket",
247+
"values": {"name": "test-bucket"},
248+
}
249+
250+
schema_classes = {}
251+
252+
# By default, we should generate blocks for all providers
253+
block = main.generate_import_block(resource, schema_classes)
254+
self.assertIn("to = google_storage_bucket.test", block)
255+
self.assertIn("TODO", block)
256+
self.assertIn("google", block)
257+
self.assertIn("storage_bucket", block)
258+
259+
# When supported_providers_only=True, we should not generate a block
260+
block = main.generate_import_block(
261+
resource, schema_classes, supported_providers_only=True
262+
)
263+
self.assertIsNone(block)
264+
253265
def test_generate_removed_block(self):
254266
"""Test generating a removed block for a resource"""
255267
# Test with destroy=False (default)
@@ -303,6 +315,36 @@ def test_generate_blocks_for_command(self):
303315
# Should have 2 blocks, not 3, due to deduplication
304316
self.assertEqual(len(removed_blocks), 2)
305317

318+
def test_generate_blocks_with_mixed_providers(self):
319+
"""Test that supported_providers_only flag filters non-AWS resources"""
320+
resources = [
321+
{
322+
"address": "aws_s3_bucket.test",
323+
"type": "aws_s3_bucket",
324+
"values": {"bucket": "test-bucket"},
325+
},
326+
{
327+
"address": "google_storage_bucket.test",
328+
"type": "google_storage_bucket",
329+
"values": {"name": "test-bucket"},
330+
},
331+
{
332+
"address": "azurerm_storage_account.test",
333+
"type": "azurerm_storage_account",
334+
"values": {"name": "teststorage"},
335+
},
336+
]
337+
338+
# Test with all providers (default behavior)
339+
import_blocks = main.generate_blocks_for_command(resources, "import")
340+
self.assertEqual(len(import_blocks), 3) # All resources should be included
341+
342+
# Test with supported_providers_only=True
343+
import_blocks = main.generate_blocks_for_command(
344+
resources, "import", supported_providers_only=True
345+
)
346+
self.assertEqual(len(import_blocks), 1) # Only AWS resource should be included
347+
306348

307349
if __name__ == "__main__":
308350
unittest.main()

0 commit comments

Comments
 (0)