Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3728,6 +3728,19 @@ def timestamp_add(unit: str, quantity: "ColumnOrName", ts: "ColumnOrName") -> Co
timestamp_add.__doc__ = pysparkfuncs.timestamp_add.__doc__


def time_bucket(
bucket_size: "ColumnOrName",
ts: "ColumnOrName",
origin: Optional["ColumnOrName"] = None,
) -> Column:
if origin is None:
return _invoke_function_over_columns("time_bucket", bucket_size, ts)
return _invoke_function_over_columns("time_bucket", bucket_size, ts, origin)


time_bucket.__doc__ = pysparkfuncs.time_bucket.__doc__


def window(
timeColumn: "ColumnOrName",
windowDuration: str,
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/sql/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@
"timestamp_micros",
"timestamp_millis",
"timestamp_seconds",
"time_bucket",
"time_diff",
"time_from_micros",
"time_from_millis",
Expand Down
68 changes: 68 additions & 0 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13125,6 +13125,74 @@ def timestamp_add(unit: str, quantity: "ColumnOrName", ts: "ColumnOrName") -> Co
)


@_try_remote_functions
def time_bucket(
bucket_size: "ColumnOrName",
ts: "ColumnOrName",
origin: Optional["ColumnOrName"] = None,
) -> Column:
"""
Aligns a timestamp to the start of a fixed-size interval bucket.

Returns the start of the bucket that ``ts`` falls into, where buckets are defined by
the given ``bucket_size`` interval aligned to ``origin``. All bucketing is performed on
UTC micros, the session time zone does not affect bucket alignment. For local wall-clock
alignment in a DST zone, cast the TIMESTAMP to TIMESTAMP_NTZ.

.. versionadded:: 4.2.0

Parameters
----------
bucket_size : :class:`~pyspark.sql.Column` or column name
A day-time or year-month interval defining the bucket size. Must be positive
and foldable.
ts : :class:`~pyspark.sql.Column` or column name
A TIMESTAMP or TIMESTAMP_NTZ value to bucket.
origin : :class:`~pyspark.sql.Column` or column name, optional
Alignment anchor. Defaults to 1970-01-01 00:00:00 (UTC for TIMESTAMP). Must be
the same type as ``ts`` and must be foldable.

Returns
-------
:class:`~pyspark.sql.Column`
The start of the bucket containing ``ts``, as the same type as ``ts``.

Examples
--------
>>> spark.conf.set("spark.sql.session.timeZone", "UTC")
>>> import datetime
>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame(
... [(datetime.datetime(2024, 1, 1, 11, 27, 0),)], ['ts'])
>>> df.select(
... sf.time_bucket(sf.expr("INTERVAL '15' MINUTE"), 'ts').alias("bucket")
... ).collect()
[Row(bucket=datetime.datetime(2024, 1, 1, 11, 15))]

Shift the grid with an explicit origin: buckets run at :05, :20, :35, :50:

>>> df.select(
... sf.time_bucket(
... sf.expr("INTERVAL '15' MINUTE"),
... 'ts',
... sf.expr("TIMESTAMP '1970-01-01 00:05:00'")
... ).alias("bucket")
... ).collect()
[Row(bucket=datetime.datetime(2024, 1, 1, 11, 20))]
>>> spark.conf.unset("spark.sql.session.timeZone")
"""
from pyspark.sql.classic.column import _to_java_column

if origin is None:
return _invoke_function("time_bucket", _to_java_column(bucket_size), _to_java_column(ts))
return _invoke_function(
"time_bucket",
_to_java_column(bucket_size),
_to_java_column(ts),
_to_java_column(origin),
)


