Skip to content

Use sqlglot to create partition by expressions in BigQuery #2695

Open
@hsm207

Description

@hsm207

Feature description

Add support for generating BigQuery PARTITION BY expressions using sqlglot

Are you a dlt user?

Yes, I'm already a dlt user.

Use case

Improve type safety and prevent SQL injection attacks when generating partition expressions

Proposed solution

# Minimal working example: bigquery_partition using sqlglot
# NOTE: The PartitionTransformation class is now obsolete.
from typing import Literal, Protocol, List, TypeVar
from dataclasses import dataclass
from sqlglot import exp


######################################################################
# Partition Spec Classes (DB-dependent)
#
# NOTE: We use @dataclass for partition specs instead of TypedDict because:
#
# - Dataclasses provide more robust type safety, IDE auto-complete, and validation.
# - They support default values, immutability (frozen=True), and methods.
######################################################################
@dataclass(frozen=True)
class BigQueryRangeBucketPartition:
    column_name: str
    start: int
    end: int
    interval: int = 1

    def __post_init__(self):
        if self.interval <= 0:
            raise ValueError("interval must be a positive integer")
        if self.start >= self.end:
            raise ValueError("start must be less than end (exclusive)")


@dataclass(frozen=True)
class BigQueryDateTruncPartition:
    column_name: str
    granularity: Literal["MONTH", "YEAR"]

    def __post_init__(self):
        allowed = ("MONTH", "YEAR")
        if self.granularity not in allowed:
            raise ValueError(f"granularity must be one of {allowed}, got {self.granularity!r}")


# BigQuery-specific union of supported partition specs
BigQueryPartitionSpec = BigQueryRangeBucketPartition | BigQueryDateTruncPartition

# --- Partition API using sqlglot ---


# --- Partition Renderer Protocol (Generic) ---
# NOTE: Generic and can be used for any dialect's partition spec.
T = TypeVar("T")


class PartitionRenderer(Protocol[T]):
    @staticmethod
    def render_sql(partitions: List[T]) -> str: ...


# --- BigQuery Partition Renderer ---


class BigQueryPartitionRenderer(PartitionRenderer[BigQueryPartitionSpec]):
    """BigQuery partition expression generator and renderer using sqlglot."""

    # Registry for dispatching partition spec types to renderers
    _DISPATCH = {
        BigQueryRangeBucketPartition: (
            lambda partition: BigQueryPartitionRenderer._render_range_bucket_expr(partition)
        ),
        BigQueryDateTruncPartition: (
            lambda partition: BigQueryPartitionRenderer._render_date_trunc_expr(partition)
        ),
    }

    @staticmethod
    def render_sql(partitions: List[BigQueryPartitionSpec]) -> str:
        """
        Returns the full PARTITION BY clause for BigQuery, e.g.:
        PARTITION BY RANGE_BUCKET(...) or PARTITION BY DATE_TRUNC(...)
        """
        if len(partitions) != 1:
            raise ValueError("BigQuery only supports partitioning by a single column.")
        partition = partitions[0]
        handler = BigQueryPartitionRenderer._DISPATCH.get(type(partition))
        if not handler:
            raise NotImplementedError(f"Unknown partition type: {type(partition)}")
        expr_sql = handler(partition)
        return f"PARTITION BY {expr_sql}"

    @staticmethod
    def _render_range_bucket_expr(partition: BigQueryRangeBucketPartition) -> str:
        expr = exp.Anonymous(
            this="RANGE_BUCKET",
            expressions=[
                exp.to_identifier(partition.column_name),
                exp.Anonymous(
                    this="GENERATE_ARRAY",
                    expressions=[
                        exp.Literal.number(partition.start),
                        exp.Literal.number(partition.end),
                        exp.Literal.number(partition.interval),
                    ],
                ),
            ],
        )
        return expr.sql(dialect="bigquery")

    @staticmethod
    def _render_date_trunc_expr(partition: BigQueryDateTruncPartition) -> str:
        expr = exp.Anonymous(
            this="DATE_TRUNC",
            expressions=[
                exp.to_identifier(partition.column_name),
                exp.Literal.string(partition.granularity),
            ],
        )
        return expr.sql(dialect="bigquery")


# --- Example usage ---


if __name__ == "__main__":
    # User creates a range bucket partition spec for BigQuery
    part1 = BigQueryRangeBucketPartition("user_id", 0, 100, 10)
    sql1 = BigQueryPartitionRenderer.render_sql([part1])
    expected_sql1 = "PARTITION BY RANGE_BUCKET(user_id, GENERATE_ARRAY(0, 100, 10))"
    assert sql1 == expected_sql1, f"SQL does not match expected: {sql1} != {expected_sql1}"

    # User creates a date_trunc partition spec for BigQuery
    part2 = BigQueryDateTruncPartition("created_at", "MONTH")
    sql2 = BigQueryPartitionRenderer.render_sql([part2])
    expected_sql2 = "PARTITION BY DATE_TRUNC(created_at, 'MONTH')"
    assert sql2 == expected_sql2, f"SQL does not match expected: {sql2} != {expected_sql2}"

    # User tries to render more than 1 partition (should fail for BigQuery)
    try:
        sql_multi = BigQueryPartitionRenderer.render_sql([part1, part2])
    except ValueError as e:
        print(f"Caught expected error for multiple partitions: {e}")

    # User tries to use an invalid granularity
    try:
        part_invalid = BigQueryDateTruncPartition("created_at", "DAY")
    except ValueError as e:
        print(f"Caught expected error for invalid granularity: {e}")

Related issues

#2696

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    Status

    Todo

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions