Skip to content

Commit 50c0a0f

Browse files
feat: add replace and replace_strict (#1327)
--------- Co-authored-by: Edoardo Abati <[email protected]>
1 parent 8efc161 commit 50c0a0f

File tree

17 files changed

+377
-40
lines changed

17 files changed

+377
-40
lines changed

docs/api-reference/expr.md

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
- over
3737
- pipe
3838
- quantile
39+
- replace_strict
3940
- round
4041
- sample
4142
- shift

docs/api-reference/series.md

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
- pipe
4545
- quantile
4646
- rename
47+
- replace_strict
4748
- round
4849
- sample
4950
- scatter

narwhals/_arrow/expr.py

+8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any
55
from typing import Callable
66
from typing import Literal
7+
from typing import Sequence
78

89
from narwhals._expression_parsing import reuse_series_implementation
910
from narwhals._expression_parsing import reuse_series_namespace_implementation
@@ -320,6 +321,13 @@ def is_last_distinct(self: Self) -> Self:
320321
def unique(self: Self, *, maintain_order: bool = False) -> Self:
321322
return reuse_series_implementation(self, "unique", maintain_order=maintain_order)
322323

324+
def replace_strict(
325+
self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
326+
) -> Self:
327+
return reuse_series_implementation(
328+
self, "replace_strict", old, new, return_dtype=return_dtype
329+
)
330+
323331
def sort(self: Self, *, descending: bool = False, nulls_last: bool = False) -> Self:
324332
return reuse_series_implementation(
325333
self, "sort", descending=descending, nulls_last=nulls_last

narwhals/_arrow/namespace.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _create_expr_from_series(self, series: ArrowSeries) -> ArrowExpr:
6565
def _create_series_from_scalar(self, value: Any, series: ArrowSeries) -> ArrowSeries:
6666
from narwhals._arrow.series import ArrowSeries
6767

68-
if self._backend_version < (13,) and hasattr(value, "as_py"): # pragma: no cover
68+
if self._backend_version < (13,) and hasattr(value, "as_py"):
6969
value = value.as_py()
7070
return ArrowSeries._from_iterable(
7171
[value],

narwhals/_arrow/series.py

+20
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,26 @@ def unique(self: Self, *, maintain_order: bool = False) -> ArrowSeries:
655655

656656
return self._from_native_series(pc.unique(self._native_series))
657657

658+
def replace_strict(
659+
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
660+
) -> ArrowSeries:
661+
import pyarrow as pa # ignore-banned-import
662+
import pyarrow.compute as pc # ignore-banned-import
663+
664+
# https://stackoverflow.com/a/79111029/4451315
665+
idxs = pc.index_in(self._native_series, pa.array(old))
666+
result_native = pc.take(pa.array(new), idxs).cast(
667+
narwhals_to_native_dtype(return_dtype, self._dtypes)
668+
)
669+
result = self._from_native_series(result_native)
670+
if result.is_null().sum() != self.is_null().sum():
671+
msg = (
672+
"replace_strict did not replace all non-null values.\n\n"
673+
f"The following did not get replaced: {self.filter(~self.is_null() & result.is_null()).unique().to_list()}"
674+
)
675+
raise ValueError(msg)
676+
return result
677+
658678
def sort(
659679
self: Self, *, descending: bool = False, nulls_last: bool = False
660680
) -> ArrowSeries:

narwhals/_dask/expr.py

+7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Callable
77
from typing import Literal
88
from typing import NoReturn
9+
from typing import Sequence
910

1011
from narwhals._dask.utils import add_row_index
1112
from narwhals._dask.utils import maybe_evaluate
@@ -477,6 +478,12 @@ def head(self) -> NoReturn:
477478
msg = "`Expr.head` is not supported for the Dask backend. Please use `LazyFrame.head` instead."
478479
raise NotImplementedError(msg)
479480

481+
def replace_strict(
482+
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
483+
) -> Self:
484+
msg = "`replace_strict` is not yet supported for Dask expressions"
485+
raise NotImplementedError(msg)
486+
480487
def sort(self, *, descending: bool = False, nulls_last: bool = False) -> NoReturn:
481488
# We can't (yet?) allow methods which modify the index
482489
msg = "`Expr.sort` is not supported for the Dask backend. Please use `LazyFrame.sort` instead."

narwhals/_pandas_like/dataframe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def with_columns(
429429
)
430430
else:
431431
# This is the logic in pandas' DataFrame.assign
432-
if self._backend_version < (2,): # pragma: no cover
432+
if self._backend_version < (2,):
433433
df = self._native_frame.copy(deep=True)
434434
else:
435435
df = self._native_frame.copy(deep=False)

narwhals/_pandas_like/expr.py

+9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any
55
from typing import Callable
66
from typing import Literal
7+
from typing import Sequence
78

89
from narwhals._expression_parsing import reuse_series_implementation
910
from narwhals._expression_parsing import reuse_series_namespace_implementation
@@ -14,6 +15,7 @@
1415

1516
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
1617
from narwhals._pandas_like.namespace import PandasLikeNamespace
18+
from narwhals.dtypes import DType
1719
from narwhals.typing import DTypes
1820
from narwhals.utils import Implementation
1921

@@ -271,6 +273,13 @@ def filter(self, *predicates: Any) -> Self:
271273
def drop_nulls(self) -> Self:
272274
return reuse_series_implementation(self, "drop_nulls")
273275

276+
def replace_strict(
277+
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
278+
) -> Self:
279+
return reuse_series_implementation(
280+
self, "replace_strict", old, new, return_dtype=return_dtype
281+
)
282+
274283
def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self:
275284
return reuse_series_implementation(
276285
self, "sort", descending=descending, nulls_last=nulls_last

narwhals/_pandas_like/series.py

+30
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,36 @@ def diff(self) -> PandasLikeSeries:
491491
def shift(self, n: int) -> PandasLikeSeries:
492492
return self._from_native_series(self._native_series.shift(n))
493493

494+
def replace_strict(
495+
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
496+
) -> PandasLikeSeries:
497+
tmp_name = f"{self.name}_tmp"
498+
dtype = narwhals_to_native_dtype(
499+
return_dtype,
500+
self._native_series.dtype,
501+
self._implementation,
502+
self._backend_version,
503+
self._dtypes,
504+
)
505+
other = self.__native_namespace__().DataFrame(
506+
{
507+
self.name: old,
508+
tmp_name: self.__native_namespace__().Series(new, dtype=dtype),
509+
}
510+
)
511+
result = self._from_native_series(
512+
self._native_series.to_frame()
513+
.merge(other, on=self.name, how="left")[tmp_name]
514+
.rename(self.name)
515+
)
516+
if result.is_null().sum() != self.is_null().sum():
517+
msg = (
518+
"replace_strict did not replace all non-null values.\n\n"
519+
f"The following did not get replaced: {self.filter(~self.is_null() & result.is_null()).unique().to_list()}"
520+
)
521+
raise ValueError(msg)
522+
return result
523+
494524
def sort(
495525
self, *, descending: bool = False, nulls_last: bool = False
496526
) -> PandasLikeSeries:

narwhals/_polars/dataframe.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,15 @@ def func(*args: Any, **kwargs: Any) -> Any:
8484
return func
8585

8686
def __array__(self, dtype: Any | None = None, copy: bool | None = None) -> np.ndarray:
87-
if self._backend_version < (0, 20, 28) and copy is not None: # pragma: no cover
87+
if self._backend_version < (0, 20, 28) and copy is not None:
8888
msg = "`copy` in `__array__` is only supported for Polars>=0.20.28"
8989
raise NotImplementedError(msg)
90-
if self._backend_version < (0, 20, 28): # pragma: no cover
90+
if self._backend_version < (0, 20, 28):
9191
return self._native_frame.__array__(dtype)
9292
return self._native_frame.__array__(dtype)
9393

9494
def collect_schema(self) -> dict[str, Any]:
95-
if self._backend_version < (1,): # pragma: no cover
95+
if self._backend_version < (1,):
9696
schema = self._native_frame.schema
9797
else:
9898
schema = dict(self._native_frame.collect_schema())
@@ -209,12 +209,12 @@ def group_by(self, *by: str, drop_null_keys: bool) -> Any:
209209
return PolarsGroupBy(self, list(by), drop_null_keys=drop_null_keys)
210210

211211
def with_row_index(self, name: str) -> Any:
212-
if self._backend_version < (0, 20, 4): # pragma: no cover
212+
if self._backend_version < (0, 20, 4):
213213
return self._from_native_frame(self._native_frame.with_row_count(name))
214214
return self._from_native_frame(self._native_frame.with_row_index(name))
215215

216216
def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
217-
if self._backend_version < (1, 0, 0): # pragma: no cover
217+
if self._backend_version < (1, 0, 0):
218218
to_drop = parse_columns_to_drop(
219219
compliant_frame=self, columns=columns, strict=strict
220220
)
@@ -228,7 +228,7 @@ def unpivot(
228228
variable_name: str | None,
229229
value_name: str | None,
230230
) -> Self:
231-
if self._backend_version < (1, 0, 0): # pragma: no cover
231+
if self._backend_version < (1, 0, 0):
232232
return self._from_native_frame(
233233
self._native_frame.melt(
234234
id_vars=index,
@@ -296,7 +296,7 @@ def schema(self) -> dict[str, Any]:
296296
}
297297

298298
def collect_schema(self) -> dict[str, Any]:
299-
if self._backend_version < (1,): # pragma: no cover
299+
if self._backend_version < (1,):
300300
schema = self._native_frame.schema
301301
else:
302302
schema = dict(self._native_frame.collect_schema())
@@ -318,12 +318,12 @@ def group_by(self, *by: str, drop_null_keys: bool) -> Any:
318318
return PolarsLazyGroupBy(self, list(by), drop_null_keys=drop_null_keys)
319319

320320
def with_row_index(self, name: str) -> Any:
321-
if self._backend_version < (0, 20, 4): # pragma: no cover
321+
if self._backend_version < (0, 20, 4):
322322
return self._from_native_frame(self._native_frame.with_row_count(name))
323323
return self._from_native_frame(self._native_frame.with_row_index(name))
324324

325325
def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
326-
if self._backend_version < (1, 0, 0): # pragma: no cover
326+
if self._backend_version < (1, 0, 0):
327327
return self._from_native_frame(self._native_frame.drop(columns))
328328
return self._from_native_frame(self._native_frame.drop(columns, strict=strict))
329329

@@ -334,7 +334,7 @@ def unpivot(
334334
variable_name: str | None,
335335
value_name: str | None,
336336
) -> Self:
337-
if self._backend_version < (1, 0, 0): # pragma: no cover
337+
if self._backend_version < (1, 0, 0):
338338
return self._from_native_frame(
339339
self._native_frame.melt(
340340
id_vars=index,

narwhals/_polars/expr.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import TYPE_CHECKING
44
from typing import Any
5+
from typing import Sequence
56

67
from narwhals._polars.utils import extract_args_kwargs
78
from narwhals._polars.utils import extract_native
@@ -16,16 +17,21 @@
1617

1718

1819
class PolarsExpr:
19-
def __init__(self, expr: Any, dtypes: DTypes) -> None:
20+
def __init__(
21+
self, expr: Any, dtypes: DTypes, backend_version: tuple[int, ...]
22+
) -> None:
2023
self._native_expr = expr
2124
self._implementation = Implementation.POLARS
2225
self._dtypes = dtypes
26+
self._backend_version = backend_version
2327

2428
def __repr__(self) -> str: # pragma: no cover
2529
return "PolarsExpr"
2630

2731
def _from_native_expr(self, expr: Any) -> Self:
28-
return self.__class__(expr, dtypes=self._dtypes)
32+
return self.__class__(
33+
expr, dtypes=self._dtypes, backend_version=self._backend_version
34+
)
2935

3036
def __getattr__(self, attr: str) -> Any:
3137
def func(*args: Any, **kwargs: Any) -> Any:
@@ -41,6 +47,18 @@ def cast(self, dtype: DType) -> Self:
4147
dtype = narwhals_to_native_dtype(dtype, self._dtypes)
4248
return self._from_native_expr(expr.cast(dtype))
4349

50+
def replace_strict(
51+
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
52+
) -> Self:
53+
expr = self._native_expr
54+
return_dtype = narwhals_to_native_dtype(return_dtype, self._dtypes)
55+
if self._backend_version < (1,):
56+
msg = f"`replace_strict` is only available in Polars>=1.0, found version {self._backend_version}"
57+
raise NotImplementedError(msg)
58+
return self._from_native_expr(
59+
expr.replace_strict(old, new, return_dtype=return_dtype)
60+
)
61+
4462
def __eq__(self, other: object) -> Self: # type: ignore[override]
4563
return self._from_native_expr(self._native_expr.__eq__(extract_native(other)))
4664

0 commit comments

Comments
 (0)