11import json
22import logging
33from collections import defaultdict
4- from dataclasses import dataclass , field
4+ from dataclasses import dataclass , field , replace
55from datetime import datetime
66from typing import (
77 Any ,
@@ -85,6 +85,15 @@ class SnowflakeTag:
8585 schema : str
8686 name : str
8787 value : str
88+ inherited_from : Optional [SnowflakeObjectDomain ] = None
89+
90+ @property
91+ def is_inherited (self ) -> bool :
92+ return self .inherited_from is not None
93+
94+ def as_inherited (self , level : SnowflakeObjectDomain ) -> "SnowflakeTag" :
95+ """Return a copy marked as inherited from a parent object *level*."""
96+ return replace (self , inherited_from = level )
8897
8998 def tag_display_name (self ) -> str :
9099 return f"{ self .name } : { self .value } "
@@ -401,7 +410,7 @@ def add_database_tag(self, db_name: str, tag: SnowflakeTag) -> None:
401410 self ._database_tags [db_name ].append (tag )
402411
403412 def get_database_tags (self , db_name : str ) -> List [SnowflakeTag ]:
404- return self ._database_tags [ db_name ]
413+ return self ._database_tags . get ( db_name , [])
405414
406415 def add_schema_tag (self , schema_name : str , db_name : str , tag : SnowflakeTag ) -> None :
407416 self ._schema_tags [db_name ][schema_name ].append (tag )
@@ -417,7 +426,9 @@ def add_table_tag(
417426 def get_table_tags (
418427 self , table_name : str , schema_name : str , db_name : str
419428 ) -> List [SnowflakeTag ]:
420- return self ._table_tags [db_name ][schema_name ][table_name ]
429+ return (
430+ self ._table_tags .get (db_name , {}).get (schema_name , {}).get (table_name , [])
431+ )
421432
422433 def add_column_tag (
423434 self ,
@@ -436,6 +447,99 @@ def get_column_tags_for_table(
436447 self ._column_tags .get (db_name , {}).get (schema_name , {}).get (table_name , {})
437448 )
438449
450+ # --- Inheritance-aware methods ---
451+ # These emulate Snowflake's tag inheritance: a tag set on a database is
452+ # inherited by its schemas, tables, and columns. A more-specific
453+ # assignment (e.g. directly on the table) overrides the inherited value.
454+
455+ @staticmethod
456+ def _deduplicate_tags (tags : List [SnowflakeTag ]) -> List [SnowflakeTag ]:
457+ """Deduplicate tags by (database, schema, name), preferring direct over inherited."""
458+ best : Dict [tuple , SnowflakeTag ] = {}
459+ for tag in tags :
460+ key = (tag .database , tag .schema , tag .name )
461+ existing = best .get (key )
462+ if existing is None or (existing .is_inherited and not tag .is_inherited ):
463+ best [key ] = tag
464+ return list (best .values ())
465+
466+ @staticmethod
467+ def _mark_inherited (
468+ tags : List [SnowflakeTag ], level : SnowflakeObjectDomain
469+ ) -> List [SnowflakeTag ]:
470+ return [t .as_inherited (level ) for t in tags ]
471+
472+ def get_schema_tags_with_inheritance (
473+ self , schema_name : str , db_name : str
474+ ) -> List [SnowflakeTag ]:
475+ direct = self .get_schema_tags (schema_name , db_name )
476+ inherited = self ._mark_inherited (
477+ self .get_database_tags (db_name ), SnowflakeObjectDomain .DATABASE
478+ )
479+ return self ._deduplicate_tags (direct + inherited )
480+
481+ def get_table_tags_with_inheritance (
482+ self , table_name : str , schema_name : str , db_name : str
483+ ) -> List [SnowflakeTag ]:
484+ direct = self .get_table_tags (table_name , schema_name , db_name )
485+ schema_inherited = self ._mark_inherited (
486+ self .get_schema_tags (schema_name , db_name ),
487+ SnowflakeObjectDomain .SCHEMA ,
488+ )
489+ db_inherited = self ._mark_inherited (
490+ self .get_database_tags (db_name ), SnowflakeObjectDomain .DATABASE
491+ )
492+ return self ._deduplicate_tags (direct + schema_inherited + db_inherited )
493+
494+ def get_column_tags_for_table_with_inheritance (
495+ self ,
496+ table_name : str ,
497+ schema_name : str ,
498+ db_name : str ,
499+ column_names : Optional [List [str ]] = None ,
500+ ) -> Dict [str , List [SnowflakeTag ]]:
501+ """Return column tags with inheritance from table, schema, and database levels.
502+
503+ Args:
504+ column_names: All column names in the table. When provided,
505+ inherited parent tags are applied to every column, not just
506+ those with direct column tags.
507+ """
508+ direct_column_tags = self .get_column_tags_for_table (
509+ table_name , schema_name , db_name
510+ )
511+
512+ # Tags inherited by every column in this table
513+ parent_tags = (
514+ self ._mark_inherited (
515+ self .get_table_tags (table_name , schema_name , db_name ),
516+ SnowflakeObjectDomain .TABLE ,
517+ )
518+ + self ._mark_inherited (
519+ self .get_schema_tags (schema_name , db_name ),
520+ SnowflakeObjectDomain .SCHEMA ,
521+ )
522+ + self ._mark_inherited (
523+ self .get_database_tags (db_name ),
524+ SnowflakeObjectDomain .DATABASE ,
525+ )
526+ )
527+
528+ if not parent_tags :
529+ return dict (direct_column_tags )
530+
531+ # Apply parent tags to all known columns, merging with direct tags
532+ all_columns = (
533+ column_names
534+ if column_names is not None
535+ else list (direct_column_tags .keys ())
536+ )
537+ result : Dict [str , List [SnowflakeTag ]] = {}
538+ for col_name in all_columns :
539+ col_tags = list (direct_column_tags .get (col_name , []))
540+ result [col_name ] = self ._deduplicate_tags (col_tags + parent_tags )
541+ return result
542+
439543
440544class SnowflakeDataDictionary (SupportsAsObj ):
441545 def __init__ (
@@ -1493,11 +1597,23 @@ def get_tags_for_database_without_propagation(
14931597 tags = _SnowflakeTagCache ()
14941598
14951599 for tag in cur :
1600+ tag_db = tag ["TAG_DATABASE" ]
1601+ tag_schema = tag ["TAG_SCHEMA" ]
1602+ tag_name = tag ["TAG_NAME" ]
1603+ tag_value = tag ["TAG_VALUE" ]
1604+ if tag_db is None or tag_schema is None or tag_name is None :
1605+ logger .warning (
1606+ f"Skipping tag with null definition fields: "
1607+ f"TAG_DATABASE={ tag_db } , TAG_SCHEMA={ tag_schema } , "
1608+ f"TAG_NAME={ tag_name } "
1609+ )
1610+ continue
1611+
14961612 snowflake_tag = SnowflakeTag (
1497- database = tag [ "TAG_DATABASE" ] ,
1498- schema = tag [ "TAG_SCHEMA" ] ,
1499- name = tag [ "TAG_NAME" ] ,
1500- value = tag [ "TAG_VALUE" ] ,
1613+ database = tag_db ,
1614+ schema = tag_schema ,
1615+ name = tag_name ,
1616+ value = tag_value or "" ,
15011617 )
15021618
15031619 # This is the name of the object, unless the object is a column, in which
@@ -1508,17 +1624,47 @@ def get_tags_for_database_without_propagation(
15081624 # This will be null if the object is a database
15091625 object_database = tag ["OBJECT_DATABASE" ]
15101626
1511- domain = tag ["DOMAIN" ].lower ()
1627+ raw_domain = tag ["DOMAIN" ]
1628+ if raw_domain is None :
1629+ logger .warning (
1630+ f"Skipping tag with null DOMAIN: "
1631+ f"tag={ tag_name } , object_name={ object_name } "
1632+ )
1633+ continue
1634+ domain = raw_domain .lower ()
15121635 if domain == SnowflakeObjectDomain .DATABASE :
15131636 tags .add_database_tag (object_name , snowflake_tag )
15141637 elif domain == SnowflakeObjectDomain .SCHEMA :
1638+ if object_database is None :
1639+ logger .warning (
1640+ f"Skipping schema tag with null OBJECT_DATABASE: "
1641+ f"tag={ snowflake_tag .name } , object_name={ object_name } "
1642+ )
1643+ continue
15151644 tags .add_schema_tag (object_name , object_database , snowflake_tag )
15161645 elif domain == SnowflakeObjectDomain .TABLE : # including views
1646+ if object_schema is None or object_database is None :
1647+ logger .warning (
1648+ f"Skipping table tag with null OBJECT_SCHEMA/OBJECT_DATABASE: "
1649+ f"tag={ snowflake_tag .name } , object_name={ object_name } "
1650+ )
1651+ continue
15171652 tags .add_table_tag (
15181653 object_name , object_schema , object_database , snowflake_tag
15191654 )
15201655 elif domain == SnowflakeObjectDomain .COLUMN :
15211656 column_name = tag ["COLUMN_NAME" ]
1657+ if (
1658+ column_name is None
1659+ or object_schema is None
1660+ or object_database is None
1661+ ):
1662+ logger .warning (
1663+ f"Skipping column tag with null fields: "
1664+ f"tag={ snowflake_tag .name } , object_name={ object_name } , "
1665+ f"column_name={ column_name } "
1666+ )
1667+ continue
15221668 tags .add_column_tag (
15231669 column_name ,
15241670 object_name ,
@@ -1527,56 +1673,13 @@ def get_tags_for_database_without_propagation(
15271673 snowflake_tag ,
15281674 )
15291675 else :
1530- # This should never happen.
1531- logger .error (f"Encountered an unexpected domain: { domain } " )
1532- continue
1533-
1534- return tags
1535-
1536- def get_tags_for_object_with_propagation (
1537- self ,
1538- domain : str ,
1539- quoted_identifier : str ,
1540- db_name : str ,
1541- ) -> List [SnowflakeTag ]:
1542- tags : List [SnowflakeTag ] = []
1543-
1544- cur = self .connection .query (
1545- SnowflakeQuery .get_all_tags_on_object_with_propagation (
1546- db_name , quoted_identifier , domain
1547- ),
1548- )
1549-
1550- for tag in cur :
1551- tags .append (
1552- SnowflakeTag (
1553- database = tag ["TAG_DATABASE" ],
1554- schema = tag ["TAG_SCHEMA" ],
1555- name = tag ["TAG_NAME" ],
1556- value = tag ["TAG_VALUE" ],
1676+ self .report .warning (
1677+ title = "Unexpected tag domain encountered" ,
1678+ message = f"Tag '{ snowflake_tag .name } ' has domain '{ domain } ' which is not "
1679+ "recognized. This tag will be skipped." ,
1680+ context = f"database={ db_name } , object={ object_name } " ,
15571681 )
1558- )
1559- return tags
1560-
1561- def get_tags_on_columns_for_table (
1562- self , quoted_table_name : str , db_name : str
1563- ) -> Dict [str , List [SnowflakeTag ]]:
1564- tags : Dict [str , List [SnowflakeTag ]] = defaultdict (list )
1565- cur = self .connection .query (
1566- SnowflakeQuery .get_tags_on_columns_with_propagation (
1567- db_name , quoted_table_name
1568- ),
1569- )
1570-
1571- for tag in cur :
1572- column_name = tag ["COLUMN_NAME" ]
1573- snowflake_tag = SnowflakeTag (
1574- database = tag ["TAG_DATABASE" ],
1575- schema = tag ["TAG_SCHEMA" ],
1576- name = tag ["TAG_NAME" ],
1577- value = tag ["TAG_VALUE" ],
1578- )
1579- tags [column_name ].append (snowflake_tag )
1682+ continue
15801683
15811684 return tags
15821685
0 commit comments