@_try_remote_functions
def window(
timeColumn: "ColumnOrName",
Expand Down
20 changes: 20 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8486,6 +8486,26 @@ object functions {
def timestamp_add(unit: String, quantity: Column, ts: Column): Column =
Column.internalFn("timestampadd", lit(unit), quantity, ts)

/**
* Returns the start of the fixed-size bucket of `bucketSize` that contains `ts`, with buckets
* aligned to the epoch (1970-01-01 00:00:00). All computation is in UTC.
*
* @group datetime_funcs
* @since 4.2.0
*/
def time_bucket(bucketSize: Column, ts: Column): Column =
Column.fn("time_bucket", bucketSize, ts)

/**
* Returns the start of the fixed-size bucket of `bucketSize` that contains `ts`, with buckets
* aligned to `origin`. All computation is in UTC.
*
* @group datetime_funcs
* @since 4.2.0
*/
def time_bucket(bucketSize: Column, ts: Column, origin: Column): Column =
Column.fn("time_bucket", bucketSize, ts, origin)

/**
* Returns the difference between two times, measured in specified units. Throws a
* SparkIllegalArgumentException, in case the specified unit is not supported.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,7 @@ object FunctionRegistry {
expression[UnixMillis]("unix_millis"),
expression[UnixMicros]("unix_micros"),
expression[ConvertTimezone]("convert_timezone"),
expressionBuilder("time_bucket", TimeBucketExpressionBuilder),

// collection functions
expression[CreateArray]("array"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ import java.util.Locale

import org.apache.commons.text.StringEscapeUtils

import org.apache.spark.{SparkDateTimeException, SparkIllegalArgumentException}
import org.apache.spark.{SparkDateTimeException, SparkException, SparkIllegalArgumentException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry}
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.Cast.{ordinalNumber, toSQLExpr, toSQLId, toSQLType, toSQLValue}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
Expand Down Expand Up @@ -3897,3 +3899,178 @@ case class TimestampDiff(
copy(startTimestamp = newLeft, endTimestamp = newRight)
}
}

/**
* Aligns a timestamp to the start of a fixed-size interval bucket.
*
* Returns the start of the half-open bucket [start, start + bucketSize) containing ts.
* All computation is performed on UTC values.
*/
case class TimeBucket(
bucketSize: Expression,
ts: Expression,
originTs: Expression)
extends TernaryExpression with ExpectsInputTypes {

override def nullIntolerant: Boolean = true

override def first: Expression = bucketSize
override def second: Expression = ts
override def third: Expression = originTs

override def inputTypes: Seq[AbstractDataType] = Seq(
TypeCollection(DayTimeIntervalType, YearMonthIntervalType),
AnyTimestampType,
AnyTimestampType)

override def dataType: DataType = ts.dataType

override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) return defaultCheck

if (!bucketSize.foldable) {
return DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> toSQLId("bucketSize"),
"inputType" -> toSQLType(bucketSize.dataType),
"inputExpr" -> toSQLExpr(bucketSize)))
}

val bucketSizeValue = bucketSize.eval()
if (bucketSizeValue != null) {
val isNonPositive = bucketSize.dataType match {
case _: DayTimeIntervalType => bucketSizeValue.asInstanceOf[Long] <= 0
case _: YearMonthIntervalType => bucketSizeValue.asInstanceOf[Int] <= 0
case other => throw SparkException.internalError(
s"Unexpected bucketSize type: $other")
}
if (isNonPositive) {
return DataTypeMismatch(
errorSubClass = "VALUE_OUT_OF_RANGE",
messageParameters = Map(
"exprName" -> "time_bucket",
"valueRange" -> "(0, inf)",
"currentValue" -> toSQLValue(bucketSizeValue, bucketSize.dataType)))
}
}

if (!originTs.foldable) {
return DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> toSQLId("origin"),
"inputType" -> toSQLType(originTs.dataType),
"inputExpr" -> toSQLExpr(originTs)))
}

if (ts.dataType != originTs.dataType) {
return DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(2),
"requiredType" -> toSQLType(ts.dataType),
"inputSql" -> toSQLExpr(originTs),
"inputType" -> toSQLType(originTs.dataType)))
}

TypeCheckSuccess
}

