Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 65 additions & 6 deletions python/pyspark/pandas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4790,12 +4790,17 @@ def insert(
psdf = psdf[columns]
self._update_internal_frame(psdf._internal)

# TODO(SPARK-46156): add frep and axis parameter
def shift(self, periods: int = 1, fill_value: Optional[Any] = None) -> "DataFrame":
# TODO(SPARK-46160): add freq parameter
def shift(
self,
periods: int = 1,
fill_value: Optional[Any] = None,
axis: Axis = 0,
) -> "DataFrame":
"""
Shift DataFrame by desired number of periods.

.. note:: the current implementation of shift uses Spark's Window without
.. note:: When axis=0, the current implementation of shift uses Spark's Window without
specifying partition specification. This leads to moving all data into
a single partition in a single machine and could cause serious
performance degradation. Avoid this method with very large datasets.
Expand All @@ -4807,6 +4812,13 @@ def shift(self, periods: int = 1, fill_value: Optional[Any] = None) -> "DataFram
fill_value : object, optional
The scalar value to use for newly introduced missing values.
The default depends on the dtype of self. For numeric data, np.nan is used.
axis : {0 or 'index', 1 or 'columns'}, default 0
Axis along which to shift:

* 0 or 'index': shift each column independently (down/up rows)
* 1 or 'columns': shift each row independently (across columns)

.. versionchanged:: 4.2.0

Returns
-------
Expand Down Expand Up @@ -4835,10 +4847,57 @@ def shift(self, periods: int = 1, fill_value: Optional[Any] = None) -> "DataFram
3 10 13 17
4 20 23 27

Shift across columns with axis=1:

>>> df = ps.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6], 'C': [7, 8, 9]},
... columns=['A', 'B', 'C'])
>>> df.shift(periods=1, axis=1).sort_index()
A B C
0 NaN 1.0 4.0
1 NaN 2.0 5.0
2 NaN 3.0 6.0
"""
return self._apply_series_op(
lambda psser: psser._shift(periods, fill_value), should_resolve=True
)
if not isinstance(periods, int):
raise TypeError("periods should be an int; however, got [%s]" % type(periods).__name__)

axis = validate_axis(axis)

if axis == 0:
return self._apply_series_op(
lambda psser: psser._shift(periods, fill_value), should_resolve=True
)
else:
# Infer result schema from a small sample, following the same
# pattern as apply() (shortcut_limit).
limit = get_option("compute.shortcut_limit")
pdf = self.head(limit + 1)._to_internal_pandas()
pdf_shifted = pdf.shift(periods=periods, fill_value=fill_value, axis=1)
if len(pdf) <= limit:
return DataFrame(InternalFrame.from_pandas(pdf_shifted))

# Use the shifted sample to infer return types so that the UDF
# path produces consistent dtypes with the fast path.
psdf_shifted = DataFrame(InternalFrame.from_pandas(pdf_shifted))
data_fields = [
field.normalize_spark_type() for field in psdf_shifted._internal.data_fields
]
return_schema = StructType([field.struct_field for field in data_fields])

column_label_strings = [
name_like_string(label) for label in self._internal.column_labels
]

@pandas_udf(returnType=return_schema) # type: ignore[call-overload]
def shift_axis_1(*cols: pd.Series) -> pd.DataFrame:
pdf_row = pd.concat(cols, axis=1, keys=column_label_strings)
return pdf_row.shift(periods=periods, fill_value=fill_value, axis=1)

shifted_struct_col = shift_axis_1(*self._internal.data_spark_columns)
new_data_columns = [
shifted_struct_col[col_name].alias(col_name) for col_name in column_label_strings
]
internal = self._internal.with_new_columns(new_data_columns, data_fields=data_fields)
return DataFrame(internal)

# TODO(SPARK-46161): axis should support 1 or 'columns' either at this moment
def diff(self, periods: int = 1, axis: Axis = 0) -> "DataFrame":
Expand Down
93 changes: 93 additions & 0 deletions python/pyspark/pandas/tests/frame/test_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,99 @@ def test_shift(self):
self.assert_eq(pdf.shift().shift(-1), psdf.shift().shift(-1))
self.assert_eq(pdf.shift(0), psdf.shift(0))

def test_shift_axis(self):
# SPARK-46160: shift with axis parameter
pdf = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
psdf = ps.from_pandas(pdf)

# Test axis=0 (explicit, should match default behavior)
self.assert_eq(pdf.shift(axis=0).sort_index(), psdf.shift(axis=0).sort_index())

# Test axis=1 (shift across columns)
self.assert_eq(pdf.shift(axis=1).sort_index(), psdf.shift(axis=1).sort_index())

# Test axis='index' and axis='columns'
self.assert_eq(pdf.shift(axis="index").sort_index(), psdf.shift(axis="index").sort_index())
self.assert_eq(
pdf.shift(axis="columns").sort_index(), psdf.shift(axis="columns").sort_index()
)

# Test various periods with axis=1
for periods in [1, -1, 2, -2, 0]:
self.assert_eq(
pdf.shift(periods=periods, axis=1).sort_index(),
psdf.shift(periods=periods, axis=1).sort_index(),
)

# Test fill_value with axis=1
self.assert_eq(
pdf.shift(periods=1, fill_value=0, axis=1).sort_index(),
psdf.shift(periods=1, fill_value=0, axis=1).sort_index(),
)

# Test with single column DataFrame
pdf_single = pd.DataFrame({"A": [1, 2, 3]})
psdf_single = ps.from_pandas(pdf_single)
self.assert_eq(
pdf_single.shift(axis=1).sort_index(),
psdf_single.shift(axis=1).sort_index(),
)

# Test with NaN values
pdf_nan = pd.DataFrame({"A": [1, np.nan, 3], "B": [4, 3, np.nan]})
psdf_nan = ps.from_pandas(pdf_nan)
self.assert_eq(
pdf_nan.shift(axis=1).sort_index(),
psdf_nan.shift(axis=1).sort_index(),
)

# Test with multi-index columns
columns = pd.MultiIndex.from_tuples([("x", "A"), ("x", "B"), ("y", "C")])
pdf.columns = columns
psdf.columns = columns
self.assert_eq(pdf.shift(axis=1).sort_index(), psdf.shift(axis=1).sort_index())

# Test with large dataset to ensure UDF path is used (>1000 rows)
rng = np.random.RandomState(42)
pdf_large = pd.DataFrame({"A": rng.rand(1500), "B": rng.rand(1500), "C": rng.rand(1500)})
psdf_large = ps.from_pandas(pdf_large)
self.assert_eq(
pdf_large.shift(axis=1).sort_index(),
psdf_large.shift(axis=1).sort_index(),
)

# Test fill_value on UDF path (large dataset)
self.assert_eq(
pdf_large.shift(periods=1, fill_value=0, axis=1).sort_index(),
psdf_large.shift(periods=1, fill_value=0, axis=1).sort_index(),
)

# Test periods larger than number of columns (should produce all NaN)
self.assert_eq(
pdf.shift(periods=5, axis=1).sort_index(),
psdf.shift(periods=5, axis=1).sort_index(),
)

# Test with mixed numeric types (int + float)
pdf_mixed = pd.DataFrame({"A": [1, 2, 3], "B": [4.0, 5.0, 6.0], "C": [7, 8, 9]})
psdf_mixed = ps.from_pandas(pdf_mixed)
self.assert_eq(
pdf_mixed.shift(axis=1).sort_index(),
psdf_mixed.shift(axis=1).sort_index(),
)

# Test with empty DataFrame
pdf_empty = pd.DataFrame({"A": pd.Series([], dtype="float64")})
psdf_empty = ps.from_pandas(pdf_empty)
self.assert_eq(
pdf_empty.shift(axis=1).sort_index(),
psdf_empty.shift(axis=1).sort_index(),
)

# Test invalid axis value
with self.assertRaisesRegex(ValueError, "No axis named"):
psdf.shift(axis=2)

def test_first_valid_index(self):
pdf = pd.DataFrame(
{"a": [None, 2, 3, 2], "b": [None, 2.0, 3.0, 1.0], "c": [None, 200, 400, 200]},
Expand Down
12 changes: 6 additions & 6 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,7 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:

@property
def observations(self) -> Dict[str, "Observation"]:
return dict(**super().observations, **self.right.observations)
return {**super().observations, **self.right.observations}

def print(self, indent: int = 0) -> str:
i = " " * indent
Expand Down Expand Up @@ -1213,7 +1213,7 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:

@property
def observations(self) -> Dict[str, "Observation"]:
return dict(**super().observations, **self.right.observations)
return {**super().observations, **self.right.observations}

def print(self, indent: int = 0) -> str:
assert self.left is not None
Expand Down Expand Up @@ -1288,7 +1288,7 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:

@property
def observations(self) -> Dict[str, "Observation"]:
return dict(**super().observations, **self.right.observations)
return {**super().observations, **self.right.observations}

def print(self, indent: int = 0) -> str:
i = " " * indent
Expand Down Expand Up @@ -1354,10 +1354,10 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:

@property
def observations(self) -> Dict[str, "Observation"]:
return dict(
return {
**super().observations,
**(self.other.observations if self.other is not None else {}),
)
}

def print(self, indent: int = 0) -> str:
assert self._child is not None
Expand Down Expand Up @@ -1664,7 +1664,7 @@ def observations(self) -> Dict[str, "Observation"]:
observations = {str(self._observation._name): self._observation}
else:
observations = {}
return dict(**super().observations, **observations)
return {**super().observations, **observations}


class NAFill(LogicalPlan):
Expand Down
95 changes: 94 additions & 1 deletion python/pyspark/sql/tests/connect/test_connect_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,21 @@
)
from pyspark.errors import PySparkValueError

from unittest.mock import MagicMock

if should_test_connect:
import pyspark.sql.connect.proto as proto
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import WriteOperation, Read
from pyspark.sql.connect.plan import (
WriteOperation,
Read,
Join,
SetOperation,
CollectMetrics,
LogicalPlan,
)
from pyspark.sql.connect.observation import Observation
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.connect.expressions import LiteralExpression
from pyspark.sql.connect.functions import col, lit, max, min, sum
Expand Down Expand Up @@ -1131,6 +1141,89 @@ def test_literal_to_any_conversion(self):
LiteralExpression._to_value(proto_lit, DoubleType)


if should_test_connect:

class _StubPlan(LogicalPlan):
"""Minimal LogicalPlan that returns a fixed observations dict."""

def __init__(self, observations=None):
super().__init__(None)
self._obs = observations or {}

@property
def observations(self):
return self._obs

def plan(self, session):
raise NotImplementedError

def print(self, indent=0):
return ""


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class TestObservationMerging(unittest.TestCase):
"""Verify that observations are deduplicated when plan branches share the same key."""

def test_join_with_duplicate_observation_names(self):
obs = MagicMock()
obs._name = "shared"
shared = {"shared": obs}

left = _StubPlan(observations=shared)
right = _StubPlan(observations=shared)

join = Join.__new__(Join)
join._child = left
join.right = right

result = join.observations
self.assertEqual(result, {"shared": obs})

def test_join_with_distinct_observations(self):
obs_a = MagicMock()
obs_a._name = "a"
obs_b = MagicMock()
obs_b._name = "b"

left = _StubPlan(observations={"a": obs_a})
right = _StubPlan(observations={"b": obs_b})

join = Join.__new__(Join)
join._child = left
join.right = right

result = join.observations
self.assertEqual(result, {"a": obs_a, "b": obs_b})

def test_set_operation_with_duplicate_observation_names(self):
obs = MagicMock()
obs._name = "shared"
shared = {"shared": obs}

left = _StubPlan(observations=shared)
right = _StubPlan(observations=shared)

set_op = SetOperation.__new__(SetOperation)
set_op._child = left
set_op.other = right

result = set_op.observations
self.assertEqual(result, {"shared": obs})

def test_collect_metrics_with_duplicate_observation_name(self):
obs = Observation("my_metric")
parent = _StubPlan(observations={"my_metric": obs})

cm = CollectMetrics.__new__(CollectMetrics)
cm._child = parent
cm._observation = obs
cm._exprs = []

result = cm.observations
self.assertEqual(result, {"my_metric": obs})


if __name__ == "__main__":
from pyspark.testing import main

Expand Down
Loading