|
22 | 22 | ) |
23 | 23 | from ibis.backends.base.sql.alchemy.registry import _gen_string_find |
24 | 24 | from ibis.backends.base.sql.alchemy.registry import _literal as base_literal |
| 25 | +from ibis.common.enums import DateUnit, IntervalUnit |
25 | 26 |
|
26 | 27 | operation_registry = sqlalchemy_operation_registry.copy() |
27 | 28 | operation_registry.update(sqlalchemy_window_functions_registry) |
@@ -93,18 +94,22 @@ def _extract_quarter(t, op): |
93 | 94 |
|
94 | 95 |
|
95 | 96 | _truncate_modifiers = { |
96 | | - 'Y': 'start of year', |
97 | | - 'M': 'start of month', |
98 | | - 'D': 'start of day', |
99 | | - 'W': 'weekday 1', |
| 97 | + DateUnit.DAY: 'start of day', |
| 98 | + DateUnit.WEEK: 'weekday 1', |
| 99 | + DateUnit.MONTH: 'start of month', |
| 100 | + DateUnit.YEAR: 'start of year', |
| 101 | + IntervalUnit.DAY: 'start of day', |
| 102 | + IntervalUnit.WEEK: 'weekday 1', |
| 103 | + IntervalUnit.MONTH: 'start of month', |
| 104 | + IntervalUnit.YEAR: 'start of year', |
100 | 105 | } |
101 | 106 |
|
102 | 107 |
|
103 | 108 | def _truncate(func): |
104 | 109 | def translator(t, op): |
105 | 110 | sa_arg = t.translate(op.arg) |
106 | 111 | try: |
107 | | - modifier = _truncate_modifiers[op.unit.short] |
| 112 | + modifier = _truncate_modifiers[op.unit] |
108 | 113 | except KeyError: |
109 | 114 | raise com.UnsupportedOperationError( |
110 | 115 | f'Unsupported truncate unit {op.unit!r}' |
@@ -208,6 +213,36 @@ def _arbitrary(t, op): |
208 | 213 | return reduction(getattr(sa.func, f"_ibis_sqlite_arbitrary_{how}"))(t, op) |
209 | 214 |
|
210 | 215 |
|
| 216 | +_INTERVAL_DATE_UNITS = frozenset( |
| 217 | + (IntervalUnit.YEAR, IntervalUnit.MONTH, IntervalUnit.DAY) |
| 218 | +) |
| 219 | + |
| 220 | + |
| 221 | +def _timestamp_op(func, sign, units): |
| 222 | + def _formatter(translator, op): |
| 223 | + arg, offset = op.args |
| 224 | + |
| 225 | + unit = offset.output_dtype.unit |
| 226 | + if unit not in units: |
| 227 | + raise com.UnsupportedOperationError( |
| 228 | + "SQLite does not allow binary operation " |
| 229 | + f"{func} with INTERVAL offset {unit}" |
| 230 | + ) |
| 231 | + offset = translator.translate(offset) |
| 232 | + result = getattr(sa.func, func)( |
| 233 | + translator.translate(arg), |
| 234 | + f"{sign}{offset.value} {unit.plural}", |
| 235 | + ) |
| 236 | + return result |
| 237 | + |
| 238 | + return _formatter |
| 239 | + |
| 240 | + |
| 241 | +def _date_diff(t, op): |
| 242 | + left, right = map(t.translate, op.args) |
| 243 | + return sa.func.julianday(left) - sa.func.julianday(right) |
| 244 | + |
| 245 | + |
211 | 246 | operation_registry.update( |
212 | 247 | { |
213 | 248 | # TODO(kszucs): don't dispatch on op.arg since that should be always an |
@@ -242,6 +277,9 @@ def _arbitrary(t, op): |
242 | 277 | ), |
243 | 278 | ops.DateTruncate: _truncate(sa.func.date), |
244 | 279 | ops.Date: unary(sa.func.date), |
| 280 | + ops.DateAdd: _timestamp_op("DATE", "+", _INTERVAL_DATE_UNITS), |
| 281 | + ops.DateSub: _timestamp_op("DATE", "-", _INTERVAL_DATE_UNITS), |
| 282 | + ops.DateDiff: _date_diff, |
245 | 283 | ops.Time: unary(sa.func.time), |
246 | 284 | ops.TimestampTruncate: _truncate(sa.func.datetime), |
247 | 285 | ops.Strftime: fixed_arity( |
|
0 commit comments