override def nullSafeEval(bucketSizeVal: Any, tsVal: Any, originVal: Any): Any = {
first.dataType match {
case _: DayTimeIntervalType =>
DateTimeUtils.timeBucketDTInterval(
bucketSizeVal.asInstanceOf[Long], tsVal.asInstanceOf[Long],
originVal.asInstanceOf[Long])
case _: YearMonthIntervalType =>
DateTimeUtils.timeBucketYMInterval(
bucketSizeVal.asInstanceOf[Int], tsVal.asInstanceOf[Long],
originVal.asInstanceOf[Long])
case other => throw SparkException.internalError(
s"Unexpected bucketSize type: $other")
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
first.dataType match {
case _: DayTimeIntervalType =>
defineCodeGen(ctx, ev, (bucketSizeCode, tsCode, originCode) =>
s"$dtu.timeBucketDTInterval($bucketSizeCode, $tsCode, $originCode)")
case _: YearMonthIntervalType =>
defineCodeGen(ctx, ev, (bucketSizeCode, tsCode, originCode) =>
s"$dtu.timeBucketYMInterval($bucketSizeCode, $tsCode, $originCode)")
case other => throw SparkException.internalError(
s"Unexpected bucketSize type: $other")
}
}

override def prettyName: String = "time_bucket"

override protected def withNewChildrenInternal(
newFirst: Expression, newSecond: Expression, newThird: Expression): TimeBucket =
copy(bucketSize = newFirst, ts = newSecond, originTs = newThird)
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(bucketSize, ts[, origin]) - Returns the start of the bucket that `ts` falls into,
where buckets are defined by the given `bucketSize` interval aligned to `origin`. All
bucketing is performed on UTC micros, the session time zone does not affect bucket
alignment. For local wall-clock alignment in a DST zone, cast the TIMESTAMP to
TIMESTAMP_NTZ.
""",
arguments = """
Arguments:
* bucketSize - A day-time or year-month interval defining the bucket size. Must be positive and foldable.
* ts - A TIMESTAMP or TIMESTAMP_NTZ value to bucket.
* origin - Optional TIMESTAMP or TIMESTAMP_NTZ alignment anchor. Defaults to 1970-01-01 00:00:00 (UTC for TIMESTAMP). Must be the same type as ts and must be foldable.
""",
examples = """
Examples:
> SELECT _FUNC_(INTERVAL '15' MINUTE, TIMESTAMP '2024-01-01 11:27:00', TIMESTAMP '1970-01-01 00:00:00');
2024-01-01 11:15:00
> SELECT _FUNC_(INTERVAL '1' HOUR, TIMESTAMP '2024-01-01 11:27:00');
2024-01-01 11:00:00
> SELECT _FUNC_(INTERVAL '1' MONTH, TIMESTAMP '2024-07-20 14:30:00', TIMESTAMP '2024-06-15 09:00:00');
2024-07-15 09:00:00
""",
since = "4.2.0",
group = "datetime_funcs")
// scalastyle:on line.size.limit
object TimeBucketExpressionBuilder extends ExpressionBuilder {
private def retypeNull(e: Expression, dt: DataType): Expression = e match {
case Literal(null, NullType) => Literal(null, dt)
case _ => e
}

override def build(funcName: String, expressions: Seq[Expression]): Expression = {
expressions match {
case Seq(rawBucketSize, rawTs) =>
val bucketSize = retypeNull(rawBucketSize, DayTimeIntervalType())
// Fall back to TimestampType for bad ts types; ExpectsInputTypes will report it.
val tsType = rawTs.dataType match {
case t if AnyTimestampType.acceptsType(t) => t
case _ => TimestampType
}
val ts = retypeNull(rawTs, tsType)
TimeBucket(bucketSize, ts, Literal(0L, tsType))
case Seq(rawBucketSize, rawTs, rawOrigin) =>
val bucketSize = retypeNull(rawBucketSize, DayTimeIntervalType())
val tsType = (rawTs.dataType, rawOrigin.dataType) match {
case (NullType, t) if AnyTimestampType.acceptsType(t) => t
case (NullType, _) => TimestampType
case (t, _) => t
}
val ts = retypeNull(rawTs, tsType)
val originTs = retypeNull(rawOrigin, tsType)
TimeBucket(bucketSize, ts, originTs)
case _ =>
throw QueryCompilationErrors.wrongNumArgsError(
funcName, Seq(2, 3), expressions.length)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1059,4 +1059,59 @@ object DateTimeUtils extends SparkDateTimeUtils {
time, timePrecision, interval, intervalEndField)
}
}

/**
* DayTimeInterval bucketing: microsecond floor division against `originMicros`.
* Returns `originMicros + floorDiv(tsMicros - originMicros, bucketMicros) * bucketMicros`.
*
* `bucketMicros` must be positive; `TimeBucket.checkInputDataTypes` enforces
* this at analysis time.
*
* @param bucketMicros bucket size in microseconds.
* @param tsMicros timestamp to bucket, in microseconds since the epoch (UTC).
* @param originMicros grid alignment anchor, in microseconds since the epoch (UTC).
*/
def timeBucketDTInterval(bucketMicros: Long, tsMicros: Long, originMicros: Long): Long = {
val diff = Math.subtractExact(tsMicros, originMicros)
val bucketOffset = Math.multiplyExact(Math.floorDiv(diff, bucketMicros), bucketMicros)
Math.addExact(originMicros, bucketOffset)
}

/**
* YearMonthInterval bucketing: month arithmetic with end-of-month capping and step-back.
* The origin's day-of-month and time-of-day determine the bucket boundaries.
*
* `bucketMonths` must be positive; `TimeBucket.checkInputDataTypes` enforces
* this at analysis time.
*
* @param bucketMonths bucket size in months.
* @param tsMicros timestamp to bucket, in microseconds since the epoch (UTC).
* @param originMicros grid alignment anchor, in microseconds since the epoch (UTC).
*/
def timeBucketYMInterval(bucketMonths: Int, tsMicros: Long, originMicros: Long): Long = {
val tsDays = microsToDays(tsMicros, ZoneOffset.UTC)
val originDays = microsToDays(originMicros, ZoneOffset.UTC)
val originTodMicros =
Math.subtractExact(originMicros, daysToMicros(originDays, ZoneOffset.UTC))

val tsDate = daysToLocalDate(tsDays)
val originDate = daysToLocalDate(originDays)
val rawMonthDiff = (tsDate.getYear.toLong * 12 + tsDate.getMonthValue) -
(originDate.getYear.toLong * 12 + originDate.getMonthValue)

var k = Math.floorDiv(rawMonthDiff, bucketMonths.toLong)
var candidateDays = dateAddMonths(originDays,
Math.toIntExact(Math.multiplyExact(k, bucketMonths.toLong)))
var candidate = Math.addExact(daysToMicros(candidateDays, ZoneOffset.UTC), originTodMicros)

// End-of-month capping in dateAddMonths can overshoot; step back one bucket if so.
if (candidate > tsMicros) {
k -= 1
candidateDays = dateAddMonths(originDays,
Math.toIntExact(Math.multiplyExact(k, bucketMonths.toLong)))
candidate = Math.addExact(daysToMicros(candidateDays, ZoneOffset.UTC), originTodMicros)
}

candidate
}
}
Loading