Skip to content

Commit c867b6f

Browse files
committed
fix: apply ruff formatting and remove unused imports in new test files
1 parent 535e9a8 commit c867b6f

5 files changed

Lines changed: 37 additions & 101 deletions

File tree

tests/test_case_expressions.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,8 @@
1818
Total: 16 test cases
1919
"""
2020

21-
import pytest
22-
2321
from clgraph import RecursiveLineageBuilder, SQLColumnTracer
2422

25-
2623
# ============================================================================
2724
# Test Group 1: Simple CASE WHEN
2825
# ============================================================================
@@ -154,10 +151,10 @@ def test_case_references_multiple_tables(self):
154151
ref_edges = [e for e in graph.edges if e.to_node.full_name == "output.reference_code"]
155152
source_tables = {e.from_node.table_name for e in ref_edges}
156153

157-
assert "orders" in source_tables or "o" in source_tables, \
158-
"Orders table should be a source"
159-
assert "shipments" in source_tables or "s" in source_tables, \
154+
assert "orders" in source_tables or "o" in source_tables, "Orders table should be a source"
155+
assert "shipments" in source_tables or "s" in source_tables, (
160156
"Shipments table should be a source"
157+
)
161158

162159
def test_case_multi_table_backward_lineage(self):
163160
"""Backward lineage should identify required inputs from all joined tables."""
@@ -182,8 +179,9 @@ def test_case_multi_table_backward_lineage(self):
182179
all_cols.update(cols)
183180

184181
assert "role" in all_cols, "Condition column from users should be required"
185-
assert "full_access_level" in all_cols or "basic_access_level" in all_cols, \
182+
assert "full_access_level" in all_cols or "basic_access_level" in all_cols, (
186183
"Result columns from permissions should be required"
184+
)
187185

188186

189187
# ============================================================================
@@ -294,8 +292,9 @@ def test_count_case_when(self):
294292
count_edges = [e for e in graph.edges if e.to_node.full_name == "output.active_count"]
295293
source_columns = {e.from_node.column_name for e in count_edges}
296294

297-
assert "is_active" in source_columns or "employee_id" in source_columns, \
295+
assert "is_active" in source_columns or "employee_id" in source_columns, (
298296
"At least the condition or result column should be traced"
297+
)
299298

300299

301300
# ============================================================================

tests/test_dialect_coverage.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import pytest
1313

14-
from clgraph import Pipeline, RecursiveLineageBuilder, SQLColumnTracer
14+
from clgraph import Pipeline, RecursiveLineageBuilder
1515

1616
# ============================================================================
1717
# Helper utilities
@@ -25,11 +25,7 @@ def _output_column_names(graph):
2525

2626
def _source_columns_for(graph, output_col):
2727
"""Return source column names that feed into *output_col*."""
28-
return {
29-
e.from_node.column_name
30-
for e in graph.edges
31-
if e.to_node.column_name == output_col
32-
}
28+
return {e.from_node.column_name for e in graph.edges if e.to_node.column_name == output_col}
3329

3430

3531
# ============================================================================
@@ -474,8 +470,7 @@ def test_cross_dialect_output_columns_match(self):
474470
baseline_cols = business_cols
475471
else:
476472
assert business_cols == baseline_cols, (
477-
f"Dialect {dialect} produced {business_cols}, "
478-
f"expected {baseline_cols}"
473+
f"Dialect {dialect} produced {business_cols}, expected {baseline_cols}"
479474
)
480475

481476

tests/test_join_types.py

Lines changed: 11 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232

3333
from clgraph import RecursiveLineageBuilder, SQLColumnTracer
3434

35-
3635
# ============================================================================
3736
# Test Group 1: INNER JOIN
3837
# ============================================================================
@@ -56,10 +55,7 @@ def test_inner_join_columns_from_both_tables(self):
5655
assert "output.order_id" in graph.nodes
5756
assert "output.amount" in graph.nodes
5857

59-
edges_dict = {
60-
(e.from_node.full_name, e.to_node.full_name): e
61-
for e in graph.edges
62-
}
58+
edges_dict = {(e.from_node.full_name, e.to_node.full_name): e for e in graph.edges}
6359
assert ("users.id", "output.id") in edges_dict
6460
assert ("users.name", "output.name") in edges_dict
6561
assert ("orders.order_id", "output.order_id") in edges_dict
@@ -103,10 +99,7 @@ def test_left_join_nullable_side_columns_tracked(self):
10399
assert "output.name" in graph.nodes
104100
assert "output.bio" in graph.nodes
105101

106-
edges_dict = {
107-
(e.from_node.full_name, e.to_node.full_name): e
108-
for e in graph.edges
109-
}
102+
edges_dict = {(e.from_node.full_name, e.to_node.full_name): e for e in graph.edges}
110103
assert ("users.id", "output.id") in edges_dict
111104
assert ("profiles.bio", "output.bio") in edges_dict
112105

@@ -145,10 +138,7 @@ def test_right_join_columns_tracked(self):
145138
assert "output.order_id" in graph.nodes
146139
assert "output.amount" in graph.nodes
147140

148-
edges_dict = {
149-
(e.from_node.full_name, e.to_node.full_name): e
150-
for e in graph.edges
151-
}
141+
edges_dict = {(e.from_node.full_name, e.to_node.full_name): e for e in graph.edges}
152142
assert ("users.name", "output.name") in edges_dict
153143
assert ("orders.order_id", "output.order_id") in edges_dict
154144

@@ -188,10 +178,7 @@ def test_full_outer_join_both_sides_tracked(self):
188178
assert "output.order_id" in graph.nodes
189179
assert "output.amount" in graph.nodes
190180

191-
edges_dict = {
192-
(e.from_node.full_name, e.to_node.full_name): e
193-
for e in graph.edges
194-
}
181+
edges_dict = {(e.from_node.full_name, e.to_node.full_name): e for e in graph.edges}
195182
assert ("users.id", "output.id") in edges_dict
196183
assert ("orders.amount", "output.amount") in edges_dict
197184

@@ -232,10 +219,7 @@ def test_cross_join_no_on_clause(self):
232219
assert "output.name" in graph.nodes
233220
assert "output.color_name" in graph.nodes
234221

235-
edges_dict = {
236-
(e.from_node.full_name, e.to_node.full_name): e
237-
for e in graph.edges
238-
}
222+
edges_dict = {(e.from_node.full_name, e.to_node.full_name): e for e in graph.edges}
239223
assert ("users.name", "output.name") in edges_dict
240224
assert ("colors.color_name", "output.color_name") in edges_dict
241225

@@ -262,12 +246,8 @@ def test_self_join_with_aliases(self):
262246
assert "output.manager" in graph.nodes
263247

264248
# Both output columns trace back to employees.name
265-
employee_edges = [
266-
e for e in graph.edges if e.to_node.full_name == "output.employee"
267-
]
268-
manager_edges = [
269-
e for e in graph.edges if e.to_node.full_name == "output.manager"
270-
]
249+
employee_edges = [e for e in graph.edges if e.to_node.full_name == "output.employee"]
250+
manager_edges = [e for e in graph.edges if e.to_node.full_name == "output.manager"]
271251

272252
assert len(employee_edges) > 0
273253
assert len(manager_edges) > 0
@@ -311,10 +291,7 @@ def test_three_table_join_chain(self):
311291
assert "output.order_id" in graph.nodes
312292
assert "output.product_name" in graph.nodes
313293

314-
edges_dict = {
315-
(e.from_node.full_name, e.to_node.full_name): e
316-
for e in graph.edges
317-
}
294+
edges_dict = {(e.from_node.full_name, e.to_node.full_name): e for e in graph.edges}
318295
assert ("users.name", "output.name") in edges_dict
319296
assert ("orders.order_id", "output.order_id") in edges_dict
320297
assert ("products.product_name", "output.product_name") in edges_dict
@@ -362,9 +339,7 @@ def test_join_against_subquery(self):
362339
assert "output.total_amount" in graph.nodes
363340

364341
# total_amount should trace back through the subquery to orders.amount
365-
total_edges = [
366-
e for e in graph.edges if e.to_node.full_name == "output.total_amount"
367-
]
342+
total_edges = [e for e in graph.edges if e.to_node.full_name == "output.total_amount"]
368343
assert len(total_edges) > 0
369344

370345
def test_join_subquery_backward_lineage(self):
@@ -406,12 +381,7 @@ def test_on_clause_columns_in_graph(self):
406381
assert "output.name" in graph.nodes
407382
assert "output.amount" in graph.nodes
408383

409-
# The join-key columns should appear as source nodes in the graph
410-
all_source_names = {
411-
e.from_node.full_name for e in graph.edges
412-
}
413-
# users.id and orders.user_id are join keys; at minimum the selected
414-
# columns (users.name, orders.amount) must have edges
384+
# At minimum the selected columns (users.name, orders.amount) must have edges
415385
assert ("users.name", "output.name") in {
416386
(e.from_node.full_name, e.to_node.full_name) for e in graph.edges
417387
}
@@ -447,10 +417,7 @@ def test_inner_join_lineage_per_dialect(self, dialect):
447417
assert "output.order_id" in graph.nodes
448418
assert "output.amount" in graph.nodes
449419

450-
edges_dict = {
451-
(e.from_node.full_name, e.to_node.full_name): e
452-
for e in graph.edges
453-
}
420+
edges_dict = {(e.from_node.full_name, e.to_node.full_name): e for e in graph.edges}
454421
assert ("users.id", "output.id") in edges_dict
455422
assert ("orders.amount", "output.amount") in edges_dict
456423

tests/test_regex_functions.py

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414

1515
import pytest
1616

17-
from clgraph import RecursiveLineageBuilder, SQLColumnTracer
18-
17+
from clgraph import RecursiveLineageBuilder
1918

2019
# ============================================================================
2120
# Test Group 1: REGEXP_CONTAINS (BigQuery)
@@ -36,9 +35,7 @@ def test_regexp_contains_in_select(self):
3635
builder = RecursiveLineageBuilder(sql, dialect="bigquery")
3736
graph = builder.build()
3837

39-
is_example_edges = [
40-
e for e in graph.edges if e.to_node.column_name == "is_example_domain"
41-
]
38+
is_example_edges = [e for e in graph.edges if e.to_node.column_name == "is_example_domain"]
4239
assert len(is_example_edges) > 0
4340

4441
source_columns = {e.from_node.column_name for e in is_example_edges}
@@ -81,9 +78,7 @@ def test_snowflake_regexp_substr(self):
8178
builder = RecursiveLineageBuilder(sql, dialect="snowflake")
8279
graph = builder.build()
8380

84-
area_code_edges = [
85-
e for e in graph.edges if e.to_node.column_name == "area_code"
86-
]
81+
area_code_edges = [e for e in graph.edges if e.to_node.column_name == "area_code"]
8782
assert len(area_code_edges) > 0
8883

8984
source_columns = {e.from_node.column_name for e in area_code_edges}
@@ -108,9 +103,7 @@ def test_bigquery_regexp_replace(self):
108103
builder = RecursiveLineageBuilder(sql, dialect="bigquery")
109104
graph = builder.build()
110105

111-
clean_phone_edges = [
112-
e for e in graph.edges if e.to_node.column_name == "clean_phone"
113-
]
106+
clean_phone_edges = [e for e in graph.edges if e.to_node.column_name == "clean_phone"]
114107
assert len(clean_phone_edges) > 0
115108

116109
source_columns = {e.from_node.column_name for e in clean_phone_edges}
@@ -126,9 +119,7 @@ def test_snowflake_regexp_replace(self):
126119
builder = RecursiveLineageBuilder(sql, dialect="snowflake")
127120
graph = builder.build()
128121

129-
sanitized_edges = [
130-
e for e in graph.edges if e.to_node.column_name == "sanitized_address"
131-
]
122+
sanitized_edges = [e for e in graph.edges if e.to_node.column_name == "sanitized_address"]
132123
assert len(sanitized_edges) > 0
133124

134125
source_columns = {e.from_node.column_name for e in sanitized_edges}
@@ -178,9 +169,7 @@ def test_tilde_operator_in_where(self):
178169
builder = RecursiveLineageBuilder(sql, dialect="postgres")
179170
graph = builder.build()
180171

181-
desc_edges = [
182-
e for e in graph.edges if e.to_node.column_name == "description"
183-
]
172+
desc_edges = [e for e in graph.edges if e.to_node.column_name == "description"]
184173
assert len(desc_edges) > 0
185174

186175
source_columns = {e.from_node.column_name for e in desc_edges}
@@ -209,9 +198,7 @@ def test_regexp_extract_in_case_when(self):
209198
builder = RecursiveLineageBuilder(sql, dialect="bigquery")
210199
graph = builder.build()
211200

212-
protocol_edges = [
213-
e for e in graph.edges if e.to_node.column_name == "protocol_type"
214-
]
201+
protocol_edges = [e for e in graph.edges if e.to_node.column_name == "protocol_type"]
215202
assert len(protocol_edges) > 0
216203

217204
source_columns = {e.from_node.column_name for e in protocol_edges}
@@ -231,9 +218,7 @@ def test_regexp_replace_in_case_value(self):
231218
builder = RecursiveLineageBuilder(sql, dialect="bigquery")
232219
graph = builder.build()
233220

234-
phone_edges = [
235-
e for e in graph.edges if e.to_node.column_name == "normalized_phone"
236-
]
221+
phone_edges = [e for e in graph.edges if e.to_node.column_name == "normalized_phone"]
237222
assert len(phone_edges) > 0
238223

239224
source_columns = {e.from_node.column_name for e in phone_edges}
@@ -262,15 +247,11 @@ def test_regexp_where_preserves_select_lineage(self):
262247
graph = builder.build()
263248

264249
# user_id should still trace to source
265-
user_id_edges = [
266-
e for e in graph.edges if e.to_node.column_name == "user_id"
267-
]
250+
user_id_edges = [e for e in graph.edges if e.to_node.column_name == "user_id"]
268251
assert len(user_id_edges) > 0
269252

270253
# email_domain should trace back to email
271-
domain_edges = [
272-
e for e in graph.edges if e.to_node.column_name == "email_domain"
273-
]
254+
domain_edges = [e for e in graph.edges if e.to_node.column_name == "email_domain"]
274255
assert len(domain_edges) > 0
275256

276257
source_columns = {e.from_node.column_name for e in domain_edges}
@@ -295,9 +276,7 @@ def test_regexp_replace_bigquery(self):
295276
builder = RecursiveLineageBuilder(sql, dialect="bigquery")
296277
graph = builder.build()
297278

298-
clean_edges = [
299-
e for e in graph.edges if e.to_node.column_name == "clean_name"
300-
]
279+
clean_edges = [e for e in graph.edges if e.to_node.column_name == "clean_name"]
301280
assert len(clean_edges) > 0
302281

303282
source_columns = {e.from_node.column_name for e in clean_edges}
@@ -313,9 +292,7 @@ def test_regexp_replace_postgres(self):
313292
builder = RecursiveLineageBuilder(sql, dialect="postgres")
314293
graph = builder.build()
315294

316-
clean_edges = [
317-
e for e in graph.edges if e.to_node.column_name == "clean_name"
318-
]
295+
clean_edges = [e for e in graph.edges if e.to_node.column_name == "clean_name"]
319296
assert len(clean_edges) > 0
320297

321298
source_columns = {e.from_node.column_name for e in clean_edges}
@@ -331,9 +308,7 @@ def test_regexp_replace_snowflake(self):
331308
builder = RecursiveLineageBuilder(sql, dialect="snowflake")
332309
graph = builder.build()
333310

334-
clean_edges = [
335-
e for e in graph.edges if e.to_node.column_name == "clean_name"
336-
]
311+
clean_edges = [e for e in graph.edges if e.to_node.column_name == "clean_name"]
337312
assert len(clean_edges) > 0
338313

339314
source_columns = {e.from_node.column_name for e in clean_edges}

tests/test_type_casting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,9 @@ def test_cast_lineage_across_dialects(self, dialect):
267267

268268
edges = [e for e in graph.edges if e.to_node.column_name == "amount_int"]
269269
assert len(edges) > 0, f"No edges for amount_int in dialect={dialect}"
270-
assert any(
271-
e.from_node.column_name == "amount" for e in edges
272-
), f"Source column 'amount' not found in dialect={dialect}"
270+
assert any(e.from_node.column_name == "amount" for e in edges), (
271+
f"Source column 'amount' not found in dialect={dialect}"
272+
)
273273

274274
@pytest.mark.parametrize("dialect", ["bigquery", "postgres", "snowflake"])
275275
def test_cast_backward_lineage_across_dialects(self, dialect):

0 commit comments

Comments
 (0)