Skip to content

Commit 12efe47

Browse files
authored
ingest(snowflake): Emulate tag inheritance in-memory to eliminate N+1 queries (datahub-project#16400)
1 parent fe5a9c8 commit 12efe47

6 files changed

Lines changed: 736 additions & 203 deletions

File tree

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -389,19 +389,6 @@ def get_all_tags():
389389
ORDER BY TAG_DATABASE, TAG_SCHEMA, TAG_NAME ASC;
390390
"""
391391

392-
@staticmethod
393-
def get_all_tags_on_object_with_propagation(
394-
db_name: str, quoted_identifier: str, domain: str
395-
) -> str:
396-
# https://docs.snowflake.com/en/sql-reference/functions/tag_references.html
397-
return f"""
398-
SELECT tag_database as "TAG_DATABASE",
399-
tag_schema AS "TAG_SCHEMA",
400-
tag_name AS "TAG_NAME",
401-
tag_value AS "TAG_VALUE"
402-
FROM table("{db_name}".information_schema.tag_references('{quoted_identifier}', '{domain}'));
403-
"""
404-
405392
@staticmethod
406393
def get_all_tags_in_database_without_propagation(db_name: str) -> str:
407394
allowed_object_domains = (
@@ -430,20 +417,6 @@ def get_all_tags_in_database_without_propagation(db_name: str) -> str:
430417
AND object_deleted IS NULL;
431418
"""
432419

433-
@staticmethod
434-
def get_tags_on_columns_with_propagation(
435-
db_name: str, quoted_table_identifier: str
436-
) -> str:
437-
# https://docs.snowflake.com/en/sql-reference/functions/tag_references_all_columns.html
438-
return f"""
439-
SELECT tag_database as "TAG_DATABASE",
440-
tag_schema AS "TAG_SCHEMA",
441-
tag_name AS "TAG_NAME",
442-
tag_value AS "TAG_VALUE",
443-
column_name AS "COLUMN_NAME"
444-
FROM table("{db_name}".information_schema.tag_references_all_columns('{quoted_table_identifier}', '{SnowflakeObjectDomain.TABLE}'));
445-
"""
446-
447420
@staticmethod
448421
def show_views_for_database(
449422
db_name: str,

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,6 @@ class SnowflakeV2Report(
134134
num_get_tables_for_schema_queries: int = 0
135135
num_get_views_for_schema_queries: int = 0
136136

137-
# these will be non-zero if the user choses to enable the extract_tags = "with_lineage" option, which requires
138-
# individual queries per object (database, schema, table) and an extra query per table to get the tags on the columns.
139-
num_get_tags_for_object_queries: int = 0
140-
num_get_tags_on_columns_for_table_queries: int = 0
141-
142137
num_get_streams_for_schema_queries: int = 0
143138

144139
rows_zero_objects_modified: int = 0

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py

Lines changed: 160 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import logging
33
from collections import defaultdict
4-
from dataclasses import dataclass, field
4+
from dataclasses import dataclass, field, replace
55
from datetime import datetime
66
from typing import (
77
Any,
@@ -85,6 +85,15 @@ class SnowflakeTag:
8585
schema: str
8686
name: str
8787
value: str
88+
inherited_from: Optional[SnowflakeObjectDomain] = None
89+
90+
@property
91+
def is_inherited(self) -> bool:
92+
return self.inherited_from is not None
93+
94+
def as_inherited(self, level: SnowflakeObjectDomain) -> "SnowflakeTag":
95+
"""Return a copy marked as inherited from a parent object *level*."""
96+
return replace(self, inherited_from=level)
8897

8998
def tag_display_name(self) -> str:
9099
return f"{self.name}: {self.value}"
@@ -401,7 +410,7 @@ def add_database_tag(self, db_name: str, tag: SnowflakeTag) -> None:
401410
self._database_tags[db_name].append(tag)
402411

403412
def get_database_tags(self, db_name: str) -> List[SnowflakeTag]:
404-
return self._database_tags[db_name]
413+
return self._database_tags.get(db_name, [])
405414

406415
def add_schema_tag(self, schema_name: str, db_name: str, tag: SnowflakeTag) -> None:
407416
self._schema_tags[db_name][schema_name].append(tag)
@@ -417,7 +426,9 @@ def add_table_tag(
417426
def get_table_tags(
418427
self, table_name: str, schema_name: str, db_name: str
419428
) -> List[SnowflakeTag]:
420-
return self._table_tags[db_name][schema_name][table_name]
429+
return (
430+
self._table_tags.get(db_name, {}).get(schema_name, {}).get(table_name, [])
431+
)
421432

422433
def add_column_tag(
423434
self,
@@ -436,6 +447,99 @@ def get_column_tags_for_table(
436447
self._column_tags.get(db_name, {}).get(schema_name, {}).get(table_name, {})
437448
)
438449

450+
# --- Inheritance-aware methods ---
451+
# These emulate Snowflake's tag inheritance: a tag set on a database is
452+
# inherited by its schemas, tables, and columns. A more-specific
453+
# assignment (e.g. directly on the table) overrides the inherited value.
454+
455+
@staticmethod
456+
def _deduplicate_tags(tags: List[SnowflakeTag]) -> List[SnowflakeTag]:
457+
"""Deduplicate tags by (database, schema, name), preferring direct over inherited."""
458+
best: Dict[tuple, SnowflakeTag] = {}
459+
for tag in tags:
460+
key = (tag.database, tag.schema, tag.name)
461+
existing = best.get(key)
462+
if existing is None or (existing.is_inherited and not tag.is_inherited):
463+
best[key] = tag
464+
return list(best.values())
465+
466+
@staticmethod
467+
def _mark_inherited(
468+
tags: List[SnowflakeTag], level: SnowflakeObjectDomain
469+
) -> List[SnowflakeTag]:
470+
return [t.as_inherited(level) for t in tags]
471+
472+
def get_schema_tags_with_inheritance(
473+
self, schema_name: str, db_name: str
474+
) -> List[SnowflakeTag]:
475+
direct = self.get_schema_tags(schema_name, db_name)
476+
inherited = self._mark_inherited(
477+
self.get_database_tags(db_name), SnowflakeObjectDomain.DATABASE
478+
)
479+
return self._deduplicate_tags(direct + inherited)
480+
481+
def get_table_tags_with_inheritance(
482+
self, table_name: str, schema_name: str, db_name: str
483+
) -> List[SnowflakeTag]:
484+
direct = self.get_table_tags(table_name, schema_name, db_name)
485+
schema_inherited = self._mark_inherited(
486+
self.get_schema_tags(schema_name, db_name),
487+
SnowflakeObjectDomain.SCHEMA,
488+
)
489+
db_inherited = self._mark_inherited(
490+
self.get_database_tags(db_name), SnowflakeObjectDomain.DATABASE
491+
)
492+
return self._deduplicate_tags(direct + schema_inherited + db_inherited)
493+
494+
def get_column_tags_for_table_with_inheritance(
495+
self,
496+
table_name: str,
497+
schema_name: str,
498+
db_name: str,
499+
column_names: Optional[List[str]] = None,
500+
) -> Dict[str, List[SnowflakeTag]]:
501+
"""Return column tags with inheritance from table, schema, and database levels.
502+
503+
Args:
504+
column_names: All column names in the table. When provided,
505+
inherited parent tags are applied to every column, not just
506+
those with direct column tags.
507+
"""
508+
direct_column_tags = self.get_column_tags_for_table(
509+
table_name, schema_name, db_name
510+
)
511+
512+
# Tags inherited by every column in this table
513+
parent_tags = (
514+
self._mark_inherited(
515+
self.get_table_tags(table_name, schema_name, db_name),
516+
SnowflakeObjectDomain.TABLE,
517+
)
518+
+ self._mark_inherited(
519+
self.get_schema_tags(schema_name, db_name),
520+
SnowflakeObjectDomain.SCHEMA,
521+
)
522+
+ self._mark_inherited(
523+
self.get_database_tags(db_name),
524+
SnowflakeObjectDomain.DATABASE,
525+
)
526+
)
527+
528+
if not parent_tags:
529+
return dict(direct_column_tags)
530+
531+
# Apply parent tags to all known columns, merging with direct tags
532+
all_columns = (
533+
column_names
534+
if column_names is not None
535+
else list(direct_column_tags.keys())
536+
)
537+
result: Dict[str, List[SnowflakeTag]] = {}
538+
for col_name in all_columns:
539+
col_tags = list(direct_column_tags.get(col_name, []))
540+
result[col_name] = self._deduplicate_tags(col_tags + parent_tags)
541+
return result
542+
439543

440544
class SnowflakeDataDictionary(SupportsAsObj):
441545
def __init__(
@@ -1493,11 +1597,23 @@ def get_tags_for_database_without_propagation(
14931597
tags = _SnowflakeTagCache()
14941598

14951599
for tag in cur:
1600+
tag_db = tag["TAG_DATABASE"]
1601+
tag_schema = tag["TAG_SCHEMA"]
1602+
tag_name = tag["TAG_NAME"]
1603+
tag_value = tag["TAG_VALUE"]
1604+
if tag_db is None or tag_schema is None or tag_name is None:
1605+
logger.warning(
1606+
f"Skipping tag with null definition fields: "
1607+
f"TAG_DATABASE={tag_db}, TAG_SCHEMA={tag_schema}, "
1608+
f"TAG_NAME={tag_name}"
1609+
)
1610+
continue
1611+
14961612
snowflake_tag = SnowflakeTag(
1497-
database=tag["TAG_DATABASE"],
1498-
schema=tag["TAG_SCHEMA"],
1499-
name=tag["TAG_NAME"],
1500-
value=tag["TAG_VALUE"],
1613+
database=tag_db,
1614+
schema=tag_schema,
1615+
name=tag_name,
1616+
value=tag_value or "",
15011617
)
15021618

15031619
# This is the name of the object, unless the object is a column, in which
@@ -1508,17 +1624,47 @@ def get_tags_for_database_without_propagation(
15081624
# This will be null if the object is a database
15091625
object_database = tag["OBJECT_DATABASE"]
15101626

1511-
domain = tag["DOMAIN"].lower()
1627+
raw_domain = tag["DOMAIN"]
1628+
if raw_domain is None:
1629+
logger.warning(
1630+
f"Skipping tag with null DOMAIN: "
1631+
f"tag={tag_name}, object_name={object_name}"
1632+
)
1633+
continue
1634+
domain = raw_domain.lower()
15121635
if domain == SnowflakeObjectDomain.DATABASE:
15131636
tags.add_database_tag(object_name, snowflake_tag)
15141637
elif domain == SnowflakeObjectDomain.SCHEMA:
1638+
if object_database is None:
1639+
logger.warning(
1640+
f"Skipping schema tag with null OBJECT_DATABASE: "
1641+
f"tag={snowflake_tag.name}, object_name={object_name}"
1642+
)
1643+
continue
15151644
tags.add_schema_tag(object_name, object_database, snowflake_tag)
15161645
elif domain == SnowflakeObjectDomain.TABLE: # including views
1646+
if object_schema is None or object_database is None:
1647+
logger.warning(
1648+
f"Skipping table tag with null OBJECT_SCHEMA/OBJECT_DATABASE: "
1649+
f"tag={snowflake_tag.name}, object_name={object_name}"
1650+
)
1651+
continue
15171652
tags.add_table_tag(
15181653
object_name, object_schema, object_database, snowflake_tag
15191654
)
15201655
elif domain == SnowflakeObjectDomain.COLUMN:
15211656
column_name = tag["COLUMN_NAME"]
1657+
if (
1658+
column_name is None
1659+
or object_schema is None
1660+
or object_database is None
1661+
):
1662+
logger.warning(
1663+
f"Skipping column tag with null fields: "
1664+
f"tag={snowflake_tag.name}, object_name={object_name}, "
1665+
f"column_name={column_name}"
1666+
)
1667+
continue
15221668
tags.add_column_tag(
15231669
column_name,
15241670
object_name,
@@ -1527,56 +1673,13 @@ def get_tags_for_database_without_propagation(
15271673
snowflake_tag,
15281674
)
15291675
else:
1530-
# This should never happen.
1531-
logger.error(f"Encountered an unexpected domain: {domain}")
1532-
continue
1533-
1534-
return tags
1535-
1536-
def get_tags_for_object_with_propagation(
1537-
self,
1538-
domain: str,
1539-
quoted_identifier: str,
1540-
db_name: str,
1541-
) -> List[SnowflakeTag]:
1542-
tags: List[SnowflakeTag] = []
1543-
1544-
cur = self.connection.query(
1545-
SnowflakeQuery.get_all_tags_on_object_with_propagation(
1546-
db_name, quoted_identifier, domain
1547-
),
1548-
)
1549-
1550-
for tag in cur:
1551-
tags.append(
1552-
SnowflakeTag(
1553-
database=tag["TAG_DATABASE"],
1554-
schema=tag["TAG_SCHEMA"],
1555-
name=tag["TAG_NAME"],
1556-
value=tag["TAG_VALUE"],
1676+
self.report.warning(
1677+
title="Unexpected tag domain encountered",
1678+
message=f"Tag '{snowflake_tag.name}' has domain '{domain}' which is not "
1679+
"recognized. This tag will be skipped.",
1680+
context=f"database={db_name}, object={object_name}",
15571681
)
1558-
)
1559-
return tags
1560-
1561-
def get_tags_on_columns_for_table(
1562-
self, quoted_table_name: str, db_name: str
1563-
) -> Dict[str, List[SnowflakeTag]]:
1564-
tags: Dict[str, List[SnowflakeTag]] = defaultdict(list)
1565-
cur = self.connection.query(
1566-
SnowflakeQuery.get_tags_on_columns_with_propagation(
1567-
db_name, quoted_table_name
1568-
),
1569-
)
1570-
1571-
for tag in cur:
1572-
column_name = tag["COLUMN_NAME"]
1573-
snowflake_tag = SnowflakeTag(
1574-
database=tag["TAG_DATABASE"],
1575-
schema=tag["TAG_SCHEMA"],
1576-
name=tag["TAG_NAME"],
1577-
value=tag["TAG_VALUE"],
1578-
)
1579-
tags[column_name].append(snowflake_tag)
1682+
continue
15801683

15811684
return tags
15821685

0 commit comments

Comments
 (0)