@@ -11074,16 +11074,38 @@ def replace(
1107411074 raise TypeError (msg )
1107511075 new = list (old .values ())
1107611076 old = list (old .keys ())
11077+
11078+ old_is_seq = isinstance (old , Sequence ) and not isinstance (old , (str , pl .Series ))
11079+ new_is_seq = isinstance (new , Sequence ) and not isinstance (new , (str , pl .Series ))
11080+ has_expr_in_old = old_is_seq and any (isinstance (v , Expr ) for v in old ) # type: ignore[union-attr]
11081+ has_expr_in_new = new_is_seq and any (isinstance (v , Expr ) for v in new ) # type: ignore[union-attr]
11082+
11083+ if has_expr_in_old or has_expr_in_new :
11084+ old_list = list (old ) if old_is_seq else [old ] # type: ignore[arg-type, misc]
11085+ new_list = list (new ) if new_is_seq else [new ] # type: ignore[arg-type]
11086+
11087+ if len (new_list ) == 1 and len (old_list ) > 1 :
11088+ new_list = new_list * len (old_list )
11089+
11090+ if len (old_list ) != len (new_list ):
11091+ msg = f"lengths of `old` ({ len (old_list )} ) and `new` ({ len (new_list )} ) must match"
11092+ raise ValueError (msg )
11093+
11094+ # when(False).then(self) preserves the column name
11095+ result : Expr = F .when (F .lit (False )).then (self )
11096+ for old_val , new_val in zip (old_list , new_list ):
11097+ result = result .when (self == old_val ).then (new_val ) # type: ignore[attr-defined]
11098+ result = result .otherwise (self ) # type: ignore[attr-defined]
1107711099 else :
11078- if isinstance ( old , Sequence ) and not isinstance ( old , ( str , pl . Series )) :
11100+ if old_is_seq :
1107911101 old = pl .Series (old )
11080- if isinstance ( new , Sequence ) and not isinstance ( new , ( str , pl . Series )) :
11102+ if new_is_seq :
1108111103 new = pl .Series (new )
1108211104
11083- old_pyexpr = parse_into_expression (old , str_as_lit = True ) # type: ignore[arg-type]
11084- new_pyexpr = parse_into_expression (new , str_as_lit = True )
11105+ old_pyexpr = parse_into_expression (old , str_as_lit = True ) # type: ignore[arg-type]
11106+ new_pyexpr = parse_into_expression (new , str_as_lit = True ) # type: ignore[arg-type]
1108511107
11086- result = wrap_expr (self ._pyexpr .replace (old_pyexpr , new_pyexpr ))
11108+ result = wrap_expr (self ._pyexpr .replace (old_pyexpr , new_pyexpr ))
1108711109
1108811110 if return_dtype is not None :
1108911111 result = result .cast (return_dtype )
@@ -11271,6 +11293,37 @@ def replace_strict(
1127111293 new = list (old .values ())
1127211294 old = list (old .keys ())
1127311295
11296+ old_is_seq = isinstance (old , Sequence ) and not isinstance (old , (str , pl .Series ))
11297+ new_is_seq = isinstance (new , Sequence ) and not isinstance (new , (str , pl .Series ))
11298+ has_expr_in_old = old_is_seq and any (isinstance (v , Expr ) for v in old ) # type: ignore[union-attr]
11299+ has_expr_in_new = new_is_seq and any (isinstance (v , Expr ) for v in new ) # type: ignore[union-attr]
11300+
11301+ if has_expr_in_old or has_expr_in_new :
11302+ old_list = list (old ) if old_is_seq else [old ] # type: ignore[arg-type, misc]
11303+ new_list = list (new ) if new_is_seq else [new ] # type: ignore[arg-type]
11304+
11305+ if len (new_list ) == 1 and len (old_list ) > 1 :
11306+ new_list = new_list * len (old_list )
11307+
11308+ if len (old_list ) != len (new_list ):
11309+ msg = f"lengths of `old` ({ len (old_list )} ) and `new` ({ len (new_list )} ) must match"
11310+ raise ValueError (msg )
11311+
11312+ # when(False).then(self) preserves the column name
11313+ result : Expr = F .when (F .lit (False )).then (self )
11314+ for old_val , new_val in zip (old_list , new_list ):
11315+ result = result .when (self == old_val ).then (new_val ) # type: ignore[attr-defined]
11316+
11317+ if default is no_default :
11318+ result = result .otherwise (None ) # type: ignore[attr-defined]
11319+ else :
11320+ result = result .otherwise (default ) # type: ignore[attr-defined]
11321+
11322+ if return_dtype is not None :
11323+ result = result .cast (return_dtype )
11324+
11325+ return result
11326+
1127411327 old_pyexpr = parse_into_expression (old , str_as_lit = True ) # type: ignore[arg-type]
1127511328 new_pyexpr = parse_into_expression (new , str_as_lit = True ) # type: ignore[arg-type]
1127611329
0 commit comments