44from textwrap import dedent
55from typing import Iterator
66
7- from . import _cur , _introspect
7+ from . import _cur , _introspect , _partition
88from . import _psycopg as psycopg
99
1010
@@ -16,11 +16,13 @@ def __init__(
1616 cur : _cur .Cursor ,
1717 introspector : _introspect .Introspector ,
1818 schema : str ,
19+ partition_config : _partition .PartitionConfig | None = None ,
1920 ) -> None :
2021 self .conn = conn
2122 self .cur = cur
2223 self .introspector = introspector
2324 self .schema = schema
25+ self .partition_config = partition_config
2426
2527 def drop_constraint (self , * , table : str , constraint : str ) -> None :
2628 self .cur .execute (
@@ -49,6 +51,17 @@ def drop_table_if_exists(self, *, table: str) -> None:
4951 )
5052
5153 def create_copy_table (self , * , base_table : str , copy_table : str ) -> None :
54+ if self .partition_config :
55+ return self ._create_partitioned_table (
56+ base_table = base_table , copy_table = copy_table
57+ )
58+
59+ # Create a non-partitioned table (default behavior)
60+ self ._create_non_partitioned_table (base_table = base_table , copy_table = copy_table )
61+
62+ def _create_non_partitioned_table (
63+ self , * , base_table : str , copy_table : str
64+ ) -> None :
5265 self .cur .execute (
5366 psycopg .sql .SQL (
5467 dedent ("""
@@ -64,6 +77,185 @@ def create_copy_table(self, *, base_table: str, copy_table: str) -> None:
6477 .as_string (self .conn )
6578 )
6679
80+ def _create_partitioned_table (self , * , base_table : str , copy_table : str ) -> None :
81+ assert self .partition_config is not None
82+ assert isinstance (self .partition_config .strategy , _partition .DateRangeStrategy )
83+
84+ # Create the parent partitioned table
85+ self .cur .execute (
86+ psycopg .sql .SQL (
87+ dedent ("""
88+ CREATE TABLE {schema}.{copy_table}
89+ (LIKE {schema}.{table} INCLUDING DEFAULTS)
90+ PARTITION BY RANGE ({partition_column});
91+ """ )
92+ )
93+ .format (
94+ table = psycopg .sql .Identifier (base_table ),
95+ copy_table = psycopg .sql .Identifier (copy_table ),
96+ schema = psycopg .sql .Identifier (self .schema ),
97+ partition_column = psycopg .sql .Identifier (self .partition_config .column ),
98+ )
99+ .as_string (self .conn )
100+ )
101+
102+ # Create partitions ahead of time
103+ self ._create_partitions (base_table = base_table , copy_table = copy_table )
104+
105+ def _create_partitions (self , * , base_table : str , copy_table : str ) -> None :
106+ assert self .partition_config is not None
107+ strategy = self .partition_config .strategy
108+ assert isinstance (strategy , _partition .DateRangeStrategy )
109+
110+ num_of_extra_partitions = self .partition_config .num_of_extra_partitions_ahead
111+
112+ min_value = self .introspector .get_min_partition_date_value (
113+ table = base_table , column = self .partition_config .column
114+ )
115+ max_value = self .introspector .get_max_partition_date_value (
116+ table = base_table , column = self .partition_config .column
117+ )
118+ partition_start = self ._get_first_partition_start_date (
119+ min_value = min_value , strategy = strategy
120+ )
121+ partition_end = self ._get_last_partition_end_date (
122+ max_value = max_value ,
123+ strategy = strategy ,
124+ num_of_extra_partitions = num_of_extra_partitions ,
125+ )
126+
127+ # Create partitions from partition_start to partition_end
128+ current_partition_start = partition_start
129+
130+ while current_partition_start < partition_end :
131+ partition_suffix = self ._get_partition_suffix (
132+ current_partition_start = current_partition_start , strategy = strategy
133+ )
134+
135+ current_partition_end = self ._get_partition_end_boundary (
136+ current_partition_start = current_partition_start , strategy = strategy
137+ )
138+ self ._create_datetime_partition (
139+ base_table = base_table ,
140+ copy_table = copy_table ,
141+ partition_suffix = partition_suffix ,
142+ start = current_partition_start ,
143+ end = current_partition_end ,
144+ )
145+ current_partition_start = current_partition_end
146+
147+ def _get_first_partition_start_date (
148+ self , * , min_value : datetime .date , strategy : _partition .DateRangeStrategy
149+ ) -> datetime .date :
150+ """
151+ Align the minimum value to partition boundaries.
152+ For DAY: uses the exact min_value
153+ For MONTH: aligns to the first day of the month
154+ """
155+ if strategy .partition_by == _partition .PartitionInterval .DAY :
156+ return min_value
157+ elif strategy .partition_by == _partition .PartitionInterval .MONTH :
158+ # Align to start of month
159+ return min_value .replace (day = 1 )
160+ else :
161+ raise ValueError (f"Unsupported partition_by: { strategy .partition_by } " )
162+
163+ def _get_last_partition_end_date (
164+ self ,
165+ * ,
166+ max_value : datetime .date ,
167+ strategy : _partition .DateRangeStrategy ,
168+ num_of_extra_partitions : int ,
169+ ) -> datetime .date :
170+ """
171+ Calculate the end date for partitioning: max_value + num_of_extra_partitions.
172+ For DAY: adds the specified number of days
173+ For MONTH: adds the specified number of months
174+ """
175+ if strategy .partition_by == _partition .PartitionInterval .DAY :
176+ return max_value + datetime .timedelta (days = num_of_extra_partitions )
177+ elif strategy .partition_by == _partition .PartitionInterval .MONTH :
178+ # Add months by advancing to first of month and adding 32*months,
179+ # then normalising. This is because timedelta doesn't deal accept
180+ # "months" as argument.
181+ temp_date = max_value .replace (day = 1 )
182+ for _ in range (num_of_extra_partitions ):
183+ temp_date = (temp_date + datetime .timedelta (days = 32 )).replace (day = 1 )
184+ return temp_date
185+ else :
186+ raise ValueError (f"Unsupported partition_by: { strategy .partition_by } " )
187+
188+ def _get_partition_end_boundary (
189+ self ,
190+ * ,
191+ current_partition_start : datetime .date ,
192+ strategy : _partition .DateRangeStrategy ,
193+ ) -> datetime .date :
194+ """
195+ Calculate the end boundary for a single partition.
196+ For DAY: adds 1 day
197+ For MONTH: advances to the first day of the next month
198+ """
199+ if strategy .partition_by == _partition .PartitionInterval .DAY :
200+ return current_partition_start + datetime .timedelta (days = 1 )
201+ elif strategy .partition_by == _partition .PartitionInterval .MONTH :
202+ # Next month boundary
203+ return (current_partition_start + datetime .timedelta (days = 32 )).replace (
204+ day = 1
205+ )
206+ else :
207+ raise ValueError (f"Unsupported partition_by: { strategy .partition_by } " )
208+
209+ def _get_partition_suffix (
210+ self ,
211+ * ,
212+ current_partition_start : datetime .date ,
213+ strategy : _partition .DateRangeStrategy ,
214+ ) -> str :
215+ """
216+ Generate a date-based partition suffix.
217+ For DAY: returns p20250101 (YYYYMMDD format)
218+ For MONTH: returns p202501 (YYYYMM format)
219+ """
220+ if strategy .partition_by == _partition .PartitionInterval .DAY :
221+ # Format: p20250101 (YYYYMMDD)
222+ return f"p{ current_partition_start .strftime ('%Y%m%d' )} "
223+ elif strategy .partition_by == _partition .PartitionInterval .MONTH :
224+ # Format: p202501 (YYYYMM)
225+ return f"p{ current_partition_start .strftime ('%Y%m' )} "
226+ else :
227+ raise ValueError (f"Unsupported partition_by: { strategy .partition_by } " )
228+
229+ def _create_datetime_partition (
230+ self ,
231+ * ,
232+ base_table : str ,
233+ copy_table : str ,
234+ partition_suffix : str ,
235+ start : datetime .date ,
236+ end : datetime .date ,
237+ ) -> None :
238+ """Create a single datetime range partition."""
239+ self .cur .execute (
240+ psycopg .sql .SQL (
241+ dedent ("""
242+ CREATE TABLE {schema}.{partition_name}
243+ PARTITION OF {schema}.{copy_table}
244+ FOR VALUES FROM ({start}) TO ({end});
245+ """ )
246+ )
247+ .format (
248+ schema = psycopg .sql .Identifier (self .schema ),
249+ partition_name = psycopg .sql .Identifier (
250+ f"{ base_table } _{ partition_suffix } "
251+ ),
252+ copy_table = psycopg .sql .Identifier (copy_table ),
253+ start = psycopg .sql .Literal (start ),
254+ end = psycopg .sql .Literal (end ),
255+ )
256+ .as_string (self .conn )
257+ )
258+
67259 def drop_sequence_if_exists (self , * , seq : str ) -> None :
68260 self .cur .execute (
69261 psycopg .sql .SQL ("DROP SEQUENCE IF EXISTS {schema}.{seq};" )
@@ -109,17 +301,37 @@ def set_table_id_seq(self, *, table: str, seq: str, pk_column: str) -> None:
109301 )
110302
111303 def add_pk (self , * , table : str , pk_column : str ) -> None :
112- self .cur .execute (
113- psycopg .sql .SQL (
114- "ALTER TABLE {schema}.{table} ADD PRIMARY KEY ({pk_column});"
304+ # For partitioned tables, the PK must include all partitioning columns
305+ if self .partition_config :
306+ pk_columns = psycopg .sql .SQL (", " ).join (
307+ [
308+ psycopg .sql .Identifier (pk_column ),
309+ psycopg .sql .Identifier (self .partition_config .column ),
310+ ]
115311 )
116- .format (
117- table = psycopg .sql .Identifier (table ),
118- pk_column = psycopg .sql .Identifier (pk_column ),
119- schema = psycopg .sql .Identifier (self .schema ),
312+ self .cur .execute (
313+ psycopg .sql .SQL (
314+ "ALTER TABLE {schema}.{table} ADD PRIMARY KEY ({pk_columns});"
315+ )
316+ .format (
317+ table = psycopg .sql .Identifier (table ),
318+ pk_columns = pk_columns ,
319+ schema = psycopg .sql .Identifier (self .schema ),
320+ )
321+ .as_string (self .conn )
322+ )
323+ else :
324+ self .cur .execute (
325+ psycopg .sql .SQL (
326+ "ALTER TABLE {schema}.{table} ADD PRIMARY KEY ({pk_column});"
327+ )
328+ .format (
329+ table = psycopg .sql .Identifier (table ),
330+ pk_column = psycopg .sql .Identifier (pk_column ),
331+ schema = psycopg .sql .Identifier (self .schema ),
332+ )
333+ .as_string (self .conn )
120334 )
121- .as_string (self .conn )
122- )
123335
124336 def create_copy_function (
125337 self ,
@@ -511,12 +723,24 @@ def create_unique_constraint_using_idx(
511723 def create_not_valid_constraint_from_def (
512724 self , * , table : str , constraint : str , definition : str , is_validated : bool
513725 ) -> None :
726+ # For partitioned tables, we can't use NOT VALID on foreign keys
727+ # So we need to remove it from the definition
728+ is_fk = "FOREIGN KEY" in definition .upper ()
729+ if self .partition_config and is_fk and not is_validated :
730+ # Remove NOT VALID from the definition for partitioned tables
731+ definition = definition .replace (" NOT VALID" , "" ).replace ("NOT VALID" , "" )
732+
514733 add_constraint_sql = dedent ("""
515734 ALTER TABLE {schema}.{table}
516735 ADD CONSTRAINT {constraint}
517736 {definition}
518737 """ )
519- if is_validated :
738+ # Only add NOT VALID if:
739+ # 1. The constraint is validated (so we make it NOT VALID temporarily)
740+ # 2. AND it's not a FK on a partitioned table (which doesn't support NOT VALID)
741+ should_add_not_valid = is_validated and not (self .partition_config and is_fk )
742+
743+ if should_add_not_valid :
520744 # If the definition is for a valid constraint, alter it to be not
521745 # valid manually so that it can be created ONLINE.
522746 add_constraint_sql += " NOT VALID"
0 commit comments